onscreen-translator/helpers/batching.py

456 lines
18 KiB
Python

from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from datetime import datetime, timedelta
from dotenv import load_dotenv
import os , sys, torch, time, ast, json, pytz
from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
from logging_config import logger
from groq import Groq as Groqq
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
import asyncio
import aiohttp
from functools import wraps
from data import session, Api, Translations
from typing import Optional
class ApiModel():
def __init__(self, model, # model name as defined by the API
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
api_key: Optional[str] = None, # api key for the model wrt the site
rpmin: Optional[int] = None, # rate of calls per minute
rph: Optional[int] = None, # rate of calls per hour
rpd: Optional[int] = None, # rate of calls per day
rpw: Optional[int] = None, # rate of calls per week
rpmth: Optional[int] = None, # rate of calls per month
rpy: Optional[int] = None # rate of calls per year
):
self.api_key = api_key
self.model = model
self.rpmin = rpmin
self.rph = rph
self.rpd = rpd
self.rpw = rpw
self.rpmth = rpmth
self.rpy = rpy
self.site = site
self.from_lang = None
self.target_lang = None
self.db_table = None
self.session_calls = 0
self._id = None
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
# Create the table if it does not already exist
def __repr__(self):
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
def __str__(self):
return self.model
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
def _get_db_model_id(self):
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if model:
return model.id
else:
return None
def _set_db_model_id(self):
self._id = self._get_db_model_id()
@staticmethod
def _get_time():
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang
self.target_lang = target_lang
def set_db_table(self, db_table):
self.db_table = db_table
def update_db(self):
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if not api:
api = Api(model_name = self.model,
site = self.site,
rpmin = self.rpmin,
rph = self.rph,
rpd = self.rpd,
rpw = self.rpw,
rpmth = self.rpmth,
rpy = self.rpy)
session.add(api)
session.commit()
self._set_db_model_id()
else:
api.rpmin = self.rpmin
api.rph = self.rph
api.rpd = self.rpd
api.rpw = self.rpw
api.rpmth = self.rpmth
api.rpy = self.rpy
session.commit()
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
translation = json.dumps(translation)
translation = Translations(source_texts = text, translated_texts = translation,
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
translation_mismatch = mismatch)
session.add(translation)
session.commit()
@staticmethod
def _single_period_calls_check(max_calls, call_count):
if not max_calls:
return True
if max_calls <= call_count:
return False
else:
return True
def _are_rates_good(self):
curr_time = self._get_time()
min_ago = curr_time - timedelta(minutes=1)
hour_ago = curr_time - timedelta(hours=1)
day_ago = curr_time - timedelta(days=1)
week_ago = curr_time - timedelta(weeks=1)
month_ago = curr_time - timedelta(days=30)
year_ago = curr_time - timedelta(days=365)
min_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= min_ago
).count()
hour_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= hour_ago
).count()
day_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= day_ago
).count()
week_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= week_ago
).count()
month_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= month_ago
).count()
year_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= year_ago
).count()
if self._single_period_calls_check(self.rpmin, min_calls) \
and self._single_period_calls_check(self.rph, hour_calls) \
and self._single_period_calls_check(self.rpd, day_calls) \
and self._single_period_calls_check(self.rpw, week_calls) \
and self._single_period_calls_check(self.rpmth, month_calls) \
and self._single_period_calls_check(self.rpy, year_calls):
return True
else:
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
return False
# async def request_func(request):
# @wraps(request)
# async def wrapper(self, text, *args, **kwargs):
# if await self._are_rates_good():
# try:
# self.session_calls += 1
# response = await request(self, text, *args, **kwargs)
# return response
# except Exception as e:
# logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
# else:
# logger.error(f"Rate limit reached for this model.")
# raise TooManyRequests('Rate limit reached for this model.')
# return wrapper
# @request_func
async def translate(self, texts_to_translate, store = False):
if isinstance(texts_to_translate, str):
texts_to_translate = [texts_to_translate]
if len(texts_to_translate) == 0:
return []
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
prompt = f"""INSTRUCTIONS:
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
- The translations must preserve the original order.
- Each translation must be from the Source language to the Target language
- Source language: {self.from_lang}
- Target language: {self.target_lang}
- Texts are provided in JSON array syntax.
- Respond using ONLY valid JSON array syntax.
- Do not include explanations or additional text
- Escape special characters properly
Input texts:
{texts_to_translate}
Expected format:
["translation1", "translation2", ...]
Translation:"""
response = await self._request(prompt)
response_list = ast.literal_eval(response.strip())
logger.debug(repr(self))
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
logger.error(f"{self.model} model failed to translate all the texts. Number of translations to make: {len(texts_to_translate)}; Number of translated texts: {len(response_list)}.")
if store:
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
else:
if store:
self._db_add_translation(texts_to_translate, response_list)
print(response_list)
return response_list
class Groq(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GROQ_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Groq', **kwargs)
self.client = Groqq()
async def _request(self, content: str) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {GROQ_API_KEY}",
"Content-Type": "application/json"
},
json={
"messages": [{"role": "user", "content": content}],
"model": self.model
}
) as response:
response_json = await response.json()
return response_json["choices"][0]["message"]["content"]
# https://console.groq.com/settings/limits for limits
# def request(self, content):
# chat_completion = self.client.chat.completions.create(
# messages=[
# {
# "role": "user",
# "content": content,
# }
# ],
# model=self.model
# )
# return chat_completion.choices[0].message.content
# async def translate(self, texts_to_translate):
# return super().translate(self.request, texts_to_translate)
class Gemini(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GEMINI_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Google',
**kwargs)
# def request(self, content):
# genai.configure(api_key=self.api_key)
# safety_settings = {
# "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
# "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
# "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
# "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
# try:
# response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
# except ResourceExhausted as e:
# logger.error(f"Rate limited with {self.model}. Error: {e}")
# raise ResourceExhausted("Rate limited.")
# return response.text.strip()
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
headers={
"Content-Type": "application/json"
},
json={
"contents": [{"parts": [{"text": content}]}],
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
}
]
}
) as response:
response_json = await response.json()
return response_json['candidates'][0]['content']['parts'][0]['text']
# async def translate(self, texts_to_translate):
# return super().translate(self.request, texts_to_translate)
###################################################################################################
### LOCAL LLM TRANSLATION
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
"""
Custom dataset for efficient text processing
Args:
texts: List of input texts
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize with padding and truncation
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
# Remove batch dimension added by tokenizer
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def collate_fn(batch: List[Dict]):
"""
Custom collate function to handle batching
"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def generate_text(
texts: List[str],
model,
tokenizer,
batch_size: int = 6, # Smaller batch size uses less VRAM
device: str = 'cuda',
max_length: int = 512,
max_new_tokens: int = 512,
**generate_kwargs
):
"""
Optimized text generation function
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
texts: List of input texts
batch_size: Batch size for processing
device: Device to run inference on
max_length: Maximum input sequence length
max_new_tokens: Maximum number of tokens to generate
generate_kwargs: Additional kwargs for model.generate
Returns:
List of generated texts
"""
# Create dataset and dataloader
dataset = TranslationDataset(texts, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
all_generated_texts = []
# Default generation parameters
generation_config = {
'max_new_tokens': max_new_tokens,
'num_beams': 4,
'do_sample': True,
'top_k': 50,
'top_p': 0.95,
'temperature': 0.7,
'no_repeat_ngram_size': 2,
'pad_token_id': tokenizer.pad_token_id,
'eos_token_id': tokenizer.eos_token_id
}
# Update with user-provided parameters
generation_config.update(generate_kwargs)
# Perform generation
with torch.no_grad():
for batch in dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Generate text
outputs = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
**generation_config
)
# Decode generated tokens
decoded_texts = tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
all_generated_texts.extend(decoded_texts)
return all_generated_texts
if __name__ == '__main__':
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
groq.set_lang('zh','en')
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
gemini.set_lang('zh','en')
print(gemini.translate(['荷兰咯']))
print(groq.translate(['荷兰咯']))