76 lines
2.6 KiB
Python
76 lines
2.6 KiB
Python
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig
|
|
import torch, os
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
if os.getenv('TRANSLATION_USE_GPU') in ['False', '0', 'false', 'no', 'No', 'NO', 'FALSE']:
|
|
device = torch.device("cpu")
|
|
else:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
### Batch translate a list of strings
|
|
|
|
|
|
# M2M100 model
|
|
|
|
|
|
def init_M2M():
|
|
global tokenizer, model
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True, torch_dtype=torch.float16).to(device)
|
|
model.eval()
|
|
|
|
|
|
|
|
def translate_M2M(text, from_lang = 'zh', target_lang = 'en'):
|
|
if len(text) == 0:
|
|
return []
|
|
tokenizer.src_lang = from_lang
|
|
with torch.no_grad():
|
|
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
|
generated_tokens = model.generate(**encoded,
|
|
forced_bos_token_id=tokenizer.get_lang_id(target_lang))
|
|
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
return translated
|
|
|
|
|
|
# Helsinki-NLP model Opus MT
|
|
|
|
|
|
def init_OPUS():
|
|
global tokenizer, model
|
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True, torch_dtype=torch.float16).to(device)
|
|
model.eval()
|
|
|
|
def translate_OPUS(text: list[str]) -> list[str]:
|
|
if len(text) == 0:
|
|
return []
|
|
with torch.no_grad():
|
|
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
|
generated_tokens = model.generate(**encoded)
|
|
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
return translated
|
|
|
|
###
|
|
def init_TRANSLATE(model): # model = 'opus' or 'm2m'
|
|
if model == 'opus':
|
|
init_OPUS()
|
|
elif model == 'm2m':
|
|
init_M2M()
|
|
else:
|
|
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.")
|
|
|
|
|
|
###
|
|
def translate(text, model, **kwargs):
|
|
if model == 'opus':
|
|
return translate_OPUS(text)
|
|
elif model == 'm2m':
|
|
try:
|
|
return translate_M2M(text, **kwargs)
|
|
except:
|
|
raise ValueError("Please provide the from_lang and target_lang variables if you are using the M2M model.")
|
|
|
|
else:
|
|
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.") |