456 lines
18 KiB
Python
456 lines
18 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
|
from typing import List, Dict
|
|
from datetime import datetime, timedelta
|
|
from dotenv import load_dotenv
|
|
import os , sys, torch, time, ast, json, pytz
|
|
from werkzeug.exceptions import TooManyRequests
|
|
from multiprocessing import Process, Event, Value
|
|
load_dotenv()
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
|
|
from logging_config import logger
|
|
from groq import Groq as Groqq
|
|
import google.generativeai as genai
|
|
from google.api_core.exceptions import ResourceExhausted
|
|
import asyncio
|
|
import aiohttp
|
|
from functools import wraps
|
|
from data import session, Api, Translations
|
|
from typing import Optional
|
|
|
|
|
|
class ApiModel():
|
|
def __init__(self, model, # model name as defined by the API
|
|
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
|
|
api_key: Optional[str] = None, # api key for the model wrt the site
|
|
rpmin: Optional[int] = None, # rate of calls per minute
|
|
rph: Optional[int] = None, # rate of calls per hour
|
|
rpd: Optional[int] = None, # rate of calls per day
|
|
rpw: Optional[int] = None, # rate of calls per week
|
|
rpmth: Optional[int] = None, # rate of calls per month
|
|
rpy: Optional[int] = None # rate of calls per year
|
|
):
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.rpmin = rpmin
|
|
self.rph = rph
|
|
self.rpd = rpd
|
|
self.rpw = rpw
|
|
self.rpmth = rpmth
|
|
self.rpy = rpy
|
|
self.site = site
|
|
self.from_lang = None
|
|
self.target_lang = None
|
|
self.db_table = None
|
|
self.session_calls = 0
|
|
self._id = None
|
|
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
|
|
# Create the table if it does not already exist
|
|
|
|
|
|
def __repr__(self):
|
|
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
|
|
|
|
def __str__(self):
|
|
return self.model
|
|
|
|
async def __aenter__(self):
|
|
self.session = aiohttp.ClientSession()
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.session:
|
|
await self.session.close()
|
|
def _get_db_model_id(self):
|
|
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
|
if model:
|
|
return model.id
|
|
else:
|
|
return None
|
|
|
|
def _set_db_model_id(self):
|
|
self._id = self._get_db_model_id()
|
|
|
|
@staticmethod
|
|
def _get_time():
|
|
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
|
|
|
|
def set_lang(self, from_lang, target_lang):
|
|
self.from_lang = from_lang
|
|
self.target_lang = target_lang
|
|
|
|
def set_db_table(self, db_table):
|
|
self.db_table = db_table
|
|
|
|
def update_db(self):
|
|
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
|
if not api:
|
|
api = Api(model_name = self.model,
|
|
site = self.site,
|
|
rpmin = self.rpmin,
|
|
rph = self.rph,
|
|
rpd = self.rpd,
|
|
rpw = self.rpw,
|
|
rpmth = self.rpmth,
|
|
rpy = self.rpy)
|
|
session.add(api)
|
|
session.commit()
|
|
self._set_db_model_id()
|
|
else:
|
|
api.rpmin = self.rpmin
|
|
api.rph = self.rph
|
|
api.rpd = self.rpd
|
|
api.rpw = self.rpw
|
|
api.rpmth = self.rpmth
|
|
api.rpy = self.rpy
|
|
session.commit()
|
|
|
|
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
|
|
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
|
|
translation = json.dumps(translation)
|
|
translation = Translations(source_texts = text, translated_texts = translation,
|
|
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
|
|
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
|
|
translation_mismatch = mismatch)
|
|
session.add(translation)
|
|
session.commit()
|
|
|
|
@staticmethod
|
|
def _single_period_calls_check(max_calls, call_count):
|
|
if not max_calls:
|
|
return True
|
|
if max_calls <= call_count:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _are_rates_good(self):
|
|
curr_time = self._get_time()
|
|
min_ago = curr_time - timedelta(minutes=1)
|
|
hour_ago = curr_time - timedelta(hours=1)
|
|
day_ago = curr_time - timedelta(days=1)
|
|
week_ago = curr_time - timedelta(weeks=1)
|
|
month_ago = curr_time - timedelta(days=30)
|
|
year_ago = curr_time - timedelta(days=365)
|
|
min_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= min_ago
|
|
).count()
|
|
hour_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= hour_ago
|
|
).count()
|
|
day_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= day_ago
|
|
).count()
|
|
week_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= week_ago
|
|
).count()
|
|
month_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= month_ago
|
|
).count()
|
|
year_calls = session.query(Translations).join(Api). \
|
|
filter(Api.id==self._id,
|
|
Translations.timestamp >= year_ago
|
|
).count()
|
|
if self._single_period_calls_check(self.rpmin, min_calls) \
|
|
and self._single_period_calls_check(self.rph, hour_calls) \
|
|
and self._single_period_calls_check(self.rpd, day_calls) \
|
|
and self._single_period_calls_check(self.rpw, week_calls) \
|
|
and self._single_period_calls_check(self.rpmth, month_calls) \
|
|
and self._single_period_calls_check(self.rpy, year_calls):
|
|
return True
|
|
else:
|
|
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
|
|
return False
|
|
|
|
|
|
# async def request_func(request):
|
|
# @wraps(request)
|
|
# async def wrapper(self, text, *args, **kwargs):
|
|
# if await self._are_rates_good():
|
|
# try:
|
|
# self.session_calls += 1
|
|
# response = await request(self, text, *args, **kwargs)
|
|
# return response
|
|
# except Exception as e:
|
|
# logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
|
# else:
|
|
# logger.error(f"Rate limit reached for this model.")
|
|
# raise TooManyRequests('Rate limit reached for this model.')
|
|
# return wrapper
|
|
|
|
# @request_func
|
|
async def translate(self, texts_to_translate, store = False):
|
|
if isinstance(texts_to_translate, str):
|
|
texts_to_translate = [texts_to_translate]
|
|
if len(texts_to_translate) == 0:
|
|
return []
|
|
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
|
|
prompt = f"""INSTRUCTIONS:
|
|
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
|
|
- The translations must preserve the original order.
|
|
- Each translation must be from the Source language to the Target language
|
|
- Source language: {self.from_lang}
|
|
- Target language: {self.target_lang}
|
|
- Texts are provided in JSON array syntax.
|
|
- Respond using ONLY valid JSON array syntax.
|
|
- Do not include explanations or additional text
|
|
- Escape special characters properly
|
|
|
|
Input texts:
|
|
{texts_to_translate}
|
|
|
|
Expected format:
|
|
["translation1", "translation2", ...]
|
|
|
|
Translation:"""
|
|
response = await self._request(prompt)
|
|
response_list = ast.literal_eval(response.strip())
|
|
logger.debug(repr(self))
|
|
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
|
|
|
|
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
|
|
logger.error(f"{self.model} model failed to translate all the texts. Number of translations to make: {len(texts_to_translate)}; Number of translated texts: {len(response_list)}.")
|
|
if store:
|
|
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
|
|
else:
|
|
if store:
|
|
self._db_add_translation(texts_to_translate, response_list)
|
|
print(response_list)
|
|
return response_list
|
|
|
|
class Groq(ApiModel):
|
|
def __init__(self, # model name as defined by the API
|
|
model,
|
|
api_key = GROQ_API_KEY, # api key for the model wrt the site
|
|
**kwargs):
|
|
super().__init__(model,
|
|
api_key = api_key,
|
|
site = 'Groq', **kwargs)
|
|
self.client = Groqq()
|
|
|
|
async def _request(self, content: str) -> str:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
"https://api.groq.com/openai/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {GROQ_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"messages": [{"role": "user", "content": content}],
|
|
"model": self.model
|
|
}
|
|
) as response:
|
|
response_json = await response.json()
|
|
return response_json["choices"][0]["message"]["content"]
|
|
# https://console.groq.com/settings/limits for limits
|
|
# def request(self, content):
|
|
# chat_completion = self.client.chat.completions.create(
|
|
# messages=[
|
|
# {
|
|
# "role": "user",
|
|
# "content": content,
|
|
# }
|
|
# ],
|
|
# model=self.model
|
|
# )
|
|
# return chat_completion.choices[0].message.content
|
|
|
|
# async def translate(self, texts_to_translate):
|
|
# return super().translate(self.request, texts_to_translate)
|
|
|
|
class Gemini(ApiModel):
|
|
def __init__(self, # model name as defined by the API
|
|
model,
|
|
api_key = GEMINI_API_KEY, # api key for the model wrt the site
|
|
**kwargs):
|
|
super().__init__(model,
|
|
api_key = api_key,
|
|
site = 'Google',
|
|
**kwargs)
|
|
|
|
# def request(self, content):
|
|
# genai.configure(api_key=self.api_key)
|
|
# safety_settings = {
|
|
# "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
|
# "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
|
# "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
|
# "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
|
# try:
|
|
# response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
|
# except ResourceExhausted as e:
|
|
# logger.error(f"Rate limited with {self.model}. Error: {e}")
|
|
# raise ResourceExhausted("Rate limited.")
|
|
# return response.text.strip()
|
|
|
|
async def _request(self, content):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
|
|
headers={
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"contents": [{"parts": [{"text": content}]}],
|
|
"safetySettings": [
|
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
|
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
|
|
}
|
|
]
|
|
}
|
|
) as response:
|
|
response_json = await response.json()
|
|
return response_json['candidates'][0]['content']['parts'][0]['text']
|
|
|
|
# async def translate(self, texts_to_translate):
|
|
# return super().translate(self.request, texts_to_translate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
###################################################################################################
|
|
|
|
### LOCAL LLM TRANSLATION
|
|
|
|
class TranslationDataset(Dataset):
|
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
|
"""
|
|
Custom dataset for efficient text processing
|
|
|
|
Args:
|
|
texts: List of input texts
|
|
tokenizer: HuggingFace tokenizer
|
|
max_length: Maximum sequence length
|
|
"""
|
|
self.texts = texts
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, idx):
|
|
text = self.texts[idx]
|
|
|
|
# Tokenize with padding and truncation
|
|
encoding = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
).to(device)
|
|
|
|
# Remove batch dimension added by tokenizer
|
|
return {
|
|
'input_ids': encoding['input_ids'].squeeze(0),
|
|
'attention_mask': encoding['attention_mask'].squeeze(0)
|
|
}
|
|
def collate_fn(batch: List[Dict]):
|
|
"""
|
|
Custom collate function to handle batching
|
|
"""
|
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
|
attention_mask = torch.stack([item['attention_mask'] for item in batch])
|
|
|
|
return {
|
|
'input_ids': input_ids,
|
|
'attention_mask': attention_mask
|
|
}
|
|
|
|
def generate_text(
|
|
texts: List[str],
|
|
model,
|
|
tokenizer,
|
|
batch_size: int = 6, # Smaller batch size uses less VRAM
|
|
device: str = 'cuda',
|
|
max_length: int = 512,
|
|
max_new_tokens: int = 512,
|
|
**generate_kwargs
|
|
):
|
|
"""
|
|
Optimized text generation function
|
|
|
|
Args:
|
|
model: HuggingFace model
|
|
tokenizer: HuggingFace tokenizer
|
|
texts: List of input texts
|
|
batch_size: Batch size for processing
|
|
device: Device to run inference on
|
|
max_length: Maximum input sequence length
|
|
max_new_tokens: Maximum number of tokens to generate
|
|
generate_kwargs: Additional kwargs for model.generate
|
|
|
|
Returns:
|
|
List of generated texts
|
|
"""
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TranslationDataset(texts, tokenizer, max_length)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
all_generated_texts = []
|
|
|
|
# Default generation parameters
|
|
generation_config = {
|
|
'max_new_tokens': max_new_tokens,
|
|
'num_beams': 4,
|
|
'do_sample': True,
|
|
'top_k': 50,
|
|
'top_p': 0.95,
|
|
'temperature': 0.7,
|
|
'no_repeat_ngram_size': 2,
|
|
'pad_token_id': tokenizer.pad_token_id,
|
|
'eos_token_id': tokenizer.eos_token_id
|
|
}
|
|
|
|
# Update with user-provided parameters
|
|
generation_config.update(generate_kwargs)
|
|
|
|
# Perform generation
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move batch to device
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
# Generate text
|
|
outputs = model.generate(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
**generation_config
|
|
)
|
|
|
|
# Decode generated tokens
|
|
decoded_texts = tokenizer.batch_decode(
|
|
outputs,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True
|
|
)
|
|
|
|
all_generated_texts.extend(decoded_texts)
|
|
|
|
return all_generated_texts
|
|
|
|
if __name__ == '__main__':
|
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
|
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
|
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
|
groq.set_lang('zh','en')
|
|
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
|
gemini.set_lang('zh','en')
|
|
print(gemini.translate(['荷兰咯']))
|
|
print(groq.translate(['荷兰咯'])) |