onscreen-translator/helpers/translation.py
2024-11-01 15:44:12 +11:00

76 lines
2.6 KiB
Python

from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig
import torch, os
from dotenv import load_dotenv
load_dotenv()
if os.getenv('TRANSLATION_USE_GPU') in ['False', '0', 'false', 'no', 'No', 'NO', 'FALSE']:
device = torch.device("cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### Batch translate a list of strings
# M2M100 model
def init_M2M():
global tokenizer, model
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True, torch_dtype=torch.float16).to(device)
model.eval()
def translate_M2M(text, from_lang = 'zh', target_lang = 'en'):
if len(text) == 0:
return []
tokenizer.src_lang = from_lang
with torch.no_grad():
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
generated_tokens = model.generate(**encoded,
forced_bos_token_id=tokenizer.get_lang_id(target_lang))
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated
# Helsinki-NLP model Opus MT
def init_OPUS():
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True, torch_dtype=torch.float16).to(device)
model.eval()
def translate_OPUS(text: list[str]) -> list[str]:
if len(text) == 0:
return []
with torch.no_grad():
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
generated_tokens = model.generate(**encoded)
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated
###
def init_TRANSLATE(model): # model = 'opus' or 'm2m'
if model == 'opus':
init_OPUS()
elif model == 'm2m':
init_M2M()
else:
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.")
###
def translate(text, model, **kwargs):
if model == 'opus':
return translate_OPUS(text)
elif model == 'm2m':
try:
return translate_M2M(text, **kwargs)
except:
raise ValueError("Please provide the from_lang and target_lang variables if you are using the M2M model.")
else:
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.")