onscreen-translator/helpers/batching.py

447 lines
19 KiB
Python

from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from datetime import datetime, timedelta
from dotenv import load_dotenv
import os , sys, torch, time, ast, json, pytz
from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
from logging_config import logger
from groq import Groq as Groqq
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
import asyncio
import aiohttp
from functools import wraps
from data import session, Api, Translations
from typing import Optional
class ApiModel():
def __init__(self, model, # model name as defined by the API
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
api_key: Optional[str] = None, # api key for the model wrt the site
rpmin: Optional[int] = None, # rate of calls per minute
rph: Optional[int] = None, # rate of calls per hour
rpd: Optional[int] = None, # rate of calls per day
rpw: Optional[int] = None, # rate of calls per week
rpmth: Optional[int] = None, # rate of calls per month
rpy: Optional[int] = None # rate of calls per year
):
self.api_key = api_key
self.model = model
self.rpmin = rpmin
self.rph = rph
self.rpd = rpd
self.rpw = rpw
self.rpmth = rpmth
self.rpy = rpy
self.site = site
self.from_lang = None
self.target_lang = None
self.db_table = None
self.session_calls = 0
self._id = None
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
# Create the table if it does not already exist
def __repr__(self):
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
def __str__(self):
return self.model
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
def _get_db_model_id(self):
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if model:
return model.id
else:
return None
def _set_db_model_id(self):
self._id = self._get_db_model_id()
@staticmethod
def _get_time():
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang
self.target_lang = target_lang
def set_db_table(self, db_table):
self.db_table = db_table
def update_db(self):
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if not api:
api = Api(model_name = self.model,
site = self.site,
rpmin = self.rpmin,
rph = self.rph,
rpd = self.rpd,
rpw = self.rpw,
rpmth = self.rpmth,
rpy = self.rpy)
session.add(api)
session.commit()
self._set_db_model_id()
else:
api.rpmin = self.rpmin
api.rph = self.rph
api.rpd = self.rpd
api.rpw = self.rpw
api.rpmth = self.rpmth
api.rpy = self.rpy
session.commit()
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
translation = json.dumps(translation)
translation = Translations(source_texts = text, translated_texts = translation,
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
translation_mismatch = mismatch)
session.add(translation)
session.commit()
@staticmethod
def _single_period_calls_check(max_calls, call_count):
if not max_calls:
return True
if max_calls <= call_count:
return False
else:
return True
def _are_rates_good(self):
curr_time = self._get_time()
min_ago = curr_time - timedelta(minutes=1)
hour_ago = curr_time - timedelta(hours=1)
day_ago = curr_time - timedelta(days=1)
week_ago = curr_time - timedelta(weeks=1)
month_ago = curr_time - timedelta(days=30)
year_ago = curr_time - timedelta(days=365)
min_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= min_ago
).count()
hour_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= hour_ago
).count()
day_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= day_ago
).count()
week_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= week_ago
).count()
month_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= month_ago
).count()
year_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= year_ago
).count()
if self._single_period_calls_check(self.rpmin, min_calls) \
and self._single_period_calls_check(self.rph, hour_calls) \
and self._single_period_calls_check(self.rpd, day_calls) \
and self._single_period_calls_check(self.rpw, week_calls) \
and self._single_period_calls_check(self.rpmth, month_calls) \
and self._single_period_calls_check(self.rpy, year_calls):
return True
else:
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
return False
async def translate(self, texts_to_translate, store = False) -> tuple[int, # exit code: 0 for success, 1 for incorrect response type, 2 for incorrect translation count
list[str],
int # number of translations that do not match the number of texts to translate
]:
"""Main Translation Function. All API models will need to define a new class and also define a _request function as shown below in the Gemini and Groq class models."""
if isinstance(texts_to_translate, str):
texts_to_translate = [texts_to_translate]
if len(texts_to_translate) == 0:
return (0, [], 0)
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
prompt = f"""INSTRUCTIONS:
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
- Respond using ONLY valid JSON array syntax. Do not use any Python-like dictionary syntax and therefore it must not contain any keys or curly braces.
- Do not include explanations or additional text
- The translations must preserve the original order.
- Each translation must be from the Source language to the Target language
- Source language: {self.from_lang}
- Target language: {self.target_lang}
- Escape special characters properly
Input texts:
{texts_to_translate}
Expected format:
["translation1", "translation2", ...]
Translation:"""
try:
response = await self._request(prompt)
response_list = ast.literal_eval(response.strip())
except Exception as e:
logger.error(f"Failed to evaluate response from {self.model} from {self.site}. Error: {e}.")
return (1, [], 99999)
logger.debug(repr(self))
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
if not isinstance(response_list, list):
# raise TypeError(f"Incorrect response type. Expected list, got {type(response_list)}")
logger.error(f"Incorrect response type. Expected list, got {type(response_list)}")
return (1, [], 99999)
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
logger.error(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
if store:
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
# raise ValueError(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
return (2, response_list, abs(len(texts_to_translate) - len(response_list)))
else:
if store:
self._db_add_translation(texts_to_translate, response_list)
return (0, response_list, 0)
class Groq(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GROQ_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Groq', **kwargs)
self.client = Groqq()
async def _request(self, content: str) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {GROQ_API_KEY}",
"Content-Type": "application/json"
},
json={
"messages": [{"role": "user", "content": content}],
"model": self.model
}
) as response:
response_json = await response.json()
return response_json["choices"][0]["message"]["content"]
# https://console.groq.com/settings/limits for limits
class Gemini(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GEMINI_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Google',
**kwargs)
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
headers={
"Content-Type": "application/json"
},
json={
"contents": [{"parts": [{"text": content}]}],
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
}
]
}
) as response:
response_json = await response.json()
return response_json['candidates'][0]['content']['parts'][0]['text']
"""
DEFINE YOUR OWN API MODELS BELOW WITH THE SAME TEMPLATE AS BELOW. All fields required are indicated by <required field>.
class <NameOfWebsite>(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = <API_KEY>, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = <name_of_website>,
**kwargs)
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
<API ENDPOINT e.g. https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}>,
headers={
"Content-Type": "application/json"
<ANY OTHER HEADERS REQUIRED BY THE API separated by commas>
},
json={
"contents": [{"parts": [{"text": content}]}]
<ANY OTHER JSON PAIRS REQUIRED separated by commas>
}
) as response:
response_json = await response.json()
return <Anything needed to extract the message response from `response_json`>
"""
###################################################################################################
### LOCAL LLM TRANSLATION
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
"""
Custom dataset for efficient text processing
Args:
texts: List of input texts
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize with padding and truncation
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
# Remove batch dimension added by tokenizer
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def collate_fn(batch: List[Dict]):
"""
Custom collate function to handle batching
"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def generate_text(
texts: List[str],
model,
tokenizer,
batch_size: int = 6, # Smaller batch size uses less VRAM
device: str = 'cuda',
max_length: int = 512,
max_new_tokens: int = 512,
**generate_kwargs
):
"""
Optimized text generation function
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
texts: List of input texts
batch_size: Batch size for processing
device: Device to run inference on
max_length: Maximum input sequence length
max_new_tokens: Maximum number of tokens to generate
generate_kwargs: Additional kwargs for model.generate
Returns:
List of generated texts
"""
# Create dataset and dataloader
dataset = TranslationDataset(texts, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
all_generated_texts = []
# Default generation parameters
generation_config = {
'max_new_tokens': max_new_tokens,
'num_beams': 4,
'do_sample': True,
'top_k': 50,
'top_p': 0.95,
'temperature': 0.7,
'no_repeat_ngram_size': 2,
'pad_token_id': tokenizer.pad_token_id,
'eos_token_id': tokenizer.eos_token_id
}
# Update with user-provided parameters
generation_config.update(generate_kwargs)
# Perform generation
with torch.no_grad():
for batch in dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Generate text
outputs = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
**generation_config
)
# Decode generated tokens
decoded_texts = tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
all_generated_texts.extend(decoded_texts)
return all_generated_texts
if __name__ == '__main__':
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100Tokenizer, M2M100ForConditionalGeneration
opus_model = 'Helsinki-NLP/opus-mt-en-zh'
LOCAL_FILES_ONLY = True
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
# tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
# tokenizer.src_lang = "en"
# model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
print(generate_text([ i.lower().capitalize() for i in ['placeholder','Story','StoRY', 'TufoRIaL', 'CovFfG', 'LoaD DaTA', 'SAME DATa', 'ReTulN@TitIE', 'View', '@niirm', 'SysceM', 'MeNu:', 'MaND', 'CoM', 'SeLEcT', 'Frogguingang', 'Tutorias', 'Back']], model, tokenizer))