onscreen-translator/helpers/batching.py

183 lines
5.6 KiB
Python

from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from dotenv import load_dotenv
import os , sys, torch, time
from multiprocessing import Process, Event
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device
from logging_config import logger
class Gemini():
def __init__(self, name, rate):
self.name = name
self.rate = rate
self.curr_calls = 0
self.time = 0
self.process = None
self.stop_event = Event()
def __repr__(self):
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
def __str__(self):
return self.name
def background_task(self):
# Background task to manage the rate of calls to the API
while not self.stop_event.is_set():
time.sleep(5)
self.time += 5
if self.time >= 60:
self.time = 0
self.curr_calls = 0
def start(self):
# Start the background task
self.process = Process(target=self.background_task)
self.process.daemon = True
self.process.start()
logger.info(f"Background process started with PID: {self.process.pid}")
def stop(self):
# Stop the background task
logger.info(f"Stopping background process with PID: {self.process.pid}")
self.stop_event.set()
if self.process:
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
"""
Custom dataset for efficient text processing
Args:
texts: List of input texts
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize with padding and truncation
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
# Remove batch dimension added by tokenizer
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def collate_fn(batch: List[Dict]):
"""
Custom collate function to handle batching
"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def generate_text(
texts: List[str],
model,
tokenizer,
batch_size: int = 6, # Smaller batch size uses less VRAM
device: str = 'cuda',
max_length: int = 512,
max_new_tokens: int = 512,
**generate_kwargs
):
"""
Optimized text generation function
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
texts: List of input texts
batch_size: Batch size for processing
device: Device to run inference on
max_length: Maximum input sequence length
max_new_tokens: Maximum number of tokens to generate
generate_kwargs: Additional kwargs for model.generate
Returns:
List of generated texts
"""
# Create dataset and dataloader
dataset = TranslationDataset(texts, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
all_generated_texts = []
# Default generation parameters
generation_config = {
'max_new_tokens': max_new_tokens,
'num_beams': 4,
'do_sample': True,
'top_k': 50,
'top_p': 0.95,
'temperature': 0.7,
'no_repeat_ngram_size': 2,
'pad_token_id': tokenizer.pad_token_id,
'eos_token_id': tokenizer.eos_token_id
}
# Update with user-provided parameters
generation_config.update(generate_kwargs)
# Perform generation
with torch.no_grad():
for batch in dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Generate text
outputs = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
**generation_config
)
# Decode generated tokens
decoded_texts = tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
all_generated_texts.extend(decoded_texts)
return all_generated_texts
if __name__ == '__main__':
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
tokenizer.src_lang = "zh"
texts = ["你好",""]
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))