Added GEMINI API support, added proper support for Japanese romanisation and furigana and optimised batching for local LLMs

This commit is contained in:
chickenflyshigh 2024-11-03 13:29:50 +11:00
parent ee4b3ed43e
commit 17e7f6526f
14 changed files with 673 additions and 431 deletions

2
.gitignore vendored
View File

@ -3,3 +3,5 @@
translate/ translate/
__pycache__/ __pycache__/
.* .*
test.py

View File

@ -0,0 +1,4 @@
## Debugging Issues
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment.
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.

View File

@ -1,132 +0,0 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, logging, ast
from helpers.translation import init_TRANSLATE, translate
from helpers.utils import intercepts, contains_lang, printsc, romanize, convert_image_to_bytes, bytes_to_image
from helpers.ocr import id_filtered, id_lang, get_words, get_positions, get_confidences, init_OCR
from logging_config import setup_logger
from helpers.draw import modify_image_bytes
###################################################################################
#### LOGGING ####
setup_logger('chinese_to_eng', log_file='chinese_to_eng.log')
###################################################################################
##### Variables to edit #####
INTERVAL = int(os.getenv('INTERVAL'))
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
### Translation
TRANSLATION_MODEL = os.getenv('TRANSLATION_MODEL', 'opus') # 'opus' or 'm2m' # opus is a lot more lightweight
MAX_TRANSLATE = 200
### OCR
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
###################################################################################
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
latest_image = None
def main():
global latest_image
# screenshot
untranslated_image = printsc(REGION)
byte_image = convert_image_to_bytes(untranslated_image)
###################################################################################
##### Initialize the OCR #####
ocr = init_OCR(model=OCR_MODEL, ocr_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
ocr_output = id_lang(ocr, byte_image, 'ja')
curr_words = set(get_words(ocr_output))
prev_words = set()
##### Initialize the translation #####
init_TRANSLATE()
###################################################################################
while True:
print('Running')
if prev_words != curr_words:
print('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
translation = translate(to_translate, from_lang, target_lang)
print(translation)
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
latest_image = bytes_to_image(translated_image)
prev_words = curr_words
logging.info(f"Successfully translated image. Prev words are:\n{prev_words}")
else:
logging.info("The image has remained the same.")
# torch.cuda.empty_cache()
logging.info(f'Sleeping for {INTERVAL} seconds')
time.sleep(INTERVAL)
untranslated_image = printsc(REGION)
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_lang(ocr, byte_image, 'ja')
curr_words = set(get_words(ocr_output))
logging.info(f'Curr words to translate are:\n{curr_words}')
if __name__ == "__main__":
main()
# image = Image.open(SCREENSHOT_PATH)
# draw = ImageDraw.Draw(image)
# # set counter for limiting the number of translations
# translated_number = 0
# bounding_boxes = []
# for i, (position,words,confidence) in enumerate(ocr_output):
# if translated_number >= MAX_TRANSLATE:
# break
# # try:
# top_left, _, _, _ = position
# position = (top_left[0], top_left[1] - 60)
# text_content = f"{translation[i]}\n{romanize(words)}\n{words}"
# lines = text_content.split('\n')
# x,y = position
# max_width = 0
# total_height = 0
# line_spacing = 3
# line_height = FONT_SIZE
# for line in lines:
# bbox = draw.textbbox(position, line, font=font)
# line_width, _ = bbox[2] - bbox[0], bbox[3] - bbox[1]
# max_width = max(max_width, line_width)
# total_height += line_height + line_spacing
# bounding_box = (x, y, x + max_width, y + total_height, words)
# print(f"Bounding Box of Interest: {bounding_box}")
# y = np.max([y,0])
# if len(bounding_boxes) > 0:
# for box in bounding_boxes:
# print(f'Investigating box: {box}')
# if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
# print(f'Overlapping change adjustment to {words}')
# y = np.max([y,box[3]]) + line_spacing
# print(y, box[3])
# print(f'Changed to {(x,y, x+max_width, y+total_height, words)}')
# adjusted_bounding_box = (x, y, x + max_width, y + total_height, words)
# bounding_boxes.append(adjusted_bounding_box)
# draw.rectangle([(x,y), (x+max_width, y+total_height)], outline="black", width=1)
# position = (x,y)
# for line in lines:
# draw.text(position, line, fill= TEXT_COLOR, font=font)
# y += FONT_SIZE + line_spacing
# position = (x,y)
# print("Adjusted_bounding_box:",adjusted_bounding_box)
# print('\n')
# translated_number += 1

54
config.py Normal file
View File

@ -0,0 +1,54 @@
import os, ast, torch
from dotenv import load_dotenv
load_dotenv(override=True)
###################################################################################################
### EDIT THESE VARIABLES ###
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
INTERVAL = int(os.getenv('INTERVAL'))
### OCR
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
### Drawing/Overlay Config
FONT_FILE = os.getenv('FONT_FILE')
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
### Translation
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
GEMINI_KEY = os.getenv('GEMINI_KEY')
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
###################################################################################################
LINE_HEIGHT = FONT_SIZE
if TRANSLATION_USE_GPU is False:
device = torch.device("cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### Just for info
available_langs = ['ch_sim', 'ch_tra', 'ja', 'ko', 'en'] # there are limitations with the languages that can be used with the OCR models
seq_llm_models = ['opus', 'm2m']
api_llm_models = ['gemini']
causal_llm_models = []
curr_models = seq_llm_models + api_llm_models + causal_llm_models

View File

@ -1,19 +1,12 @@
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from dotenv import load_dotenv import os,io, sys, numpy as np
import os
import io sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
import numpy as np from utils import romanize, intercepts, add_furigana
import ast from logging_config import logger
from helpers.utils import romanize, intercepts from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE
load_dotenv()
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
FONT_FILE = os.getenv('FONT_FILE')
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
LINE_HEIGHT = FONT_SIZE
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
font = ImageFont.truetype(FONT_FILE, FONT_SIZE) font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
@ -33,7 +26,6 @@ def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -
return modified_image_bytes return modified_image_bytes
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw: def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw:
translation
translated_number = 0 translated_number = 0
bounding_boxes = [] bounding_boxes = []
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output): for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
@ -47,7 +39,16 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
# Draw the bounding box # Draw the bounding box
top_left, _, _, _ = position top_left, _, _, _ = position
position = (top_left[0], top_left[1] - 60) position = (top_left[0], top_left[1] - 60)
text_content = f"{translated_phrase}\n{romanize(untranslated_phrase, TO_ROMANIZE)}\n{untranslated_phrase}" if SOURCE_LANG == 'ja':
untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja')
else:
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
if TO_ROMANIZE:
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
else:
text_content = f"{translated_phrase}\n{untranslated_phrase}"
lines = text_content.split('\n') lines = text_content.split('\n')
x,y = position x,y = position
max_width = 0 max_width = 0
@ -58,7 +59,6 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
line_width = bbox[2] - bbox[0] line_width = bbox[2] - bbox[0]
max_width = max(max_width, line_width) max_width = max(max_width, line_width)
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase) bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
print(f"Bounding Box of Interest: {bounding_box}")
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height) adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1] adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
@ -68,18 +68,13 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
draw.text(position, line, fill= TEXT_COLOR, font=font) draw.text(position, line, fill= TEXT_COLOR, font=font)
adjusted_y += FONT_SIZE + LINE_SPACING adjusted_y += FONT_SIZE + LINE_SPACING
position = (adjusted_x,adjusted_y) position = (adjusted_x,adjusted_y)
print(f"Adjusted_bounding_box: {bounding_box[-1]}.\n")
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple: def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
y = np.max([y,0]) y = np.max([y,0])
if len(bounding_boxes) > 0: if len(bounding_boxes) > 0:
for box in bounding_boxes: for box in bounding_boxes:
print(f'Investigating box: {box}')
if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)): if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
print(f'Overlapping change adjustment to {untranslated_phrase}')
y = np.max([y,box[3]]) + LINE_SPACING y = np.max([y,box[3]]) + LINE_SPACING
print(y, box[3])
print(f'Changed to {(x,y, x+max_width, y+total_height, untranslated_phrase)}')
adjusted_bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase) adjusted_bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
bounding_boxes.append(adjusted_bounding_box) bounding_boxes.append(adjusted_bounding_box)
return adjusted_bounding_box return adjusted_bounding_box

View File

@ -1,153 +0,0 @@
import easyocr
from pypinyin import pinyin
from PIL import Image, ImageDraw, ImageFont
import os, time, logging, torch, subprocess
from helpers.translation import init_M2M, translate_M2M
import langid
import numpy as np
##### Variables to edit
text_color = "#ff0000"
font_file = "/home/James/.local/share/fonts/Arial-Unicode-Bold.ttf"
font_size = 16
pyin = True # whether to add pinyin or not
max_translate = 100
# for detecting language to filter out other languages. Only writes the text when it is detected to be src_lang
src_lang = "zh"
tgt_lang = "en"
# af, am, an, ar, as, az, be, bg, bn, br, bs, ca, cs, cy, da, de, dz, el, en, eo, es, et, eu, fa, fi, fo, fr, ga, gl, gu, he, hi, hr, ht, hu, hy, id, is, it, ja, jv, ka, kk, km, kn, ko, ku, ky, la, lb, lo, lt, lv, mg, mk, ml, mn, mr, ms, mt, nb, ne, nl, nn, no, oc, or, pa, pl, ps, pt, qu, ro, ru, rw, se, si, sk, sl, sq, sr, sv, sw, ta, te, th, tl, tr, ug, uk, ur, vi, vo, wa, xh, zh, zu
langid.set_languages([src_lang,tgt_lang,'en'])
# for translator (M2M100)
from_lang = "zh"
target_lang = "en"
# Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
# for easyOCR
OCR_languages = ['ch_sim','en'] # languages to recognise
# https://www.jaided.ai/easyocr/
log_directory = '/var/log/ocr'
printsc = lambda x: subprocess.run(f"grim -t png -o DP-1 -l 0 {x}", shell=True)
# Configure the logger
os.makedirs(log_directory, exist_ok=True)
logging.basicConfig(
filename=os.path.join(log_directory, 'ocr.log'),
level=logging.DEBUG, # Set the logging level
format='%(asctime)s - %(message)s', # Define the format for logging
datefmt='%Y-%m-%d %H:%M:%S' # Define the date format
)
# screenshot
printsc(image_old)
time.sleep(1)
# EasyOCR
reader = easyocr.Reader(OCR_languages) # this needs to run only once to load the model into memory
def results():
result = reader.readtext(image_old)
results_no_eng = [entry for entry in result if langid.classify(entry[1])[0] == src_lang]
return results_no_eng
# result is a list of tuples with the following structure:
# (top_left, top_right, bottom_right, bottom_left, text, confidence)
# top_left, top_right, bottom_right, bottom_left are the coordinates of the bounding box
ocr_output = results()
curr_words = set(entry[1] for entry in ocr_output)
prev_words = set()
# translator = GoogleTranslator(source=from_language, target=target_language)
font = ImageFont.truetype(font_file, font_size)
# define a function for checking whether one axis of a shape intercepts with another
def intercepts(x,y):
# both x and y are two dimensional tuples representing the ends of a line on one dimension.
x1, x2 = x
y1, y2 = y
return (x1 <= y1 <= x2) or (x1 <= y2 <= x2) or (y1 <= x1 <= y2) or (y1 <= x2 <= y2)
while True:
print('Running')
if prev_words != curr_words:
print('Translating')
image = Image.open(image_old)
draw = ImageDraw.Draw(image)
to_translate = [entry[1] for entry in ocr_output][:max_translate]
translation = translate_M2M(to_translate, from_lang = from_lang, target_lang = target_lang)
# set counter for limiting the number of translations
translated_number = 0
bounding_boxes = []
for i, (position,words,confidence) in enumerate(ocr_output):
if translated_number >= max_translate:
break
word = translation[i]
# try:
top_left, _, _, _ = position
position = (top_left[0], top_left[1] - 60)
if pyin:
py = ' '.join([ py[0] for py in pinyin(words)])
text_content = f"{translation[i]}\n{py}\n{words}"
else:
text_content = f"{translation[i]}\n{words}"
lines = text_content.split('\n')
x,y = position
max_width = 0
total_height = 0
line_spacing = 3
line_height = font_size
for line in lines:
bbox = draw.textbbox(position, line, font=font)
line_width, _ = bbox[2] - bbox[0], bbox[3] - bbox[1]
max_width = max(max_width, line_width)
total_height += line_height + line_spacing
bounding_box = (x, y, x + max_width, y + total_height, words)
print(f"Bounding Box of Interest: {bounding_box}")
y = np.max([y,0])
if len(bounding_boxes) > 0:
for box in bounding_boxes:
print(f'Investigating box: {box}')
if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
print(f'Overlapping change adjustment to {words}')
y = np.max([y,box[3]]) + line_spacing
print(y, box[3])
print(f'Changed to {(x,y, x+max_width, y+total_height, words)}')
adjusted_bounding_box = (x, y, x + max_width, y + total_height, words)
bounding_boxes.append(adjusted_bounding_box)
draw.rectangle([(x,y), (x+max_width, y+total_height)], outline="black", width=1)
position = (x,y)
for line in lines:
draw.text(position, line, fill= text_color, font=font)
y += font_size + line_spacing
position = (x,y)
print("Adjusted_bounding_box:",adjusted_bounding_box)
print('\n')
# except Exception as e:
# logging.error(e)
translated_number += 1
image.save(image_new)
logging.info(f"Saved the image to {image_new}")
prev_words = curr_words
logging.info(f"Successfully translated image. Prev words are:\n{prev_words}")
else:
logging.info("The image has remained the same.")
torch.cuda.empty_cache()
print('Sleeping')
time.sleep(10)
printsc(image_old)
ocr_output = results()
curr_words = set(entry[1] for entry in ocr_output)
logging.info(f'Curr words are:\n{curr_words}')

182
helpers/batching.py Normal file
View File

@ -0,0 +1,182 @@
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from dotenv import load_dotenv
import os , sys, torch, time
from multiprocessing import Process, Event
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device
from logging_config import logger
class Gemini():
def __init__(self, name, rate):
self.name = name
self.rate = rate
self.curr_calls = 0
self.time = 0
self.process = None
self.stop_event = Event()
def __repr__(self):
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
def __str__(self):
return self.name
def background_task(self):
# Background task to manage the rate of calls to the API
while not self.stop_event.is_set():
time.sleep(5)
self.time += 5
if self.time >= 60:
self.time = 0
self.curr_calls = 0
def start(self):
# Start the background task
self.process = Process(target=self.background_task)
self.process.daemon = True
self.process.start()
logger.info(f"Background process started with PID: {self.process.pid}")
def stop(self):
# Stop the background task
logger.info(f"Stopping background process with PID: {self.process.pid}")
self.stop_event.set()
if self.process:
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
"""
Custom dataset for efficient text processing
Args:
texts: List of input texts
tokenizer: HuggingFace tokenizer
max_length: Maximum sequence length
"""
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize with padding and truncation
encoding = self.tokenizer(
text,
max_length=self.max_length,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
# Remove batch dimension added by tokenizer
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0)
}
def collate_fn(batch: List[Dict]):
"""
Custom collate function to handle batching
"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
def generate_text(
texts: List[str],
model,
tokenizer,
batch_size: int = 6, # Smaller batch size uses less VRAM
device: str = 'cuda',
max_length: int = 512,
max_new_tokens: int = 512,
**generate_kwargs
):
"""
Optimized text generation function
Args:
model: HuggingFace model
tokenizer: HuggingFace tokenizer
texts: List of input texts
batch_size: Batch size for processing
device: Device to run inference on
max_length: Maximum input sequence length
max_new_tokens: Maximum number of tokens to generate
generate_kwargs: Additional kwargs for model.generate
Returns:
List of generated texts
"""
# Create dataset and dataloader
dataset = TranslationDataset(texts, tokenizer, max_length)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
all_generated_texts = []
# Default generation parameters
generation_config = {
'max_new_tokens': max_new_tokens,
'num_beams': 4,
'do_sample': True,
'top_k': 50,
'top_p': 0.95,
'temperature': 0.7,
'no_repeat_ngram_size': 2,
'pad_token_id': tokenizer.pad_token_id,
'eos_token_id': tokenizer.eos_token_id
}
# Update with user-provided parameters
generation_config.update(generate_kwargs)
# Perform generation
with torch.no_grad():
for batch in dataloader:
# Move batch to device
batch = {k: v.to(device) for k, v in batch.items()}
# Generate text
outputs = model.generate(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
**generation_config
)
# Decode generated tokens
decoded_texts = tokenizer.batch_decode(
outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
all_generated_texts.extend(decoded_texts)
return all_generated_texts
if __name__ == '__main__':
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
tokenizer.src_lang = "zh"
texts = ["你好",""]
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))

View File

@ -1,20 +1,27 @@
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
import easyocr import easyocr
from typing import Optional
from rapidocr_onnxruntime import RapidOCR from rapidocr_onnxruntime import RapidOCR
import langid import langid, sys,os
from helpers.utils import contains_lang from utils import contains_lang, standardize_lang
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from logging_config import logger
# PaddleOCR # PaddleOCR
# Paddleocr supports Chinese, English, French, German, Korean and Japanese. # Paddleocr supports Chinese, English, French, German, Korean and Japanese.
# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan` # You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
# to switch the language model in order. # to switch the language model in order.
# need to run only once to download and load model into memory # need to run only once to download and load model into memory
def _paddle_init(lang='ch', use_angle_cls=False, use_GPU=True): default_languages = ['en', 'ch', 'ja', 'ko']
return PaddleOCR(use_angle_cls=use_angle_cls, lang=lang, use_GPU=use_GPU)
def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
return PaddleOCR(use_angle_cls=use_angle_cls, lang=paddle_lang, use_GPU=use_GPU, **kwargs)
def _paddle_ocr(ocr, image) -> list: def _paddle_ocr(ocr, image) -> list:
### return a list containing the bounding box, text and confidence of the detected text ### return a list containing the bounding box, text and confidence of the detected text
result = ocr.ocr(image, cls=False)[0] result = ocr.ocr(image, cls=False)[0]
if not isinstance(result, list): if not isinstance(result, list):
@ -24,31 +31,34 @@ def _paddle_ocr(ocr, image) -> list:
# EasyOCR has support for many languages # EasyOCR has support for many languages
def _easy_init(ocr_languages: list, use_GPU=True): def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
return easyocr.Reader(ocr_languages, gpu=use_GPU) langs = []
for lang in easy_languages:
langs.append(standardize_lang(lang)['easyocr_lang'])
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
def _easy_ocr(ocr,image) -> list: def _easy_ocr(ocr,image) -> list:
return ocr.readtext(image) return ocr.readtext(image)
# RapidOCR mostly for mandarin and some other asian languages # RapidOCR mostly for mandarin and some other asian languages
def _rapid_init(use_GPU=True): def _rapid_init(use_GPU=True, **kwargs):
return RapidOCR(use_gpu=use_GPU) return RapidOCR(use_gpu=use_GPU, **kwargs)
def _rapid_ocr(ocr, image) -> list: def _rapid_ocr(ocr, image) -> list:
return ocr(image) return ocr(image)
### Initialize the OCR model ### Initialize the OCR model
def init_OCR(model='paddle', **kwargs): def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs):
if model == 'paddle': if model == 'paddle':
return _paddle_init(**kwargs) return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
elif model == 'easy': elif model == 'easy':
return _easy_init(**kwargs) return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU)
elif model == 'rapid': elif model == 'rapid':
return _rapid_init(**kwargs) return _rapid_init(use_GPU=use_GPU)
### Perform OCR on the image ### Perform OCR on the image
def identify(ocr, image) -> list: def _identify(ocr, image) -> list:
if isinstance(ocr, PaddleOCR): if isinstance(ocr, PaddleOCR):
return _paddle_ocr(ocr, image) return _paddle_ocr(ocr, image)
elif isinstance(ocr, easyocr.Reader): elif isinstance(ocr, easyocr.Reader):
@ -56,13 +66,14 @@ def identify(ocr, image) -> list:
elif isinstance(ocr, RapidOCR): elif isinstance(ocr, RapidOCR):
return _rapid_ocr(ocr, image) return _rapid_ocr(ocr, image)
else: else:
raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to identify().") raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to _identify().")
### Filter out the results that are not in the source language ### Filter out the results that are not in the source language. Slower but for a wider range of languages
def id_filtered(ocr, image, lang) -> list: # not working but also not very reliable so don't worry about it
result = identify(ocr, image) def _id_filtered(ocr, image, lang) -> list:
lang = standardize_lang(lang)['id_model_lang']
result = _identify(ocr, image)
### Parallelise since langid is slow ### Parallelise since langid is slow
def classify_text(entry): def classify_text(entry):
return entry if langid.classify(entry[1])[0] == lang else None return entry if langid.classify(entry[1])[0] == lang else None
@ -71,12 +82,29 @@ def id_filtered(ocr, image, lang) -> list:
return results_no_eng return results_no_eng
# zh, ja, ko # ch_sim, ch_tra, ja, ko, en
def id_lang(ocr, image, lang) -> list: def _id_lang(ocr, image, lang) -> list:
result = identify(ocr, image) result = _identify(ocr, image)
lang = standardize_lang(lang)['id_model_lang']
try:
filtered = [entry for entry in result if contains_lang(entry[1], lang)] filtered = [entry for entry in result if contains_lang(entry[1], lang)]
except:
logger.error(f"Selected language not part of default: {default_languages}.")
raise ValueError(f"Selected language not part of default: {default_languages}.")
return filtered return filtered
def id_keep_source_lang(ocr, image, lang) -> list:
try:
return _id_lang(ocr, image, lang)
except ValueError:
try:
return _id_filtered(ocr, image, lang)
except Exception as e:
print(f'Probably an issue with the _id_filtered function. {e}')
sys.exit(1)
def get_words(ocr_output) -> list: def get_words(ocr_output) -> list:
return [entry[1] for entry in ocr_output] return [entry[1] for entry in ocr_output]
@ -85,3 +113,12 @@ def get_positions(ocr_output) -> list:
def get_confidences(ocr_output) -> list: def get_confidences(ocr_output) -> list:
return [entry[2] for entry in ocr_output] return [entry[2] for entry in ocr_output]
if __name__ == '__main__':
# OCR_languages = ['ch_sim','en']
# image_old = '/home/James/Pictures/Screenshots/DP-1.jpg'
# reader = easyocr.Reader(OCR_languages, gpu=True) # this needs to run only once to load the model into memory
# result = reader.readtext(image_old)
# print(result)
print(id_keep_source_lang(init_OCR(model='paddle', paddle_lang='zh', easy_languages=['en', 'ch_sim']), '/home/James/Pictures/Screenshots/DP-1.jpg', 'ch_sim'))

View File

@ -1,76 +1,217 @@
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import torch, os import google.generativeai as genai
from dotenv import load_dotenv import torch, os, sys, ast
load_dotenv() from utils import standardize_lang
from functools import wraps
from batching import generate_text, Gemini
from logging_config import logger
from multiprocessing import Process,Event
# root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
if os.getenv('TRANSLATION_USE_GPU') in ['False', '0', 'false', 'no', 'No', 'NO', 'FALSE']: from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
device = torch.device("cpu")
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
### Batch translate a list of strings ##############################
# translation decorator
def translate(translation_func):
@wraps(translation_func)
def wrapper(text, *args, **kwargs):
try:
if len(text) == 0:
return []
return translation_func(text, *args, **kwargs)
except Exception as e:
logger.error(f"Translation error with the following function: {translation_func.__name__}. Text: {text}\nError: {e}")
return wrapper
###############################
###############################
def init_GEMINI(models_and_rates = None):
if not models_and_rates:
## this is default for free tier
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref
models = [Gemini(name, rate) for name, rate in models_and_rates.items()]
for model in models:
model.start()
genai.configure(api_key=GEMINI_KEY)
return models
def translate_GEMINI(text, models, from_lang, target_lang):
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
for model in models:
if model.curr_calls < model.rate:
try:
response = genai.GenerativeModel(model.name).generate_content(prompt,
safety_settings=safety_settings)
model.curr_calls += 1
logger.info(repr(model))
logger.info(f'Model Response: {response.text.strip()}')
return ast.literal_eval(response.text.strip())
except Exception as e:
logger.error(f"Error with model {model.name}. Error: {e}")
logger.error("No models available to translate. Please wait for a model to be available.")
###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
def init_AYA():
model_id = "CohereForAI/aya-23-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(model_id, locals_files_only=True, torch_dtype=torch.float16).to(device)
model.eval()
return (model, tokenizer)
##############################
# M2M100 model # M2M100 model
def init_M2M(): def init_M2M():
global tokenizer, model tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True) model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True, torch_dtype=torch.float16).to(device)
model.eval() model.eval()
return (model, tokenizer)
def translate_M2M(text, model, tokenizer, from_lang = 'ch_sim', target_lang = 'en') -> list[str]:
def translate_M2M(text, from_lang = 'zh', target_lang = 'en'): model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0: if len(text) == 0:
return [] return []
tokenizer.src_lang = from_lang tokenizer.src_lang = model_lang_from
with torch.no_grad(): generated_translations = generate_text(text, model,tokenizer, batch_size=BATCH_SIZE,
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device) max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS,
generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(model_lang_to))
forced_bos_token_id=tokenizer.get_lang_id(target_lang)) return generated_translations
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated ###############################
###############################
# Helsinki-NLP model Opus MT # Helsinki-NLP model Opus MT
# Refer here for all the models https://huggingface.co/Helsinki-NLP
def get_OPUS_model(from_lang, target_lang):
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
return f"Helsinki-NLP/opus-mt-{model_lang_from}-{model_lang_to}"
def init_OPUS(): def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
global tokenizer, model opus_model = get_OPUS_model(from_lang, target_lang)
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True, torch_dtype=torch.float16).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
model.eval() model.eval()
return (model, tokenizer)
def translate_OPUS(text: list[str]) -> list[str]: def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
translated_text = generate_text(model,tokenizer, text,
batch_size=BATCH_SIZE, device=device,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
return translated_text
###############################
def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
if model_type == 'opus':
return init_OPUS(**kwargs)
elif model_type == 'm2m':
return init_M2M()
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
if model_type == 'gemini':
return init_GEMINI(**kwargs)
else:
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
def init_Causal_LLM(model_type, **kwargs):
pass
###
@translate
def translate_Seq_LLM(text,
model_type, # 'opus' or 'm2m'
model,
tokenizer,
**kwargs):
if model_type == 'opus':
return translate_OPUS(text, model, tokenizer)
elif model_type == 'm2m':
try:
return translate_M2M(text, model, tokenizer, **kwargs)
except Exception as e:
logger.error(f"Error with M2M model. Error: {e}")
# raise ValueError(f"Please provide the correct from_lang and target_lang variables if you are using the M2M model. Use the list from {available_langs}.")
else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
### if you want to use any other translation, just define a translate function with input text and output text.
# def translate_api(text):
#@translate
#def translate_Causal_LLM(text, model_type, model)
@translate
def translate_API_LLM(text: list[str],
model_type: str, # 'gemma'
models: list, # list of objects of classes defined in batching.py
from_lang: str, # suggested to use ISO 639-1 codes
target_lang: str # suggested to use ISO 639-1 codes
) -> list[str]:
if model_type == 'gemini':
from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang']
return translate_GEMINI(text, models, from_lang, target_lang)
else:
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
@translate
def translate_Causal_LLM(text: list[str],
model_type, # aya
model,
tokenizer,
from_lang: str,
target_lang: str) -> list[str]:
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
if len(text) == 0: if len(text) == 0:
return [] return []
with torch.no_grad(): pass
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
generated_tokens = model.generate(**encoded)
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated
###
def init_TRANSLATE(model): # model = 'opus' or 'm2m' # choose between local Seq2Seq LLM or obtain translations from an API
if model == 'opus': def init_func(model):
init_OPUS() if model in seq_llm_models:
elif model == 'm2m': return init_Seq_LLM
init_M2M() elif model in api_llm_models:
return init_API_LLM
elif model in causal_llm_models:
return init_Causal_LLM
else: else:
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.") raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
###
def translate(text, model, **kwargs):
if model == 'opus':
return translate_OPUS(text)
elif model == 'm2m':
try:
return translate_M2M(text, **kwargs)
except:
raise ValueError("Please provide the from_lang and target_lang variables if you are using the M2M model.")
def translate_func(model):
if model in seq_llm_models:
return translate_Seq_LLM
elif model in api_llm_models:
return translate_API_LLM
elif model in causal_llm_models:
return translate_Causal_LLM
else: else:
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.") raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
if __name__ == "__main__":
models = init_GEMINI()
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en'))
# model, tokenizer = init_M2M()
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))

View File

@ -3,12 +3,15 @@ from pypinyin import pinyin
import pyscreenshot as ImageGrab # wayland tings not sure if it will work on other machines alternatively use mss import pyscreenshot as ImageGrab # wayland tings not sure if it will work on other machines alternatively use mss
import mss, io, os import mss, io, os
from PIL import Image from PIL import Image
import jaconv import jaconv, MeCab, unidic, pykakasi
import MeCab
import unidic # for creating furigana
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR)) mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
uroman = ur.Uroman() uroman = ur.Uroman()
# for romanising japanese text. Can convert to hiragana or katakana as well but does not split the words up so harder to use for furigana
kks = pykakasi.kakasi()
# define a function for checking whether one axis of a shape intercepts with another # define a function for checking whether one axis of a shape intercepts with another
def intercepts(x,y): def intercepts(x,y):
@ -71,7 +74,10 @@ def add_furigana(text):
furigana_string = '' furigana_string = ''
for i in parsed: for i in parsed:
words = i.split('\t')[0] words = i.split('\t')[0]
try :
add = f'({jaconv.kata2hira(i.split(',')[6])})' add = f'({jaconv.kata2hira(i.split(',')[6])})'
except:
add = ''
to_add = add if contains_kanji(words) else '' to_add = add if contains_kanji(words) else ''
furigana_string += i.split('\t')[0] + to_add furigana_string += i.split('\t')[0] + to_add
return furigana_string return furigana_string
@ -87,10 +93,12 @@ def contains_katakana(text):
return bool(re.search(r'[\u30A0-\u30FF]', text)) return bool(re.search(r'[\u30A0-\u30FF]', text))
# use kakasi to romanize japanese text
def romanize(text, piny=False): def romanize(text, lang):
if piny: if lang == 'zh':
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)]) return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
if lang == 'ja':
return kks.convert(text)[0]['hepburn']
return uroman.romanize_string(text) return uroman.romanize_string(text)
# check if a string contains words from a language # check if a string contains words from a language
@ -107,36 +115,45 @@ def contains_lang(text, lang):
else: else:
raise ValueError("Invalid language. Please use one of 'en', 'zh', 'ja', or 'ko'.") raise ValueError("Invalid language. Please use one of 'en', 'zh', 'ja', or 'ko'.")
### en, ch_sim, ja, ko rapidocr only has chinese and en at the moment ### en, ch_sim, ch_tra, ja, ko rapidocr only has chinese and en at the moment
def standardize_lang(lang): def standardize_lang(lang):
if lang == 'ch_sim': if lang == 'ch_sim':
easyocr_lang = 'ch_sim' easyocr_lang = 'ch_sim'
paddleocr_lang = 'ch' paddleocr_lang = 'ch'
rapidocr_lang = 'ch' rapidocr_lang = 'ch'
translation_model_lang = 'zh' translation_model_lang = 'zh'
id_model_lang = 'zh'
elif lang == 'ch_tra': elif lang == 'ch_tra':
easyocr_lang = 'ch_tra' easyocr_lang = 'ch_tra'
paddleocr_lang = 'ch' paddleocr_lang = 'ch'
rapidocr_lang = 'ch' rapidocr_lang = 'ch'
translation_model_lang = 'zh' translation_model_lang = 'zh'
id_model_lang = 'zh'
elif lang == 'ja': elif lang == 'ja':
easyocr_lang = 'ja' easyocr_lang = 'ja'
paddleocr_lang = 'ja' paddleocr_lang = 'ja'
rapidocr_lang = 'ja' rapidocr_lang = 'ja'
translation_model_lang = 'ja' translation_model_lang = 'ja'
id_model_lang = 'ja'
elif lang == 'ko': elif lang == 'ko':
easyocr_lang = 'korean' easyocr_lang = 'korean'
paddleocr_lang = 'ko' paddleocr_lang = 'ko'
rapidocr_lang = 'ko' rapidocr_lang = 'ko'
translation_model_lang = 'ko' translation_model_lang = 'ko'
id_model_lang = 'ko'
elif lang == 'en': elif lang == 'en':
easyocr_lang = 'en' easyocr_lang = 'en'
paddleocr_lang = 'en' paddleocr_lang = 'en'
rapidocr_lang = 'en' rapidocr_lang = 'en'
translation_model_lang = 'en' translation_model_lang = 'en'
id_model_lang = 'en'
else: else:
raise ValueError("Invalid language. Please use one of 'en', 'ch_sim', 'ch_tra', 'ja', or 'ko'.") raise ValueError(f"Invalid language {lang}. Please use one of 'en', 'ch_sim', 'ch_tra', 'ja', or 'ko'.")
return {'easyocr_lang': easyocr_lang, 'paddleocr_lang': paddleocr_lang, 'rapidocr_lang': rapidocr_lang, 'translation_model_lang': translation_model_lang} return {'easyocr_lang': easyocr_lang,
'paddleocr_lang': paddleocr_lang,
'rapidocr_lang': rapidocr_lang,
'translation_model_lang': translation_model_lang,
'id_model_lang': id_model_lang}
def which_ocr_lang(model): def which_ocr_lang(model):
if model == 'easy': if model == 'easy':

View File

@ -1,17 +1,42 @@
import logging, os import logging, os
from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# Configure the logger def setup_logger(
name: str,
log_file: Optional[str] = None,
level: int = logging.INFO
) -> Optional[logging.Logger]:
"""
Set up a logger with the specified name and level.
def setup_logger(name: str, log_file: str = None, level: int = logging.INFO) -> logging.Logger: Args:
"""Set up a logger with the specified name and level.""" name: Logger name
log_file: Path to log file (defaults to name.log)
level: Logging level (defaults to INFO)
Returns:
Logger object if successful, None if setup fails
"""
try:
if log_file is None: if log_file is None:
log_file = f"{name}.log" log_file = f"{name}.log"
# Validate logging level
valid_levels = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL]
if level not in valid_levels:
level = logging.INFO
# Create a logger # Create a logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
logger.propagate = False
# Clear existing handlers
if logger.handlers:
logger.handlers.clear()
# Create file handler # Create file handler
file_handler = logging.FileHandler(log_file) file_handler = logging.FileHandler(log_file)
@ -22,15 +47,21 @@ def setup_logger(name: str, log_file: str = None, level: int = logging.INFO) ->
console_handler.setLevel(level) console_handler.setLevel(level)
# Create a formatter and set it for both handlers # Create a formatter and set it for both handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - [%(levelname)s] %(message)s', formatter = logging.Formatter(
datefmt='%Y-%m-%d %H:%M:%S') '%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
# Add handlers to the logger # Add handlers to the logger
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
return logger return logger
except Exception as e:
print(f"Failed to setup logger: {e}")
return None
logger = setup_logger('on_screen_translator', log_file='translate.log')

View File

@ -10,7 +10,7 @@
<img <img
id="live-image" id="live-image"
src="/image" src="/image"
alt="Live Image" alt="No Translations Available"
style="max-width: 100%; height: auto" /> style="max-width: 100%; height: auto" />
<script> <script>

68
translate.py Normal file
View File

@ -0,0 +1,68 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
###################################################################################
latest_image = None
def main():
global latest_image
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(TRANSLATION_MODEL)
###################################################################################
runs = 0
while True:
untranslated_image = printsc(REGION)
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.info(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
logger.info(f'Translation from {to_translate} to\n {translation}')
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
latest_image = bytes_to_image(translated_image)
# latest_image.show() # for debugging
prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {INTERVAL} seconds')
time.sleep(INTERVAL)
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
if __name__ == "__main__":
main()

View File

@ -1,14 +1,12 @@
from flask import Flask, Response, render_template from flask import Flask, Response, render_template
import time
import threading import threading
from PIL import Image
import io import io
import chinese_to_eng import translate
app = Flask(__name__) app = Flask(__name__)
# Global variable to hold the current image # Global variable to hold the current image
def curr_image(): def curr_image():
return chinese_to_eng.latest_image return translate.latest_image
@app.route('/') @app.route('/')
def index(): def index():
@ -19,8 +17,6 @@ def index():
def stream_image(): def stream_image():
if curr_image() is None: if curr_image() is None:
return "No image generated yet.", 503 return "No image generated yet.", 503
print('streaming')
print(curr_image())
file_object = io.BytesIO() file_object = io.BytesIO()
curr_image().save(file_object, 'PNG') curr_image().save(file_object, 'PNG')
file_object.seek(0) file_object.seek(0)
@ -33,7 +29,7 @@ def stream_image():
if __name__ == '__main__': if __name__ == '__main__':
# Start the image updating thread # Start the image updating thread
threading.Thread(target=chinese_to_eng.main, daemon=True).start() threading.Thread(target=translate.main, daemon=True).start()
# Start the Flask web server # Start the Flask web server
app.run(host='0.0.0.0', port=5000, debug=True) app.run(host='0.0.0.0', port=5000, debug=True)