from torch.utils.data import Dataset, DataLoader from typing import List, Dict from dotenv import load_dotenv import os , sys, torch, time, ast from werkzeug.exceptions import TooManyRequests from multiprocessing import Process, Event, Value load_dotenv() sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from config import device, GEMINI_API_KEY, GROQ_API_KEY from logging_config import logger from groq import Groq import google.generativeai as genai import asyncio from functools import wraps class ApiModel(): def __init__(self, model, # model name rate, # rate of calls per minute api_key, # api key for the model wrt the site ): self.model = model self.rate = rate self.api_key = api_key self.curr_calls = Value('i', 0) self.time = Value('i', 0) self.process = None self.stop_event = Event() self.site = None self.from_lang = None self.target_lang = None self.request = None # request response from API def __repr__(self): return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.' def __str__(self): return self.model def set_lang(self, from_lang, target_lang): self.from_lang = from_lang self.target_lang = target_lang ### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit async def api_rate_check(self): # Background task to manage the rate of calls to the API while not self.stop_event.is_set(): start_time = time.monotonic() self.time.value += 5 if self.time.value >= 60: self.time.value = 0 self.curr_calls.value = 0 elapsed = time.monotonic() - start_time # Sleep for exactly 5 seconds minus the elapsed time sleep_time = max(0, 5 - elapsed) await asyncio.sleep(sleep_time) def background_task(self): asyncio.run(self.api_rate_check()) def start(self): # Start the background task self.process = Process(target=self.background_task) self.process.daemon = True self.process.start() logger.info(f"Background process started with PID: {self.process.pid}") def stop(self): # Stop the background task logger.info(f"Stopping background process with PID: {self.process.pid}") self.stop_event.set() if self.process: self.process.join(timeout=5) if self.process.is_alive(): self.process.terminate() def request_func(request): @wraps(request) def wrapper(self, text, *args, **kwargs): if self.curr_calls.value < self.rate: # try: response = request(self, text, *args, **kwargs) self.curr_calls.value += 1 return response # except Exception as e: #logger.error(f"Error with model {self.model} from {self.site}. Error: {e}") else: logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.") raise TooManyRequests('Rate limit reached for this model.') return wrapper @request_func def translate(self, request_fn, texts_to_translate): if len(texts_to_translate) == 0: return [] prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}" response = request_fn(self, prompt) return ast.literal_eval(response.strip()) class Groqq(ApiModel): def __init__(self, model, rate, api_key = GROQ_API_KEY): super().__init__(model, rate, api_key) self.site = "Groq" def request(self, content): client = Groq() chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": content, } ], model=self.model ) return chat_completion.choices[0].message.content def translate(self, texts_to_translate): return super().translate(Groqq.request, texts_to_translate) class Gemini(ApiModel): def __init__(self, model, rate, api_key = GEMINI_API_KEY): super().__init__(model, rate, api_key) self.site = "Gemini" def request(self, content): genai.configure(api_key=self.api_key) safety_settings = { "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"} response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings) return response.text.strip() def translate(self, texts_to_translate): return super().translate(Gemini.request, texts_to_translate) class TranslationDataset(Dataset): def __init__(self, texts: List[str], tokenizer, max_length: int = 512): """ Custom dataset for efficient text processing Args: texts: List of input texts tokenizer: HuggingFace tokenizer max_length: Maximum sequence length """ self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize with padding and truncation encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(device) # Remove batch dimension added by tokenizer return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0) } def collate_fn(batch: List[Dict]): """ Custom collate function to handle batching """ input_ids = torch.stack([item['input_ids'] for item in batch]) attention_mask = torch.stack([item['attention_mask'] for item in batch]) return { 'input_ids': input_ids, 'attention_mask': attention_mask } def generate_text( texts: List[str], model, tokenizer, batch_size: int = 6, # Smaller batch size uses less VRAM device: str = 'cuda', max_length: int = 512, max_new_tokens: int = 512, **generate_kwargs ): """ Optimized text generation function Args: model: HuggingFace model tokenizer: HuggingFace tokenizer texts: List of input texts batch_size: Batch size for processing device: Device to run inference on max_length: Maximum input sequence length max_new_tokens: Maximum number of tokens to generate generate_kwargs: Additional kwargs for model.generate Returns: List of generated texts """ # Create dataset and dataloader dataset = TranslationDataset(texts, tokenizer, max_length) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) all_generated_texts = [] # Default generation parameters generation_config = { 'max_new_tokens': max_new_tokens, 'num_beams': 4, 'do_sample': True, 'top_k': 50, 'top_p': 0.95, 'temperature': 0.7, 'no_repeat_ngram_size': 2, 'pad_token_id': tokenizer.pad_token_id, 'eos_token_id': tokenizer.eos_token_id } # Update with user-provided parameters generation_config.update(generate_kwargs) # Perform generation with torch.no_grad(): for batch in dataloader: # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Generate text outputs = model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], **generation_config ) # Decode generated tokens decoded_texts = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True ) all_generated_texts.extend(decoded_texts) return all_generated_texts if __name__ == '__main__': GROQ_API_KEY = os.getenv('GROQ_API_KEY') GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY) groq.set_lang('zh','en') gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY) gemini.set_lang('zh','en') print(gemini.translate(['荷兰咯'])) print(groq.translate(['荷兰咯']))