onscreen-translator/helpers/translation.py

235 lines
10 KiB
Python

from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import google.generativeai as genai
import torch, os, sys, ast, json, asyncio, batching, random
from typing import List, Optional, Set
from utils import standardize_lang
from functools import wraps
from batching import generate_text, Gemini, Groq, ApiModel
from logging_config import logger
from asyncio import Task
# root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
##############################
# translation decorator
def translate(translation_func):
@wraps(translation_func)
def wrapper(text, *args, **kwargs):
try:
if len(text) == 0:
return []
return translation_func(text, *args, **kwargs)
except Exception as e:
logger.error(f"Translation error with the following function: {translation_func.__name__}. Text: {text}\nError: {e}")
return wrapper
###############################
###############################
def init_API_LLM(from_lang, target_lang):
"""Initialise the API models. The models are stored in a json file. The models are instantiated, added to database/database api rates are updated and the languages are set."""
from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang']
with open(API_MODELS_FILEPATH, 'r') as f:
models_and_rates = json.load(f)
models = []
for class_type, class_models in models_and_rates.items():
cls = getattr(batching, class_type)
instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
models.extend(instantiated_objects)
for model in models:
model.update_db()
model.set_lang(from_lang, target_lang)
return models
async def translate_API_LLM(texts_to_translate: List[str],
models: List[ApiModel],
call_size: int = 2) -> List[str]:
"""Translate the texts using the models three at a time. If the models fail to translate the text, it will try the next model in the list."""
async def try_translate(model: ApiModel) -> Optional[List[str]]:
result = await model.translate(texts_to_translate, store=True)
logger.debug(f'Try_translate result: {result}')
return result
random.shuffle(models)
groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
no_of_models = len(models)
translation_attempts = 0
best_translation = None # (model, translation_errors)
for group in groups:
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
while tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_COMPLETED
)
logger.debug(f"Tasks done: {done}")
logger.debug(f"Tasks remaining: {pending}")
for task in done:
result = await task
logger.debug(f'Result: {result}')
if result is not None:
tasks.discard(task)
translation_attempts += 1
status_code, translations, translation_mismatches = result
if status_code == 0:
# Cancel remaining tasks
for t in pending:
t.cancel()
return translations
else:
logger.error(f"Model has failed to translate the text. Result: {result}")
if translation_attempts == no_of_models:
if best_translation is not None:
return translations
else:
logger.error("All models have failed to translate the text.")
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
elif status_code == 2:
if best_translation is None:
best_translation = (translations, translation_mismatches)
else:
best_translation = (translations, translation_mismatches) if len(result[2]) < len(best_translation[1]) else best_translation
else:
continue
###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
def init_AYA():
model_id = "CohereForAI/aya-23-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_id, locals_files_only=True, torch_dtype=torch.float16).to(device)
model.eval()
return (model, tokenizer)
##############################
# M2M100 model
def init_M2M():
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
model.eval()
return (model, tokenizer)
def translate_M2M(text, model, tokenizer, from_lang = 'ch_sim', target_lang = 'en') -> list[str]:
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0:
return []
tokenizer.src_lang = model_lang_from
generated_translations = generate_text(text, model,tokenizer, batch_size=BATCH_SIZE,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS,
forced_bos_token_id=tokenizer.get_lang_id(model_lang_to))
return generated_translations
###############################
###############################
# Helsinki-NLP model Opus MT
# Refer here for all the models https://huggingface.co/Helsinki-NLP
def get_OPUS_model(from_lang, target_lang):
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
return f"Helsinki-NLP/opus-mt-{model_lang_from}-{model_lang_to}"
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
opus_model = get_OPUS_model(from_lang, target_lang)
logger.debug(f"OPUS model: {opus_model}")
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
model.eval()
return (model, tokenizer)
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
translated_text = generate_text(text, model,tokenizer,
batch_size=BATCH_SIZE, device=device,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
logger.debug(f"Translated text: {translated_text}")
return translated_text
###############################
def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
if model_type == 'opus':
return init_OPUS(**kwargs)
elif model_type == 'm2m':
return init_M2M()
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
def init_Causal_LLM(model_type, **kwargs):
pass
###
@translate
def translate_Seq_LLM(text,
model_type, # 'opus' or 'm2m'
model,
tokenizer,
**kwargs):
text = [t.lower().capitalize() for t in text]
if model_type == 'opus':
return translate_OPUS(text, model, tokenizer)
elif model_type == 'm2m':
try:
return translate_M2M(text, model, tokenizer, **kwargs)
except Exception as e:
logger.error(f"Error with M2M model. Error: {e}")
# raise ValueError(f"Please provide the correct from_lang and target_lang variables if you are using the M2M model. Use the list from {available_langs}.")
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
### if you want to use any other translation, just define a translate function with input text and output text.
#def translate_Causal_LLM(text, model_type, model)
@translate
def translate_Causal_LLM(text: list[str],
model_type, # aya
model,
tokenizer,
from_lang: str,
target_lang: str) -> list[str]:
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0:
return []
pass
# choose between local Seq2Seq LLM or obtain translations from an API
def init_func(model):
if model in seq_llm_models:
return init_Seq_LLM
elif model in api_llm_models:
return init_API_LLM
elif model in causal_llm_models:
return init_Causal_LLM
else:
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
def translate_func(model):
if model in seq_llm_models:
return translate_Seq_LLM
elif model in api_llm_models:
return translate_API_LLM
elif model in causal_llm_models:
return translate_Causal_LLM
else:
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
if __name__ == "__main__":
models = init_API_LLM('ja', 'en')
print(translate_API_LLM(['こんにちは'], models))