232 lines
9.9 KiB
Python
232 lines
9.9 KiB
Python
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
|
import google.generativeai as genai
|
|
import torch, os, sys, ast, json, asyncio, batching, random
|
|
from typing import List, Optional, Set
|
|
from utils import standardize_lang
|
|
from functools import wraps
|
|
from batching import generate_text, Gemini, Groq, ApiModel
|
|
from logging_config import logger
|
|
from asyncio import Task
|
|
# root dir
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
|
|
|
|
##############################
|
|
# translation decorator
|
|
def translate(translation_func):
|
|
@wraps(translation_func)
|
|
def wrapper(text, *args, **kwargs):
|
|
try:
|
|
if len(text) == 0:
|
|
return []
|
|
return translation_func(text, *args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Translation error with the following function: {translation_func.__name__}. Text: {text}\nError: {e}")
|
|
return wrapper
|
|
###############################
|
|
|
|
|
|
###############################
|
|
def init_API_LLM(from_lang, target_lang):
|
|
"""Initialise the API models. The models are stored in a json file. The models are instantiated, added to database/database api rates are updated and the languages are set."""
|
|
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
|
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
|
with open(API_MODELS_FILEPATH, 'r') as f:
|
|
models_and_rates = json.load(f)
|
|
models = []
|
|
for class_type, class_models in models_and_rates.items():
|
|
cls = getattr(batching, class_type)
|
|
instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
|
|
models.extend(instantiated_objects)
|
|
for model in models:
|
|
model.update_db()
|
|
model.set_lang(from_lang, target_lang)
|
|
return models
|
|
|
|
async def translate_API_LLM(texts_to_translate: List[str],
|
|
models: List[ApiModel],
|
|
call_size: int = 2) -> List[str]:
|
|
"""Translate the texts using the models three at a time. If the models fail to translate the text, it will try the next model in the list."""
|
|
async def try_translate(model: ApiModel) -> Optional[List[str]]:
|
|
result = await model.translate(texts_to_translate, store=True)
|
|
logger.debug(f'Try_translate result: {result}')
|
|
return result
|
|
random.shuffle(models)
|
|
groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
|
|
no_of_models = len(models)
|
|
translation_attempts = 0
|
|
|
|
best_translation = None # (model, translation_errors)
|
|
|
|
for group in groups:
|
|
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
|
|
while tasks:
|
|
done, pending = await asyncio.wait(tasks,
|
|
return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
logger.debug(f"Tasks done: {done}")
|
|
logger.debug(f"Tasks remaining: {pending}")
|
|
for task in done:
|
|
result = await task
|
|
logger.debug(f'Result: {result}')
|
|
if result is not None:
|
|
tasks.discard(task)
|
|
translation_attempts += 1
|
|
status_code, translations, translation_mismatches = result
|
|
if status_code == 0:
|
|
# Cancel remaining tasks
|
|
for t in pending:
|
|
t.cancel()
|
|
return translations
|
|
else:
|
|
logger.error(f"Model has failed to translate the text. Result: {result}")
|
|
if translation_attempts == no_of_models:
|
|
if best_translation is not None:
|
|
return translations
|
|
else:
|
|
logger.error("All models have failed to translate the text.")
|
|
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
|
|
elif status_code == 2:
|
|
if best_translation is None:
|
|
best_translation = (translations, translation_mismatches)
|
|
else:
|
|
best_translation = (translations, translation_mismatches) if len(result[2]) < len(best_translation[1]) else best_translation
|
|
else:
|
|
continue
|
|
|
|
|
|
###############################
|
|
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
|
def init_AYA():
|
|
model_id = "CohereForAI/aya-23-8B"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, locals_files_only=True, torch_dtype=torch.float16).to(device)
|
|
model.eval()
|
|
return (model, tokenizer)
|
|
|
|
|
|
|
|
##############################
|
|
# M2M100 model
|
|
|
|
|
|
def init_M2M():
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
|
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
|
|
model.eval()
|
|
return (model, tokenizer)
|
|
|
|
|
|
def translate_M2M(text, model, tokenizer, from_lang = 'ch_sim', target_lang = 'en') -> list[str]:
|
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
|
if len(text) == 0:
|
|
return []
|
|
tokenizer.src_lang = model_lang_from
|
|
generated_translations = generate_text(text, model,tokenizer, batch_size=BATCH_SIZE,
|
|
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS,
|
|
forced_bos_token_id=tokenizer.get_lang_id(model_lang_to))
|
|
return generated_translations
|
|
|
|
###############################
|
|
|
|
|
|
###############################
|
|
# Helsinki-NLP model Opus MT
|
|
# Refer here for all the models https://huggingface.co/Helsinki-NLP
|
|
def get_OPUS_model(from_lang, target_lang):
|
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
|
return f"Helsinki-NLP/opus-mt-{model_lang_from}-{model_lang_to}"
|
|
|
|
|
|
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
|
|
opus_model = get_OPUS_model(from_lang, target_lang)
|
|
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
|
|
model.eval()
|
|
return (model, tokenizer)
|
|
|
|
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
|
|
translated_text = generate_text(model,tokenizer, text,
|
|
batch_size=BATCH_SIZE, device=device,
|
|
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
|
|
return translated_text
|
|
|
|
###############################
|
|
|
|
|
|
def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
|
|
if model_type == 'opus':
|
|
return init_OPUS(**kwargs)
|
|
elif model_type == 'm2m':
|
|
return init_M2M()
|
|
else:
|
|
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
|
|
|
def init_Causal_LLM(model_type, **kwargs):
|
|
pass
|
|
###
|
|
@translate
|
|
def translate_Seq_LLM(text,
|
|
model_type, # 'opus' or 'm2m'
|
|
model,
|
|
tokenizer,
|
|
**kwargs):
|
|
if model_type == 'opus':
|
|
return translate_OPUS(text, model, tokenizer)
|
|
elif model_type == 'm2m':
|
|
try:
|
|
return translate_M2M(text, model, tokenizer, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Error with M2M model. Error: {e}")
|
|
# raise ValueError(f"Please provide the correct from_lang and target_lang variables if you are using the M2M model. Use the list from {available_langs}.")
|
|
else:
|
|
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
|
|
|
|
|
### if you want to use any other translation, just define a translate function with input text and output text.
|
|
|
|
#def translate_Causal_LLM(text, model_type, model)
|
|
|
|
@translate
|
|
def translate_Causal_LLM(text: list[str],
|
|
model_type, # aya
|
|
model,
|
|
tokenizer,
|
|
from_lang: str,
|
|
target_lang: str) -> list[str]:
|
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
|
if len(text) == 0:
|
|
return []
|
|
pass
|
|
|
|
|
|
# choose between local Seq2Seq LLM or obtain translations from an API
|
|
def init_func(model):
|
|
if model in seq_llm_models:
|
|
return init_Seq_LLM
|
|
elif model in api_llm_models:
|
|
return init_API_LLM
|
|
elif model in causal_llm_models:
|
|
return init_Causal_LLM
|
|
else:
|
|
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
|
|
|
|
def translate_func(model):
|
|
if model in seq_llm_models:
|
|
return translate_Seq_LLM
|
|
elif model in api_llm_models:
|
|
return translate_API_LLM
|
|
elif model in causal_llm_models:
|
|
return translate_Causal_LLM
|
|
else:
|
|
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
|
|
|
|
|
|
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
|
|
if __name__ == "__main__":
|
|
models = init_API_LLM('ja', 'en')
|
|
print(translate_API_LLM(['こんにちは'], models)) |