183 lines
5.6 KiB
Python
183 lines
5.6 KiB
Python
from torch.utils.data import Dataset, DataLoader
|
|
from typing import List, Dict
|
|
from dotenv import load_dotenv
|
|
import os , sys, torch, time
|
|
from multiprocessing import Process, Event
|
|
load_dotenv()
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
|
from config import device
|
|
from logging_config import logger
|
|
|
|
|
|
class Gemini():
|
|
def __init__(self, name, rate):
|
|
self.name = name
|
|
self.rate = rate
|
|
self.curr_calls = 0
|
|
self.time = 0
|
|
self.process = None
|
|
self.stop_event = Event()
|
|
|
|
def __repr__(self):
|
|
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def background_task(self):
|
|
# Background task to manage the rate of calls to the API
|
|
while not self.stop_event.is_set():
|
|
time.sleep(5)
|
|
self.time += 5
|
|
if self.time >= 60:
|
|
self.time = 0
|
|
self.curr_calls = 0
|
|
|
|
def start(self):
|
|
# Start the background task
|
|
self.process = Process(target=self.background_task)
|
|
self.process.daemon = True
|
|
self.process.start()
|
|
logger.info(f"Background process started with PID: {self.process.pid}")
|
|
|
|
def stop(self):
|
|
# Stop the background task
|
|
logger.info(f"Stopping background process with PID: {self.process.pid}")
|
|
self.stop_event.set()
|
|
if self.process:
|
|
self.process.join(timeout=5)
|
|
if self.process.is_alive():
|
|
self.process.terminate()
|
|
|
|
class TranslationDataset(Dataset):
|
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
|
"""
|
|
Custom dataset for efficient text processing
|
|
|
|
Args:
|
|
texts: List of input texts
|
|
tokenizer: HuggingFace tokenizer
|
|
max_length: Maximum sequence length
|
|
"""
|
|
self.texts = texts
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_length
|
|
|
|
def __len__(self):
|
|
return len(self.texts)
|
|
|
|
def __getitem__(self, idx):
|
|
text = self.texts[idx]
|
|
|
|
# Tokenize with padding and truncation
|
|
encoding = self.tokenizer(
|
|
text,
|
|
max_length=self.max_length,
|
|
padding='max_length',
|
|
truncation=True,
|
|
return_tensors='pt'
|
|
).to(device)
|
|
|
|
# Remove batch dimension added by tokenizer
|
|
return {
|
|
'input_ids': encoding['input_ids'].squeeze(0),
|
|
'attention_mask': encoding['attention_mask'].squeeze(0)
|
|
}
|
|
def collate_fn(batch: List[Dict]):
|
|
"""
|
|
Custom collate function to handle batching
|
|
"""
|
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
|
attention_mask = torch.stack([item['attention_mask'] for item in batch])
|
|
|
|
return {
|
|
'input_ids': input_ids,
|
|
'attention_mask': attention_mask
|
|
}
|
|
|
|
def generate_text(
|
|
texts: List[str],
|
|
model,
|
|
tokenizer,
|
|
batch_size: int = 6, # Smaller batch size uses less VRAM
|
|
device: str = 'cuda',
|
|
max_length: int = 512,
|
|
max_new_tokens: int = 512,
|
|
**generate_kwargs
|
|
):
|
|
"""
|
|
Optimized text generation function
|
|
|
|
Args:
|
|
model: HuggingFace model
|
|
tokenizer: HuggingFace tokenizer
|
|
texts: List of input texts
|
|
batch_size: Batch size for processing
|
|
device: Device to run inference on
|
|
max_length: Maximum input sequence length
|
|
max_new_tokens: Maximum number of tokens to generate
|
|
generate_kwargs: Additional kwargs for model.generate
|
|
|
|
Returns:
|
|
List of generated texts
|
|
"""
|
|
|
|
# Create dataset and dataloader
|
|
dataset = TranslationDataset(texts, tokenizer, max_length)
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
all_generated_texts = []
|
|
|
|
# Default generation parameters
|
|
generation_config = {
|
|
'max_new_tokens': max_new_tokens,
|
|
'num_beams': 4,
|
|
'do_sample': True,
|
|
'top_k': 50,
|
|
'top_p': 0.95,
|
|
'temperature': 0.7,
|
|
'no_repeat_ngram_size': 2,
|
|
'pad_token_id': tokenizer.pad_token_id,
|
|
'eos_token_id': tokenizer.eos_token_id
|
|
}
|
|
|
|
# Update with user-provided parameters
|
|
generation_config.update(generate_kwargs)
|
|
|
|
# Perform generation
|
|
with torch.no_grad():
|
|
for batch in dataloader:
|
|
# Move batch to device
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
|
# Generate text
|
|
outputs = model.generate(
|
|
input_ids=batch['input_ids'],
|
|
attention_mask=batch['attention_mask'],
|
|
**generation_config
|
|
)
|
|
|
|
# Decode generated tokens
|
|
decoded_texts = tokenizer.batch_decode(
|
|
outputs,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True
|
|
)
|
|
|
|
all_generated_texts.extend(decoded_texts)
|
|
|
|
return all_generated_texts
|
|
|
|
if __name__ == '__main__':
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
|
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
|
tokenizer.src_lang = "zh"
|
|
texts = ["你好","我"]
|
|
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))
|