onscreen-translator/helpers/translation.py

218 lines
9.5 KiB
Python

from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import google.generativeai as genai
import torch, os, sys, ast
from utils import standardize_lang
from functools import wraps
from batching import generate_text, Gemini
from logging_config import logger
from multiprocessing import Process,Event
# root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
##############################
# translation decorator
def translate(translation_func):
@wraps(translation_func)
def wrapper(text, *args, **kwargs):
try:
if len(text) == 0:
return []
return translation_func(text, *args, **kwargs)
except Exception as e:
logger.error(f"Translation error with the following function: {translation_func.__name__}. Text: {text}\nError: {e}")
return wrapper
###############################
###############################
def init_GEMINI(models_and_rates = None):
if not models_and_rates:
## this is default for free tier
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref
models = [Gemini(name, rate) for name, rate in models_and_rates.items()]
for model in models:
model.start()
genai.configure(api_key=GEMINI_KEY)
return models
def translate_GEMINI(text, models, from_lang, target_lang):
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
for model in models:
if model.curr_calls < model.rate:
try:
response = genai.GenerativeModel(model.name).generate_content(prompt,
safety_settings=safety_settings)
model.curr_calls += 1
logger.info(repr(model))
logger.info(f'Model Response: {response.text.strip()}')
return ast.literal_eval(response.text.strip())
except Exception as e:
logger.error(f"Error with model {model.name}. Error: {e}")
logger.error("No models available to translate. Please wait for a model to be available.")
###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
def init_AYA():
model_id = "CohereForAI/aya-23-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_id, locals_files_only=True, torch_dtype=torch.float16).to(device)
model.eval()
return (model, tokenizer)
##############################
# M2M100 model
def init_M2M():
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
model.eval()
return (model, tokenizer)
def translate_M2M(text, model, tokenizer, from_lang = 'ch_sim', target_lang = 'en') -> list[str]:
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0:
return []
tokenizer.src_lang = model_lang_from
generated_translations = generate_text(text, model,tokenizer, batch_size=BATCH_SIZE,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS,
forced_bos_token_id=tokenizer.get_lang_id(model_lang_to))
return generated_translations
###############################
###############################
# Helsinki-NLP model Opus MT
# Refer here for all the models https://huggingface.co/Helsinki-NLP
def get_OPUS_model(from_lang, target_lang):
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
return f"Helsinki-NLP/opus-mt-{model_lang_from}-{model_lang_to}"
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
opus_model = get_OPUS_model(from_lang, target_lang)
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
model.eval()
return (model, tokenizer)
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
translated_text = generate_text(model,tokenizer, text,
batch_size=BATCH_SIZE, device=device,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
return translated_text
###############################
def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
if model_type == 'opus':
return init_OPUS(**kwargs)
elif model_type == 'm2m':
return init_M2M()
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
if model_type == 'gemini':
return init_GEMINI(**kwargs)
else:
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
def init_Causal_LLM(model_type, **kwargs):
pass
###
@translate
def translate_Seq_LLM(text,
model_type, # 'opus' or 'm2m'
model,
tokenizer,
**kwargs):
if model_type == 'opus':
return translate_OPUS(text, model, tokenizer)
elif model_type == 'm2m':
try:
return translate_M2M(text, model, tokenizer, **kwargs)
except Exception as e:
logger.error(f"Error with M2M model. Error: {e}")
# raise ValueError(f"Please provide the correct from_lang and target_lang variables if you are using the M2M model. Use the list from {available_langs}.")
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
### if you want to use any other translation, just define a translate function with input text and output text.
# def translate_api(text):
#@translate
#def translate_Causal_LLM(text, model_type, model)
@translate
def translate_API_LLM(text: list[str],
model_type: str, # 'gemma'
models: list, # list of objects of classes defined in batching.py
from_lang: str, # suggested to use ISO 639-1 codes
target_lang: str # suggested to use ISO 639-1 codes
) -> list[str]:
if model_type == 'gemini':
from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang']
return translate_GEMINI(text, models, from_lang, target_lang)
else:
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
@translate
def translate_Causal_LLM(text: list[str],
model_type, # aya
model,
tokenizer,
from_lang: str,
target_lang: str) -> list[str]:
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0:
return []
pass
# choose between local Seq2Seq LLM or obtain translations from an API
def init_func(model):
if model in seq_llm_models:
return init_Seq_LLM
elif model in api_llm_models:
return init_API_LLM
elif model in causal_llm_models:
return init_Causal_LLM
else:
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
def translate_func(model):
if model in seq_llm_models:
return translate_Seq_LLM
elif model in api_llm_models:
return translate_API_LLM
elif model in causal_llm_models:
return translate_Causal_LLM
else:
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
if __name__ == "__main__":
models = init_GEMINI()
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en'))
# model, tokenizer = init_M2M()
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))