from torch.utils.data import Dataset, DataLoader from typing import List, Dict from datetime import datetime, timedelta from dotenv import load_dotenv import os , sys, torch, time, ast, json, pytz from werkzeug.exceptions import TooManyRequests from multiprocessing import Process, Event, Value load_dotenv() sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE from logging_config import logger from groq import Groq as Groqq import google.generativeai as genai from google.api_core.exceptions import ResourceExhausted import asyncio import aiohttp from functools import wraps from data import session, Api, Translations from typing import Optional class ApiModel(): def __init__(self, model, # model name as defined by the API site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini api_key: Optional[str] = None, # api key for the model wrt the site rpmin: Optional[int] = None, # rate of calls per minute rph: Optional[int] = None, # rate of calls per hour rpd: Optional[int] = None, # rate of calls per day rpw: Optional[int] = None, # rate of calls per week rpmth: Optional[int] = None, # rate of calls per month rpy: Optional[int] = None # rate of calls per year ): self.api_key = api_key self.model = model self.rpmin = rpmin self.rph = rph self.rpd = rpd self.rpw = rpw self.rpmth = rpmth self.rpy = rpy self.site = site self.from_lang = None self.target_lang = None self.db_table = None self.session_calls = 0 self._id = None self._set_db_model_id() if self._get_db_model_id() else self.update_db() # Create the table if it does not already exist def __repr__(self): return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}' def __str__(self): return self.model async def __aenter__(self): self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.session: await self.session.close() def _get_db_model_id(self): model = session.query(Api).filter_by(model_name = self.model, site = self.site).first() if model: return model.id else: return None def _set_db_model_id(self): self._id = self._get_db_model_id() @staticmethod def _get_time(): return datetime.now(tz=pytz.timezone('Australia/Sydney')) def set_lang(self, from_lang, target_lang): self.from_lang = from_lang self.target_lang = target_lang def set_db_table(self, db_table): self.db_table = db_table def update_db(self): api = session.query(Api).filter_by(model_name = self.model, site = self.site).first() if not api: api = Api(model_name = self.model, site = self.site, rpmin = self.rpmin, rph = self.rph, rpd = self.rpd, rpw = self.rpw, rpmth = self.rpmth, rpy = self.rpy) session.add(api) session.commit() self._set_db_model_id() else: api.rpmin = self.rpmin api.rph = self.rph api.rpd = self.rpd api.rpw = self.rpw api.rpmth = self.rpmth api.rpy = self.rpy session.commit() def _db_add_translation(self, text: list | str, translation: list, mismatch = False): text = json.dumps(text) if isinstance(text, list) else json.dumps([text]) translation = json.dumps(translation) translation = Translations(source_texts = text, translated_texts = translation, model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang, timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')), translation_mismatch = mismatch) session.add(translation) session.commit() @staticmethod def _single_period_calls_check(max_calls, call_count): if not max_calls: return True if max_calls <= call_count: return False else: return True def _are_rates_good(self): curr_time = self._get_time() min_ago = curr_time - timedelta(minutes=1) hour_ago = curr_time - timedelta(hours=1) day_ago = curr_time - timedelta(days=1) week_ago = curr_time - timedelta(weeks=1) month_ago = curr_time - timedelta(days=30) year_ago = curr_time - timedelta(days=365) min_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= min_ago ).count() hour_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= hour_ago ).count() day_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= day_ago ).count() week_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= week_ago ).count() month_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= month_ago ).count() year_calls = session.query(Translations).join(Api). \ filter(Api.id==self._id, Translations.timestamp >= year_ago ).count() if self._single_period_calls_check(self.rpmin, min_calls) \ and self._single_period_calls_check(self.rph, hour_calls) \ and self._single_period_calls_check(self.rpd, day_calls) \ and self._single_period_calls_check(self.rpw, week_calls) \ and self._single_period_calls_check(self.rpmth, month_calls) \ and self._single_period_calls_check(self.rpy, year_calls): return True else: logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.") return False async def translate(self, texts_to_translate, store = False) -> tuple[int, # exit code: 0 for success, 1 for incorrect response type, 2 for incorrect translation count list[str], int # number of translations that do not match the number of texts to translate ]: """Main Translation Function. All API models will need to define a new class and also define a _request function as shown below in the Gemini and Groq class models.""" if isinstance(texts_to_translate, str): texts_to_translate = [texts_to_translate] if len(texts_to_translate) == 0: return (0, [], 0) #prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}" prompt = f"""INSTRUCTIONS: - Provide ONE and ONLY ONE translation to each text provided in the JSON array given. - Respond using ONLY valid JSON array syntax. Do not use any Python-like dictionary syntax and therefore it must not contain any keys or curly braces. - Do not include explanations or additional text - The translations must preserve the original order. - Each translation must be from the Source language to the Target language - Source language: {self.from_lang} - Target language: {self.target_lang} - Escape special characters properly Input texts: {texts_to_translate} Expected format: ["translation1", "translation2", ...] Translation:""" try: response = await self._request(prompt) response_list = ast.literal_eval(response.strip()) except Exception as e: logger.error(f"Failed to evaluate response from {self.model} from {self.site}. Error: {e}.") return (1, [], 99999) logger.debug(repr(self)) logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.') if not isinstance(response_list, list): # raise TypeError(f"Incorrect response type. Expected list, got {type(response_list)}") logger.error(f"Incorrect response type. Expected list, got {type(response_list)}") return (1, [], 99999) if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE: logger.error(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.") if store: self._db_add_translation(texts_to_translate, response_list, mismatch=True) # raise ValueError(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.") return (2, response_list, abs(len(texts_to_translate) - len(response_list))) else: if store: self._db_add_translation(texts_to_translate, response_list) return (0, response_list, 0) class Groq(ApiModel): def __init__(self, # model name as defined by the API model, api_key = GROQ_API_KEY, # api key for the model wrt the site **kwargs): super().__init__(model, api_key = api_key, site = 'Groq', **kwargs) self.client = Groqq() async def _request(self, content: str) -> str: async with aiohttp.ClientSession() as session: async with session.post( "https://api.groq.com/openai/v1/chat/completions", headers={ "Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json" }, json={ "messages": [{"role": "user", "content": content}], "model": self.model } ) as response: response_json = await response.json() return response_json["choices"][0]["message"]["content"] # https://console.groq.com/settings/limits for limits class Gemini(ApiModel): def __init__(self, # model name as defined by the API model, api_key = GEMINI_API_KEY, # api key for the model wrt the site **kwargs): super().__init__(model, api_key = api_key, site = 'Google', **kwargs) async def _request(self, content): async with aiohttp.ClientSession() as session: async with session.post( f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}", headers={ "Content-Type": "application/json" }, json={ "contents": [{"parts": [{"text": content}]}], "safetySettings": [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE", "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE", "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE", "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE" } ] } ) as response: response_json = await response.json() return response_json['candidates'][0]['content']['parts'][0]['text'] """ DEFINE YOUR OWN API MODELS BELOW WITH THE SAME TEMPLATE AS BELOW. All fields required are indicated by . class (ApiModel): def __init__(self, # model name as defined by the API model, api_key = , # api key for the model wrt the site **kwargs): super().__init__(model, api_key = api_key, site = , **kwargs) async def _request(self, content): async with aiohttp.ClientSession() as session: async with session.post( , headers={ "Content-Type": "application/json" }, json={ "contents": [{"parts": [{"text": content}]}] } ) as response: response_json = await response.json() return """ ################################################################################################### ### LOCAL LLM TRANSLATION class TranslationDataset(Dataset): def __init__(self, texts: List[str], tokenizer, max_length: int = 512): """ Custom dataset for efficient text processing Args: texts: List of input texts tokenizer: HuggingFace tokenizer max_length: Maximum sequence length """ self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize with padding and truncation encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(device) # Remove batch dimension added by tokenizer return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0) } def collate_fn(batch: List[Dict]): """ Custom collate function to handle batching """ input_ids = torch.stack([item['input_ids'] for item in batch]) attention_mask = torch.stack([item['attention_mask'] for item in batch]) return { 'input_ids': input_ids, 'attention_mask': attention_mask } def generate_text( texts: List[str], model, tokenizer, batch_size: int = 6, # Smaller batch size uses less VRAM device: str = 'cuda', max_length: int = 512, max_new_tokens: int = 512, **generate_kwargs ): """ Optimized text generation function Args: model: HuggingFace model tokenizer: HuggingFace tokenizer texts: List of input texts batch_size: Batch size for processing device: Device to run inference on max_length: Maximum input sequence length max_new_tokens: Maximum number of tokens to generate generate_kwargs: Additional kwargs for model.generate Returns: List of generated texts """ # Create dataset and dataloader dataset = TranslationDataset(texts, tokenizer, max_length) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) all_generated_texts = [] # Default generation parameters generation_config = { 'max_new_tokens': max_new_tokens, 'num_beams': 4, 'do_sample': True, 'top_k': 50, 'top_p': 0.95, 'temperature': 0.7, 'no_repeat_ngram_size': 2, 'pad_token_id': tokenizer.pad_token_id, 'eos_token_id': tokenizer.eos_token_id } # Update with user-provided parameters generation_config.update(generate_kwargs) # Perform generation with torch.no_grad(): for batch in dataloader: # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Generate text outputs = model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], **generation_config ) # Decode generated tokens decoded_texts = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True ) all_generated_texts.extend(decoded_texts) return all_generated_texts if __name__ == '__main__': GROQ_API_KEY = os.getenv('GROQ_API_KEY') GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY) groq.set_lang('zh','en') gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY) gemini.set_lang('zh','en') print(gemini.translate(['荷兰咯'])) print(groq.translate(['荷兰咯']))