445 lines
18 KiB
Python
445 lines
18 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
|
from typing import List, Dict
|
|
from datetime import datetime, timedelta
|
|
from dotenv import load_dotenv
|
|
import os , sys, torch, time, ast, json, pytz
|
|
from werkzeug.exceptions import TooManyRequests
|
|
from multiprocessing import Process, Event, Value
|
|
load_dotenv()
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
|
|
from logging_config import logger
|
|
from groq import Groq as Groqq
|
|
import google.generativeai as genai
|
|
from google.api_core.exceptions import ResourceExhausted
|
|
import asyncio
|
|
import aiohttp
|
|
from functools import wraps
|
|
from data import session, Api, Translations
|
|
from typing import Optional
|
|
|
|
|
|
class ApiModel():
|
|
def __init__(self, model, # model name as defined by the API
|
|
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
|
|
api_key: Optional[str] = None, # api key for the model wrt the site
|
|
rpmin: Optional[int] = None, # rate of calls per minute
|
|
rph: Optional[int] = None, # rate of calls per hour
|
|
rpd: Optional[int] = None, # rate of calls per day
|
|
rpw: Optional[int] = None, # rate of calls per week
|
|
rpmth: Optional[int] = None, # rate of calls per month
|
|
rpy: Optional[int] = None # rate of calls per year
|
|
):
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.rpmin = rpmin
|
|
self.rph = rph
|
|
self.rpd = rpd
|
|
self.rpw = rpw
|
|
self.rpmth = rpmth
|
|
self.rpy = rpy
|
|
self.site = site
|
|
self.from_lang = None
|
|
self.target_lang = None
|
|
self.db_table = None
|
|
self.session_calls = 0
|
|
self._id = None
|
|
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
|
|
# Create the table if it does not already exist
|
|
|
|
|
|
def __repr__(self):
|
|
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
|
|
|
|
def __str__(self):
|
|
return self.model
|
|
|
|
async def __aenter__(self):
|
|
self.session = aiohttp.ClientSession()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.session:
|
|
await self.session.close()
|
|
def _get_db_model_id(self):
|
|
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
|
if model:
|
|
return model.id
|
|
else:
|
|
return None
|
|
|
|
def _set_db_model_id(self):
|
|
self._id = self._get_db_model_id()
|
|
|
|
@staticmethod
|
|
def _get_time():
|
|
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
|
|
|
|
def set_lang(self, from_lang, target_lang):
|
|
self.from_lang = from_lang
|
|
self.target_lang = target_lang
|
|
|
|
def set_db_table(self, db_table):
|
|
self.db_table = db_table
|
|
|
|
def update_db(self):
|
|
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
|
if not api:
|
|
api = Api(model_name = self.model,
|
|
site = self.site,
|
|
rpmin = self.rpmin,
|
|
rph = self.rph,
|
|
rpd = self.rpd,
|
|
rpw = self.rpw,
|
|
rpmth = self.rpmth,
|
|
rpy = self.rpy)
|
|
session.add(api)
|
|
session.commit()
|
|
self._set_db_model_id()
|
|
else:
|
|
api.rpmin = self.rpmin
|
|
api.rph = self.rph
|
|
api.rpd = self.rpd
|
|
api.rpw = self.rpw
|
|
api.rpmth = self.rpmth
|
|
api.rpy = self.rpy
|
|
session.commit()
|
|
|
|
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
|
|
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
|
|
translation = json.dumps(translation)
|
|
translation = Translations(source_texts = text, translated_texts = translation,
|
|
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
|
|
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
|
|
translation_mismatch = mismatch)
|
|
session.add(translation)
|
|
session.commit()
|
|
|
|
@staticmethod
|
|
def _single_period_calls_check(max_calls, call_count):
|
|
if not max_calls:
|
|
return True
|
|
if max_calls <= call_count:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _are_rates_good(self):
|
|
curr_time = self._get_time()
|
|
min_ago = curr_time - timedelta(minutes=1)
|
|
hour_ago = curr_time - timedelta(hours=1)
|
|
day_ago = curr_time - timedelta(days=1)
|
|
week_ago = curr_time - timedelta(weeks=1)
|
|
month_ago = curr_time - timedelta(days=30)
|
|
year_ago = curr_time - timedelta(days=365)
|
|
min_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= min_ago
|
|
).count()
|
|
hour_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= hour_ago
|
|
).count()
|
|
day_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= day_ago
|
|
).count()
|
|
week_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= week_ago
|
|
).count()
|
|
month_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= month_ago
|
|
).count()
|
|
year_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= year_ago
|
|
).count()
|
|
if self._single_period_calls_check(self.rpmin, min_calls) \
|
|
and self._single_period_calls_check(self.rph, hour_calls) \
|
|
and self._single_period_calls_check(self.rpd, day_calls) \
|
|
and self._single_period_calls_check(self.rpw, week_calls) \
|
|
and self._single_period_calls_check(self.rpmth, month_calls) \
|
|
and self._single_period_calls_check(self.rpy, year_calls):
|
|
return True
|
|
else:
|
|
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
|
|
return False
|
|
|
|
|
|
async def translate(self, texts_to_translate, store = False) -> tuple[int, # exit code: 0 for success, 1 for incorrect response type, 2 for incorrect translation count
|
|
list[str],
|
|
int # number of translations that do not match the number of texts to translate
|
|
]:
|
|
"""Main Translation Function. All API models will need to define a new class and also define a _request function as shown below in the Gemini and Groq class models."""
|
|
if isinstance(texts_to_translate, str):
|
|
texts_to_translate = [texts_to_translate]
|
|
if len(texts_to_translate) == 0:
|
|
return (0, [], 0)
|
|
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
|
|
prompt = f"""INSTRUCTIONS:
|
|
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
|
|
- Respond using ONLY valid JSON array syntax. Do not use any Python-like dictionary syntax and therefore it must not contain any keys or curly braces.
|
|
- Do not include explanations or additional text
|
|
- The translations must preserve the original order.
|
|
- Each translation must be from the Source language to the Target language
|
|
- Source language: {self.from_lang}
|
|
- Target language: {self.target_lang}
|
|
- Escape special characters properly
|
|
|
|
Input texts:
|
|
{texts_to_translate}
|
|
|
|
Expected format:
|
|
["translation1", "translation2", ...]
|
|
|
|
Translation:"""
|
|
|
|
try:
|
|
response = await self._request(prompt)
|
|
response_list = ast.literal_eval(response.strip())
|
|
except Exception as e:
|
|
logger.error(f"Failed to evaluate response from {self.model} from {self.site}. Error: {e}.")
|
|
return (1, [], 99999)
|
|
logger.debug(repr(self))
|
|
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
|
|
if not isinstance(response_list, list):
|
|
# raise TypeError(f"Incorrect response type. Expected list, got {type(response_list)}")
|
|
logger.error(f"Incorrect response type. Expected list, got {type(response_list)}")
|
|
return (1, [], 99999)
|
|
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
|
|
logger.error(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
|
|
if store:
|
|
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
|
|
# raise ValueError(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
|
|
return (2, response_list, abs(len(texts_to_translate) - len(response_list)))
|
|
else:
|
|
if store:
|
|
self._db_add_translation(texts_to_translate, response_list)
|
|
return (0, response_list, 0)
|
|
|
|
class Groq(ApiModel):
|
|
def __init__(self, # model name as defined by the API
|
|
model,
|
|
api_key = GROQ_API_KEY, # api key for the model wrt the site
|
|
**kwargs):
|
|
super().__init__(model,
|
|
api_key = api_key,
|
|
site = 'Groq', **kwargs)
|
|
self.client = Groqq()
|
|
|
|
async def _request(self, content: str) -> str:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
"https://api.groq.com/openai/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {GROQ_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"messages": [{"role": "user", "content": content}],
|
|
"model": self.model
|
|
}
|
|
) as response:
|
|
response_json = await response.json()
|
|
return response_json["choices"][0]["message"]["content"]
|
|
# https://console.groq.com/settings/limits for limits
|
|
|
|
class Gemini(ApiModel):
|
|
def __init__(self, # model name as defined by the API
|
|
model,
|
|
api_key = GEMINI_API_KEY, # api key for the model wrt the site
|
|
**kwargs):
|
|
super().__init__(model,
|
|
api_key = api_key,
|
|
site = 'Google',
|
|
**kwargs)
|
|
|
|
async def _request(self, content):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
|
|
headers={
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"contents": [{"parts": [{"text": content}]}],
|
|
"safetySettings": [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
|
|
}
|
|
]
|
|
}
|
|
) as response:
|
|
response_json = await response.json()
|
|
return response_json['candidates'][0]['content']['parts'][0]['text']
|
|
"""
|
|
DEFINE YOUR OWN API MODELS BELOW WITH THE SAME TEMPLATE AS BELOW. All fields required are indicated by <required field>.
|
|
|
|
class <NameOfWebsite>(ApiModel):
|
|
def __init__(self, # model name as defined by the API
|
|
model,
|
|
api_key = <API_KEY>, # api key for the model wrt the site
|
|
**kwargs):
|
|
super().__init__(model,
|
|
api_key = api_key,
|
|
site = <name_of_website>,
|
|
**kwargs)
|
|
|
|
async def _request(self, content):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
<API ENDPOINT e.g. https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}>,
|
|
headers={
|
|
"Content-Type": "application/json"
|
|
<ANY OTHER HEADERS REQUIRED BY THE API separated by commas>
|
|
},
|
|
json={
|
|
"contents": [{"parts": [{"text": content}]}]
|
|
<ANY OTHER JSON PAIRS REQUIRED separated by commas>
|
|
}
|
|
) as response:
|
|
response_json = await response.json()
|
|
return <Anything needed to extract the message response from `response_json`>
|
|
"""
|
|
|
|
###################################################################################################
|
|
|
|
### LOCAL LLM TRANSLATION
|
|
|
|
class TranslationDataset(Dataset):
|
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
|
"""
|
|
Custom dataset for efficient text processing
|
|
|
|
Args:
|
|
texts: List of input texts
|
|
tokenizer: HuggingFace tokenizer
|
|
max_length: Maximum sequence length
|
|
"""
|
|
self.texts = texts
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, idx):
|
|
text = self.texts[idx]
|
|
|
|
# Tokenize with padding and truncation
|
|
encoding = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
).to(device)
|
|
|
|
# Remove batch dimension added by tokenizer
|
|
return {
|
|
'input_ids': encoding['input_ids'].squeeze(0),
|
|
'attention_mask': encoding['attention_mask'].squeeze(0)
|
|
}
|
|
def collate_fn(batch: List[Dict]):
|
|
"""
|
|
Custom collate function to handle batching
|
|
"""
|
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
|
attention_mask = torch.stack([item['attention_mask'] for item in batch])
|
|
|
|
return {
|
|
'input_ids': input_ids,
|
|
'attention_mask': attention_mask
|
|
}
|
|
|
|
def generate_text(
|
|
texts: List[str],
|
|
model,
|
|
tokenizer,
|
|
batch_size: int = 6, # Smaller batch size uses less VRAM
|
|
device: str = 'cuda',
|
|
max_length: int = 512,
|
|
max_new_tokens: int = 512,
|
|
**generate_kwargs
|
|
):
|
|
"""
|
|
Optimized text generation function
|
|
|
|
Args:
|
|
model: HuggingFace model
|
|
tokenizer: HuggingFace tokenizer
|
|
texts: List of input texts
|
|
batch_size: Batch size for processing
|
|
device: Device to run inference on
|
|
max_length: Maximum input sequence length
|
|
max_new_tokens: Maximum number of tokens to generate
|
|
generate_kwargs: Additional kwargs for model.generate
|
|
|
|
Returns:
|
|
List of generated texts
|
|
"""
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TranslationDataset(texts, tokenizer, max_length)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
all_generated_texts = []
|
|
|
|
# Default generation parameters
|
|
generation_config = {
|
|
'max_new_tokens': max_new_tokens,
|
|
'num_beams': 4,
|
|
'do_sample': True,
|
|
'top_k': 50,
|
|
'top_p': 0.95,
|
|
'temperature': 0.7,
|
|
'no_repeat_ngram_size': 2,
|
|
'pad_token_id': tokenizer.pad_token_id,
|
|
'eos_token_id': tokenizer.eos_token_id
|
|
}
|
|
|
|
# Update with user-provided parameters
|
|
generation_config.update(generate_kwargs)
|
|
|
|
# Perform generation
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move batch to device
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
# Generate text
|
|
outputs = model.generate(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
**generation_config
|
|
)
|
|
|
|
# Decode generated tokens
|
|
decoded_texts = tokenizer.batch_decode(
|
|
outputs,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True
|
|
)
|
|
|
|
all_generated_texts.extend(decoded_texts)
|
|
|
|
return all_generated_texts
|
|
|
|
if __name__ == '__main__':
|
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
|
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
|
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
|
groq.set_lang('zh','en')
|
|
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
|
gemini.set_lang('zh','en')
|
|
print(gemini.translate(['荷兰咯']))
|
|
print(groq.translate(['荷兰咯'])) |