from torch.utils.data import Dataset, DataLoader from typing import List, Dict from dotenv import load_dotenv import os , sys, torch, time from multiprocessing import Process, Event load_dotenv() sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) from config import device from logging_config import logger class Gemini(): def __init__(self, name, rate): self.name = name self.rate = rate self.curr_calls = 0 self.time = 0 self.process = None self.stop_event = Event() def __repr__(self): return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.' def __str__(self): return self.name def background_task(self): # Background task to manage the rate of calls to the API while not self.stop_event.is_set(): time.sleep(5) self.time += 5 if self.time >= 60: self.time = 0 self.curr_calls = 0 def start(self): # Start the background task self.process = Process(target=self.background_task) self.process.daemon = True self.process.start() logger.info(f"Background process started with PID: {self.process.pid}") def stop(self): # Stop the background task logger.info(f"Stopping background process with PID: {self.process.pid}") self.stop_event.set() if self.process: self.process.join(timeout=5) if self.process.is_alive(): self.process.terminate() class TranslationDataset(Dataset): def __init__(self, texts: List[str], tokenizer, max_length: int = 512): """ Custom dataset for efficient text processing Args: texts: List of input texts tokenizer: HuggingFace tokenizer max_length: Maximum sequence length """ self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] # Tokenize with padding and truncation encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ).to(device) # Remove batch dimension added by tokenizer return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0) } def collate_fn(batch: List[Dict]): """ Custom collate function to handle batching """ input_ids = torch.stack([item['input_ids'] for item in batch]) attention_mask = torch.stack([item['attention_mask'] for item in batch]) return { 'input_ids': input_ids, 'attention_mask': attention_mask } def generate_text( texts: List[str], model, tokenizer, batch_size: int = 6, # Smaller batch size uses less VRAM device: str = 'cuda', max_length: int = 512, max_new_tokens: int = 512, **generate_kwargs ): """ Optimized text generation function Args: model: HuggingFace model tokenizer: HuggingFace tokenizer texts: List of input texts batch_size: Batch size for processing device: Device to run inference on max_length: Maximum input sequence length max_new_tokens: Maximum number of tokens to generate generate_kwargs: Additional kwargs for model.generate Returns: List of generated texts """ # Create dataset and dataloader dataset = TranslationDataset(texts, tokenizer, max_length) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) all_generated_texts = [] # Default generation parameters generation_config = { 'max_new_tokens': max_new_tokens, 'num_beams': 4, 'do_sample': True, 'top_k': 50, 'top_p': 0.95, 'temperature': 0.7, 'no_repeat_ngram_size': 2, 'pad_token_id': tokenizer.pad_token_id, 'eos_token_id': tokenizer.eos_token_id } # Update with user-provided parameters generation_config.update(generate_kwargs) # Perform generation with torch.no_grad(): for batch in dataloader: # Move batch to device batch = {k: v.to(device) for k, v in batch.items()} # Generate text outputs = model.generate( input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], **generation_config ) # Decode generated tokens decoded_texts = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True ) all_generated_texts.extend(decoded_texts) return all_generated_texts if __name__ == '__main__': from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device) tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True) tokenizer.src_lang = "zh" texts = ["你好","我"] print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))