onscreen-translator/helpers/batching.py

272 lines
9.2 KiB
Python

from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from dotenv import load_dotenv
import os , sys, torch, time, ast
from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device, GEMINI_API_KEY, GROQ_API_KEY
from logging_config import logger
from groq import Groq
import google.generativeai as genai
import asyncio
from functools import wraps
class ApiModel():
def __init__(self, model, # model name
rate, # rate of calls per minute
api_key, # api key for the model wrt the site
):
self.model = model
self.rate = rate
self.api_key = api_key
self.curr_calls = Value('i', 0)
self.time = Value('i', 0)
self.process = None
self.stop_event = Event()
self.site = None
self.from_lang = None
self.target_lang = None
self.request = None # request response from API
def __repr__(self):
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
def __str__(self):
return self.model
def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang
self.target_lang = target_lang
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
async def api_rate_check(self):
# Background task to manage the rate of calls to the API
while not self.stop_event.is_set():
start_time = time.monotonic()
self.time.value += 5
if self.time.value >= 60:
self.time.value = 0
self.curr_calls.value = 0
elapsed = time.monotonic() - start_time
# Sleep for exactly 5 seconds minus the elapsed time
sleep_time = max(0, 5 - elapsed)
await asyncio.sleep(sleep_time)
def background_task(self):
asyncio.run(self.api_rate_check())
def start(self):
# Start the background task
self.process = Process(target=self.background_task)
self.process.daemon = True
self.process.start()
logger.info(f"Background process started with PID: {self.process.pid}")
def stop(self):
# Stop the background task
logger.info(f"Stopping background process with PID: {self.process.pid}")
self.stop_event.set()
if self.process:
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
def request_func(request):
@wraps(request)
def wrapper(self, text, *args, **kwargs):
if self.curr_calls.value < self.rate:
# try:
response = request(self, text, *args, **kwargs)
self.curr_calls.value += 1
return response
# except Exception as e:
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
else:
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
raise TooManyRequests('Rate limit reached for this model.')
return wrapper
@request_func
def translate(self, request_fn, texts_to_translate):
if len(texts_to_translate) == 0:
return []
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
response = request_fn(self, prompt)
return ast.literal_eval(response.strip())
class Groqq(ApiModel):
def __init__(self, model, rate, api_key = GROQ_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Groq"
def request(self, content):
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": content,
}
],
model=self.model
)
return chat_completion.choices[0].message.content
def translate(self, texts_to_translate):
return super().translate(Groqq.request, texts_to_translate)
class Gemini(ApiModel):
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Gemini"
def request(self, content):
genai.configure(api_key=self.api_key)
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
return response.text.strip()
def translate(self, texts_to_translate):
return super().translate(Gemini.request, texts_to_translate)
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
"""
Custom dataset for efficient text processing
Args:
texts: List of input texts
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize with padding and truncation
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
# Remove batch dimension added by tokenizer
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def collate_fn(batch: List[Dict]):
"""
Custom collate function to handle batching
"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def generate_text(
texts: List[str],
model,
tokenizer,
batch_size: int = 6, # Smaller batch size uses less VRAM
device: str = 'cuda',
max_length: int = 512,
max_new_tokens: int = 512,
**generate_kwargs
):
"""
Optimized text generation function
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
texts: List of input texts
batch_size: Batch size for processing
device: Device to run inference on
max_length: Maximum input sequence length
max_new_tokens: Maximum number of tokens to generate
generate_kwargs: Additional kwargs for model.generate
Returns:
List of generated texts
"""
# Create dataset and dataloader
dataset = TranslationDataset(texts, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
all_generated_texts = []
# Default generation parameters
generation_config = {
'max_new_tokens': max_new_tokens,
'num_beams': 4,
'do_sample': True,
'top_k': 50,
'top_p': 0.95,
'temperature': 0.7,
'no_repeat_ngram_size': 2,
'pad_token_id': tokenizer.pad_token_id,
'eos_token_id': tokenizer.eos_token_id
}
# Update with user-provided parameters
generation_config.update(generate_kwargs)
# Perform generation
with torch.no_grad():
for batch in dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Generate text
outputs = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
**generation_config
)
# Decode generated tokens
decoded_texts = tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
all_generated_texts.extend(decoded_texts)
return all_generated_texts
if __name__ == '__main__':
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
groq.set_lang('zh','en')
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
gemini.set_lang('zh','en')
print(gemini.translate(['荷兰咯']))
print(groq.translate(['荷兰咯']))