272 lines
9.2 KiB
Python
272 lines
9.2 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
|
from typing import List, Dict
|
|
from dotenv import load_dotenv
|
|
import os , sys, torch, time, ast
|
|
from werkzeug.exceptions import TooManyRequests
|
|
from multiprocessing import Process, Event, Value
|
|
load_dotenv()
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
from config import device, GEMINI_API_KEY, GROQ_API_KEY
|
|
from logging_config import logger
|
|
from groq import Groq
|
|
import google.generativeai as genai
|
|
import asyncio
|
|
from functools import wraps
|
|
|
|
|
|
|
|
class ApiModel():
|
|
def __init__(self, model, # model name
|
|
rate, # rate of calls per minute
|
|
api_key, # api key for the model wrt the site
|
|
):
|
|
self.model = model
|
|
self.rate = rate
|
|
self.api_key = api_key
|
|
self.curr_calls = Value('i', 0)
|
|
self.time = Value('i', 0)
|
|
self.process = None
|
|
self.stop_event = Event()
|
|
self.site = None
|
|
self.from_lang = None
|
|
self.target_lang = None
|
|
self.request = None # request response from API
|
|
|
|
def __repr__(self):
|
|
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
|
|
|
|
def __str__(self):
|
|
return self.model
|
|
|
|
def set_lang(self, from_lang, target_lang):
|
|
self.from_lang = from_lang
|
|
self.target_lang = target_lang
|
|
|
|
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
|
|
async def api_rate_check(self):
|
|
# Background task to manage the rate of calls to the API
|
|
while not self.stop_event.is_set():
|
|
start_time = time.monotonic()
|
|
self.time.value += 5
|
|
if self.time.value >= 60:
|
|
self.time.value = 0
|
|
self.curr_calls.value = 0
|
|
elapsed = time.monotonic() - start_time
|
|
# Sleep for exactly 5 seconds minus the elapsed time
|
|
sleep_time = max(0, 5 - elapsed)
|
|
await asyncio.sleep(sleep_time)
|
|
|
|
def background_task(self):
|
|
asyncio.run(self.api_rate_check())
|
|
|
|
def start(self):
|
|
# Start the background task
|
|
self.process = Process(target=self.background_task)
|
|
self.process.daemon = True
|
|
self.process.start()
|
|
logger.info(f"Background process started with PID: {self.process.pid}")
|
|
|
|
def stop(self):
|
|
# Stop the background task
|
|
logger.info(f"Stopping background process with PID: {self.process.pid}")
|
|
self.stop_event.set()
|
|
if self.process:
|
|
self.process.join(timeout=5)
|
|
if self.process.is_alive():
|
|
self.process.terminate()
|
|
|
|
def request_func(request):
|
|
@wraps(request)
|
|
def wrapper(self, text, *args, **kwargs):
|
|
if self.curr_calls.value < self.rate:
|
|
# try:
|
|
response = request(self, text, *args, **kwargs)
|
|
self.curr_calls.value += 1
|
|
return response
|
|
# except Exception as e:
|
|
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
|
else:
|
|
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
|
|
raise TooManyRequests('Rate limit reached for this model.')
|
|
return wrapper
|
|
|
|
@request_func
|
|
def translate(self, request_fn, texts_to_translate):
|
|
if len(texts_to_translate) == 0:
|
|
return []
|
|
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
|
|
response = request_fn(self, prompt)
|
|
return ast.literal_eval(response.strip())
|
|
|
|
class Groqq(ApiModel):
|
|
def __init__(self, model, rate, api_key = GROQ_API_KEY):
|
|
super().__init__(model, rate, api_key)
|
|
self.site = "Groq"
|
|
|
|
def request(self, content):
|
|
client = Groq()
|
|
chat_completion = client.chat.completions.create(
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": content,
|
|
}
|
|
],
|
|
model=self.model
|
|
)
|
|
return chat_completion.choices[0].message.content
|
|
|
|
def translate(self, texts_to_translate):
|
|
return super().translate(Groqq.request, texts_to_translate)
|
|
|
|
class Gemini(ApiModel):
|
|
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
|
|
super().__init__(model, rate, api_key)
|
|
self.site = "Gemini"
|
|
|
|
def request(self, content):
|
|
genai.configure(api_key=self.api_key)
|
|
safety_settings = {
|
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
|
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
|
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
|
return response.text.strip()
|
|
|
|
def translate(self, texts_to_translate):
|
|
return super().translate(Gemini.request, texts_to_translate)
|
|
|
|
class TranslationDataset(Dataset):
|
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
|
"""
|
|
Custom dataset for efficient text processing
|
|
|
|
Args:
|
|
texts: List of input texts
|
|
tokenizer: HuggingFace tokenizer
|
|
max_length: Maximum sequence length
|
|
"""
|
|
self.texts = texts
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, idx):
|
|
text = self.texts[idx]
|
|
|
|
# Tokenize with padding and truncation
|
|
encoding = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
).to(device)
|
|
|
|
# Remove batch dimension added by tokenizer
|
|
return {
|
|
'input_ids': encoding['input_ids'].squeeze(0),
|
|
'attention_mask': encoding['attention_mask'].squeeze(0)
|
|
}
|
|
def collate_fn(batch: List[Dict]):
|
|
"""
|
|
Custom collate function to handle batching
|
|
"""
|
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
|
attention_mask = torch.stack([item['attention_mask'] for item in batch])
|
|
|
|
return {
|
|
'input_ids': input_ids,
|
|
'attention_mask': attention_mask
|
|
}
|
|
|
|
def generate_text(
|
|
texts: List[str],
|
|
model,
|
|
tokenizer,
|
|
batch_size: int = 6, # Smaller batch size uses less VRAM
|
|
device: str = 'cuda',
|
|
max_length: int = 512,
|
|
max_new_tokens: int = 512,
|
|
**generate_kwargs
|
|
):
|
|
"""
|
|
Optimized text generation function
|
|
|
|
Args:
|
|
model: HuggingFace model
|
|
tokenizer: HuggingFace tokenizer
|
|
texts: List of input texts
|
|
batch_size: Batch size for processing
|
|
device: Device to run inference on
|
|
max_length: Maximum input sequence length
|
|
max_new_tokens: Maximum number of tokens to generate
|
|
generate_kwargs: Additional kwargs for model.generate
|
|
|
|
Returns:
|
|
List of generated texts
|
|
"""
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TranslationDataset(texts, tokenizer, max_length)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
all_generated_texts = []
|
|
|
|
# Default generation parameters
|
|
generation_config = {
|
|
'max_new_tokens': max_new_tokens,
|
|
'num_beams': 4,
|
|
'do_sample': True,
|
|
'top_k': 50,
|
|
'top_p': 0.95,
|
|
'temperature': 0.7,
|
|
'no_repeat_ngram_size': 2,
|
|
'pad_token_id': tokenizer.pad_token_id,
|
|
'eos_token_id': tokenizer.eos_token_id
|
|
}
|
|
|
|
# Update with user-provided parameters
|
|
generation_config.update(generate_kwargs)
|
|
|
|
# Perform generation
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move batch to device
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
# Generate text
|
|
outputs = model.generate(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
**generation_config
|
|
)
|
|
|
|
# Decode generated tokens
|
|
decoded_texts = tokenizer.batch_decode(
|
|
outputs,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True
|
|
)
|
|
|
|
all_generated_texts.extend(decoded_texts)
|
|
|
|
return all_generated_texts
|
|
|
|
if __name__ == '__main__':
|
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
|
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
|
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
|
groq.set_lang('zh','en')
|
|
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
|
gemini.set_lang('zh','en')
|
|
print(gemini.translate(['荷兰咯']))
|
|
print(groq.translate(['荷兰咯'])) |