Added GEMINI API support, added proper support for Japanese romanisation and furigana and optimised batching for local LLMs
This commit is contained in:
parent
ee4b3ed43e
commit
17e7f6526f
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,3 +3,5 @@
|
|||||||
translate/
|
translate/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
.*
|
.*
|
||||||
|
test.py
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
## Debugging Issues
|
||||||
|
|
||||||
|
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment.
|
||||||
|
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
|
||||||
@ -1,132 +0,0 @@
|
|||||||
###################################################################################
|
|
||||||
##### IMPORT LIBRARIES #####
|
|
||||||
import os, time, logging, ast
|
|
||||||
from helpers.translation import init_TRANSLATE, translate
|
|
||||||
from helpers.utils import intercepts, contains_lang, printsc, romanize, convert_image_to_bytes, bytes_to_image
|
|
||||||
from helpers.ocr import id_filtered, id_lang, get_words, get_positions, get_confidences, init_OCR
|
|
||||||
from logging_config import setup_logger
|
|
||||||
from helpers.draw import modify_image_bytes
|
|
||||||
###################################################################################
|
|
||||||
|
|
||||||
#### LOGGING ####
|
|
||||||
setup_logger('chinese_to_eng', log_file='chinese_to_eng.log')
|
|
||||||
|
|
||||||
###################################################################################
|
|
||||||
##### Variables to edit #####
|
|
||||||
|
|
||||||
|
|
||||||
INTERVAL = int(os.getenv('INTERVAL'))
|
|
||||||
|
|
||||||
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
|
||||||
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
|
||||||
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
|
||||||
|
|
||||||
### Translation
|
|
||||||
TRANSLATION_MODEL = os.getenv('TRANSLATION_MODEL', 'opus') # 'opus' or 'm2m' # opus is a lot more lightweight
|
|
||||||
MAX_TRANSLATE = 200
|
|
||||||
|
|
||||||
### OCR
|
|
||||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
|
||||||
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU
|
|
||||||
|
|
||||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
|
||||||
###################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
|
||||||
latest_image = None
|
|
||||||
|
|
||||||
def main():
|
|
||||||
global latest_image
|
|
||||||
# screenshot
|
|
||||||
untranslated_image = printsc(REGION)
|
|
||||||
byte_image = convert_image_to_bytes(untranslated_image)
|
|
||||||
|
|
||||||
###################################################################################
|
|
||||||
##### Initialize the OCR #####
|
|
||||||
ocr = init_OCR(model=OCR_MODEL, ocr_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
|
||||||
ocr_output = id_lang(ocr, byte_image, 'ja')
|
|
||||||
curr_words = set(get_words(ocr_output))
|
|
||||||
prev_words = set()
|
|
||||||
|
|
||||||
##### Initialize the translation #####
|
|
||||||
init_TRANSLATE()
|
|
||||||
###################################################################################
|
|
||||||
|
|
||||||
while True:
|
|
||||||
print('Running')
|
|
||||||
if prev_words != curr_words:
|
|
||||||
print('Translating')
|
|
||||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
|
||||||
translation = translate(to_translate, from_lang, target_lang)
|
|
||||||
print(translation)
|
|
||||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
|
||||||
latest_image = bytes_to_image(translated_image)
|
|
||||||
prev_words = curr_words
|
|
||||||
logging.info(f"Successfully translated image. Prev words are:\n{prev_words}")
|
|
||||||
else:
|
|
||||||
logging.info("The image has remained the same.")
|
|
||||||
# torch.cuda.empty_cache()
|
|
||||||
logging.info(f'Sleeping for {INTERVAL} seconds')
|
|
||||||
time.sleep(INTERVAL)
|
|
||||||
|
|
||||||
untranslated_image = printsc(REGION)
|
|
||||||
byte_image = convert_image_to_bytes(untranslated_image)
|
|
||||||
ocr_output = id_lang(ocr, byte_image, 'ja')
|
|
||||||
curr_words = set(get_words(ocr_output))
|
|
||||||
logging.info(f'Curr words to translate are:\n{curr_words}')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
# image = Image.open(SCREENSHOT_PATH)
|
|
||||||
# draw = ImageDraw.Draw(image)
|
|
||||||
|
|
||||||
# # set counter for limiting the number of translations
|
|
||||||
# translated_number = 0
|
|
||||||
# bounding_boxes = []
|
|
||||||
# for i, (position,words,confidence) in enumerate(ocr_output):
|
|
||||||
# if translated_number >= MAX_TRANSLATE:
|
|
||||||
# break
|
|
||||||
# # try:
|
|
||||||
# top_left, _, _, _ = position
|
|
||||||
# position = (top_left[0], top_left[1] - 60)
|
|
||||||
# text_content = f"{translation[i]}\n{romanize(words)}\n{words}"
|
|
||||||
# lines = text_content.split('\n')
|
|
||||||
# x,y = position
|
|
||||||
|
|
||||||
# max_width = 0
|
|
||||||
# total_height = 0
|
|
||||||
# line_spacing = 3
|
|
||||||
# line_height = FONT_SIZE
|
|
||||||
|
|
||||||
# for line in lines:
|
|
||||||
# bbox = draw.textbbox(position, line, font=font)
|
|
||||||
# line_width, _ = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
||||||
# max_width = max(max_width, line_width)
|
|
||||||
# total_height += line_height + line_spacing
|
|
||||||
|
|
||||||
# bounding_box = (x, y, x + max_width, y + total_height, words)
|
|
||||||
# print(f"Bounding Box of Interest: {bounding_box}")
|
|
||||||
|
|
||||||
# y = np.max([y,0])
|
|
||||||
# if len(bounding_boxes) > 0:
|
|
||||||
# for box in bounding_boxes:
|
|
||||||
# print(f'Investigating box: {box}')
|
|
||||||
# if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
|
|
||||||
# print(f'Overlapping change adjustment to {words}')
|
|
||||||
# y = np.max([y,box[3]]) + line_spacing
|
|
||||||
# print(y, box[3])
|
|
||||||
# print(f'Changed to {(x,y, x+max_width, y+total_height, words)}')
|
|
||||||
# adjusted_bounding_box = (x, y, x + max_width, y + total_height, words)
|
|
||||||
# bounding_boxes.append(adjusted_bounding_box)
|
|
||||||
# draw.rectangle([(x,y), (x+max_width, y+total_height)], outline="black", width=1)
|
|
||||||
# position = (x,y)
|
|
||||||
# for line in lines:
|
|
||||||
# draw.text(position, line, fill= TEXT_COLOR, font=font)
|
|
||||||
# y += FONT_SIZE + line_spacing
|
|
||||||
# position = (x,y)
|
|
||||||
# print("Adjusted_bounding_box:",adjusted_bounding_box)
|
|
||||||
# print('\n')
|
|
||||||
# translated_number += 1
|
|
||||||
54
config.py
Normal file
54
config.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import os, ast, torch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
###################################################################################################
|
||||||
|
### EDIT THESE VARIABLES ###
|
||||||
|
|
||||||
|
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||||
|
INTERVAL = int(os.getenv('INTERVAL'))
|
||||||
|
|
||||||
|
### OCR
|
||||||
|
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
|
||||||
|
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
||||||
|
|
||||||
|
### Drawing/Overlay Config
|
||||||
|
FONT_FILE = os.getenv('FONT_FILE')
|
||||||
|
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
||||||
|
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
||||||
|
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
||||||
|
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
|
||||||
|
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
||||||
|
|
||||||
|
### Translation
|
||||||
|
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||||
|
GEMINI_KEY = os.getenv('GEMINI_KEY')
|
||||||
|
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||||
|
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
||||||
|
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
||||||
|
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
||||||
|
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
||||||
|
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
||||||
|
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
|
||||||
|
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
|
||||||
|
###################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
LINE_HEIGHT = FONT_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if TRANSLATION_USE_GPU is False:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
### Just for info
|
||||||
|
|
||||||
|
available_langs = ['ch_sim', 'ch_tra', 'ja', 'ko', 'en'] # there are limitations with the languages that can be used with the OCR models
|
||||||
|
seq_llm_models = ['opus', 'm2m']
|
||||||
|
api_llm_models = ['gemini']
|
||||||
|
causal_llm_models = []
|
||||||
|
curr_models = seq_llm_models + api_llm_models + causal_llm_models
|
||||||
@ -1,19 +1,12 @@
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from dotenv import load_dotenv
|
import os,io, sys, numpy as np
|
||||||
import os
|
|
||||||
import io
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||||
import numpy as np
|
from utils import romanize, intercepts, add_furigana
|
||||||
import ast
|
from logging_config import logger
|
||||||
from helpers.utils import romanize, intercepts
|
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
|
||||||
FONT_FILE = os.getenv('FONT_FILE')
|
|
||||||
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
|
||||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
|
||||||
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
|
|
||||||
LINE_HEIGHT = FONT_SIZE
|
|
||||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
|
||||||
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
||||||
|
|
||||||
|
|
||||||
@ -33,7 +26,6 @@ def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -
|
|||||||
return modified_image_bytes
|
return modified_image_bytes
|
||||||
|
|
||||||
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw:
|
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw:
|
||||||
translation
|
|
||||||
translated_number = 0
|
translated_number = 0
|
||||||
bounding_boxes = []
|
bounding_boxes = []
|
||||||
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
||||||
@ -47,7 +39,16 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
|
|||||||
# Draw the bounding box
|
# Draw the bounding box
|
||||||
top_left, _, _, _ = position
|
top_left, _, _, _ = position
|
||||||
position = (top_left[0], top_left[1] - 60)
|
position = (top_left[0], top_left[1] - 60)
|
||||||
text_content = f"{translated_phrase}\n{romanize(untranslated_phrase, TO_ROMANIZE)}\n{untranslated_phrase}"
|
if SOURCE_LANG == 'ja':
|
||||||
|
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||||
|
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||||
|
else:
|
||||||
|
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
|
||||||
|
if TO_ROMANIZE:
|
||||||
|
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
|
||||||
|
else:
|
||||||
|
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||||
|
|
||||||
lines = text_content.split('\n')
|
lines = text_content.split('\n')
|
||||||
x,y = position
|
x,y = position
|
||||||
max_width = 0
|
max_width = 0
|
||||||
@ -58,7 +59,6 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
|
|||||||
line_width = bbox[2] - bbox[0]
|
line_width = bbox[2] - bbox[0]
|
||||||
max_width = max(max_width, line_width)
|
max_width = max(max_width, line_width)
|
||||||
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
||||||
print(f"Bounding Box of Interest: {bounding_box}")
|
|
||||||
|
|
||||||
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||||
@ -68,18 +68,13 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
|
|||||||
draw.text(position, line, fill= TEXT_COLOR, font=font)
|
draw.text(position, line, fill= TEXT_COLOR, font=font)
|
||||||
adjusted_y += FONT_SIZE + LINE_SPACING
|
adjusted_y += FONT_SIZE + LINE_SPACING
|
||||||
position = (adjusted_x,adjusted_y)
|
position = (adjusted_x,adjusted_y)
|
||||||
print(f"Adjusted_bounding_box: {bounding_box[-1]}.\n")
|
|
||||||
|
|
||||||
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
|
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
|
||||||
y = np.max([y,0])
|
y = np.max([y,0])
|
||||||
if len(bounding_boxes) > 0:
|
if len(bounding_boxes) > 0:
|
||||||
for box in bounding_boxes:
|
for box in bounding_boxes:
|
||||||
print(f'Investigating box: {box}')
|
|
||||||
if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
|
if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
|
||||||
print(f'Overlapping change adjustment to {untranslated_phrase}')
|
|
||||||
y = np.max([y,box[3]]) + LINE_SPACING
|
y = np.max([y,box[3]]) + LINE_SPACING
|
||||||
print(y, box[3])
|
|
||||||
print(f'Changed to {(x,y, x+max_width, y+total_height, untranslated_phrase)}')
|
|
||||||
adjusted_bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
adjusted_bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
||||||
bounding_boxes.append(adjusted_bounding_box)
|
bounding_boxes.append(adjusted_bounding_box)
|
||||||
return adjusted_bounding_box
|
return adjusted_bounding_box
|
||||||
@ -1,153 +0,0 @@
|
|||||||
import easyocr
|
|
||||||
from pypinyin import pinyin
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
|
||||||
import os, time, logging, torch, subprocess
|
|
||||||
from helpers.translation import init_M2M, translate_M2M
|
|
||||||
import langid
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
##### Variables to edit
|
|
||||||
|
|
||||||
text_color = "#ff0000"
|
|
||||||
font_file = "/home/James/.local/share/fonts/Arial-Unicode-Bold.ttf"
|
|
||||||
font_size = 16
|
|
||||||
|
|
||||||
pyin = True # whether to add pinyin or not
|
|
||||||
max_translate = 100
|
|
||||||
|
|
||||||
# for detecting language to filter out other languages. Only writes the text when it is detected to be src_lang
|
|
||||||
src_lang = "zh"
|
|
||||||
tgt_lang = "en"
|
|
||||||
# af, am, an, ar, as, az, be, bg, bn, br, bs, ca, cs, cy, da, de, dz, el, en, eo, es, et, eu, fa, fi, fo, fr, ga, gl, gu, he, hi, hr, ht, hu, hy, id, is, it, ja, jv, ka, kk, km, kn, ko, ku, ky, la, lb, lo, lt, lv, mg, mk, ml, mn, mr, ms, mt, nb, ne, nl, nn, no, oc, or, pa, pl, ps, pt, qu, ro, ru, rw, se, si, sk, sl, sq, sr, sv, sw, ta, te, th, tl, tr, ug, uk, ur, vi, vo, wa, xh, zh, zu
|
|
||||||
langid.set_languages([src_lang,tgt_lang,'en'])
|
|
||||||
|
|
||||||
# for translator (M2M100)
|
|
||||||
from_lang = "zh"
|
|
||||||
target_lang = "en"
|
|
||||||
|
|
||||||
# Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
|
|
||||||
|
|
||||||
# for easyOCR
|
|
||||||
OCR_languages = ['ch_sim','en'] # languages to recognise
|
|
||||||
# https://www.jaided.ai/easyocr/
|
|
||||||
|
|
||||||
log_directory = '/var/log/ocr'
|
|
||||||
printsc = lambda x: subprocess.run(f"grim -t png -o DP-1 -l 0 {x}", shell=True)
|
|
||||||
|
|
||||||
# Configure the logger
|
|
||||||
os.makedirs(log_directory, exist_ok=True)
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
filename=os.path.join(log_directory, 'ocr.log'),
|
|
||||||
level=logging.DEBUG, # Set the logging level
|
|
||||||
format='%(asctime)s - %(message)s', # Define the format for logging
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S' # Define the date format
|
|
||||||
)
|
|
||||||
|
|
||||||
# screenshot
|
|
||||||
printsc(image_old)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# EasyOCR
|
|
||||||
reader = easyocr.Reader(OCR_languages) # this needs to run only once to load the model into memory
|
|
||||||
|
|
||||||
def results():
|
|
||||||
result = reader.readtext(image_old)
|
|
||||||
results_no_eng = [entry for entry in result if langid.classify(entry[1])[0] == src_lang]
|
|
||||||
return results_no_eng
|
|
||||||
|
|
||||||
# result is a list of tuples with the following structure:
|
|
||||||
# (top_left, top_right, bottom_right, bottom_left, text, confidence)
|
|
||||||
# top_left, top_right, bottom_right, bottom_left are the coordinates of the bounding box
|
|
||||||
ocr_output = results()
|
|
||||||
curr_words = set(entry[1] for entry in ocr_output)
|
|
||||||
prev_words = set()
|
|
||||||
|
|
||||||
# translator = GoogleTranslator(source=from_language, target=target_language)
|
|
||||||
|
|
||||||
font = ImageFont.truetype(font_file, font_size)
|
|
||||||
|
|
||||||
# define a function for checking whether one axis of a shape intercepts with another
|
|
||||||
def intercepts(x,y):
|
|
||||||
# both x and y are two dimensional tuples representing the ends of a line on one dimension.
|
|
||||||
x1, x2 = x
|
|
||||||
y1, y2 = y
|
|
||||||
return (x1 <= y1 <= x2) or (x1 <= y2 <= x2) or (y1 <= x1 <= y2) or (y1 <= x2 <= y2)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
print('Running')
|
|
||||||
if prev_words != curr_words:
|
|
||||||
print('Translating')
|
|
||||||
image = Image.open(image_old)
|
|
||||||
draw = ImageDraw.Draw(image)
|
|
||||||
to_translate = [entry[1] for entry in ocr_output][:max_translate]
|
|
||||||
translation = translate_M2M(to_translate, from_lang = from_lang, target_lang = target_lang)
|
|
||||||
# set counter for limiting the number of translations
|
|
||||||
translated_number = 0
|
|
||||||
bounding_boxes = []
|
|
||||||
for i, (position,words,confidence) in enumerate(ocr_output):
|
|
||||||
if translated_number >= max_translate:
|
|
||||||
break
|
|
||||||
word = translation[i]
|
|
||||||
# try:
|
|
||||||
top_left, _, _, _ = position
|
|
||||||
position = (top_left[0], top_left[1] - 60)
|
|
||||||
if pyin:
|
|
||||||
py = ' '.join([ py[0] for py in pinyin(words)])
|
|
||||||
text_content = f"{translation[i]}\n{py}\n{words}"
|
|
||||||
else:
|
|
||||||
text_content = f"{translation[i]}\n{words}"
|
|
||||||
lines = text_content.split('\n')
|
|
||||||
x,y = position
|
|
||||||
|
|
||||||
max_width = 0
|
|
||||||
total_height = 0
|
|
||||||
line_spacing = 3
|
|
||||||
line_height = font_size
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
bbox = draw.textbbox(position, line, font=font)
|
|
||||||
line_width, _ = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
||||||
max_width = max(max_width, line_width)
|
|
||||||
total_height += line_height + line_spacing
|
|
||||||
|
|
||||||
bounding_box = (x, y, x + max_width, y + total_height, words)
|
|
||||||
print(f"Bounding Box of Interest: {bounding_box}")
|
|
||||||
|
|
||||||
y = np.max([y,0])
|
|
||||||
if len(bounding_boxes) > 0:
|
|
||||||
for box in bounding_boxes:
|
|
||||||
print(f'Investigating box: {box}')
|
|
||||||
if intercepts((box[0],box[2]),(bounding_box[0],bounding_box[2])) and intercepts((box[1],box[3]),(y, y+total_height)):
|
|
||||||
print(f'Overlapping change adjustment to {words}')
|
|
||||||
y = np.max([y,box[3]]) + line_spacing
|
|
||||||
print(y, box[3])
|
|
||||||
print(f'Changed to {(x,y, x+max_width, y+total_height, words)}')
|
|
||||||
adjusted_bounding_box = (x, y, x + max_width, y + total_height, words)
|
|
||||||
bounding_boxes.append(adjusted_bounding_box)
|
|
||||||
draw.rectangle([(x,y), (x+max_width, y+total_height)], outline="black", width=1)
|
|
||||||
position = (x,y)
|
|
||||||
for line in lines:
|
|
||||||
draw.text(position, line, fill= text_color, font=font)
|
|
||||||
y += font_size + line_spacing
|
|
||||||
position = (x,y)
|
|
||||||
print("Adjusted_bounding_box:",adjusted_bounding_box)
|
|
||||||
print('\n')
|
|
||||||
# except Exception as e:
|
|
||||||
# logging.error(e)
|
|
||||||
translated_number += 1
|
|
||||||
image.save(image_new)
|
|
||||||
logging.info(f"Saved the image to {image_new}")
|
|
||||||
prev_words = curr_words
|
|
||||||
logging.info(f"Successfully translated image. Prev words are:\n{prev_words}")
|
|
||||||
else:
|
|
||||||
logging.info("The image has remained the same.")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print('Sleeping')
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
printsc(image_old)
|
|
||||||
ocr_output = results()
|
|
||||||
curr_words = set(entry[1] for entry in ocr_output)
|
|
||||||
logging.info(f'Curr words are:\n{curr_words}')
|
|
||||||
|
|
||||||
182
helpers/batching.py
Normal file
182
helpers/batching.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from typing import List, Dict
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os , sys, torch, time
|
||||||
|
from multiprocessing import Process, Event
|
||||||
|
load_dotenv()
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
from config import device
|
||||||
|
from logging_config import logger
|
||||||
|
|
||||||
|
|
||||||
|
class Gemini():
|
||||||
|
def __init__(self, name, rate):
|
||||||
|
self.name = name
|
||||||
|
self.rate = rate
|
||||||
|
self.curr_calls = 0
|
||||||
|
self.time = 0
|
||||||
|
self.process = None
|
||||||
|
self.stop_event = Event()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def background_task(self):
|
||||||
|
# Background task to manage the rate of calls to the API
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
time.sleep(5)
|
||||||
|
self.time += 5
|
||||||
|
if self.time >= 60:
|
||||||
|
self.time = 0
|
||||||
|
self.curr_calls = 0
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
# Start the background task
|
||||||
|
self.process = Process(target=self.background_task)
|
||||||
|
self.process.daemon = True
|
||||||
|
self.process.start()
|
||||||
|
logger.info(f"Background process started with PID: {self.process.pid}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
# Stop the background task
|
||||||
|
logger.info(f"Stopping background process with PID: {self.process.pid}")
|
||||||
|
self.stop_event.set()
|
||||||
|
if self.process:
|
||||||
|
self.process.join(timeout=5)
|
||||||
|
if self.process.is_alive():
|
||||||
|
self.process.terminate()
|
||||||
|
|
||||||
|
class TranslationDataset(Dataset):
|
||||||
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
||||||
|
"""
|
||||||
|
Custom dataset for efficient text processing
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of input texts
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
max_length: Maximum sequence length
|
||||||
|
"""
|
||||||
|
self.texts = texts
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
text = self.texts[idx]
|
||||||
|
|
||||||
|
# Tokenize with padding and truncation
|
||||||
|
encoding = self.tokenizer(
|
||||||
|
text,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# Remove batch dimension added by tokenizer
|
||||||
|
return {
|
||||||
|
'input_ids': encoding['input_ids'].squeeze(0),
|
||||||
|
'attention_mask': encoding['attention_mask'].squeeze(0)
|
||||||
|
}
|
||||||
|
def collate_fn(batch: List[Dict]):
|
||||||
|
"""
|
||||||
|
Custom collate function to handle batching
|
||||||
|
"""
|
||||||
|
input_ids = torch.stack([item['input_ids'] for item in batch])
|
||||||
|
attention_mask = torch.stack([item['attention_mask'] for item in batch])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'attention_mask': attention_mask
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_text(
|
||||||
|
texts: List[str],
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
batch_size: int = 6, # Smaller batch size uses less VRAM
|
||||||
|
device: str = 'cuda',
|
||||||
|
max_length: int = 512,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
**generate_kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Optimized text generation function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: HuggingFace model
|
||||||
|
tokenizer: HuggingFace tokenizer
|
||||||
|
texts: List of input texts
|
||||||
|
batch_size: Batch size for processing
|
||||||
|
device: Device to run inference on
|
||||||
|
max_length: Maximum input sequence length
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
generate_kwargs: Additional kwargs for model.generate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of generated texts
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create dataset and dataloader
|
||||||
|
dataset = TranslationDataset(texts, tokenizer, max_length)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
all_generated_texts = []
|
||||||
|
|
||||||
|
# Default generation parameters
|
||||||
|
generation_config = {
|
||||||
|
'max_new_tokens': max_new_tokens,
|
||||||
|
'num_beams': 4,
|
||||||
|
'do_sample': True,
|
||||||
|
'top_k': 50,
|
||||||
|
'top_p': 0.95,
|
||||||
|
'temperature': 0.7,
|
||||||
|
'no_repeat_ngram_size': 2,
|
||||||
|
'pad_token_id': tokenizer.pad_token_id,
|
||||||
|
'eos_token_id': tokenizer.eos_token_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update with user-provided parameters
|
||||||
|
generation_config.update(generate_kwargs)
|
||||||
|
|
||||||
|
# Perform generation
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in dataloader:
|
||||||
|
# Move batch to device
|
||||||
|
batch = {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids=batch['input_ids'],
|
||||||
|
attention_mask=batch['attention_mask'],
|
||||||
|
**generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode generated tokens
|
||||||
|
decoded_texts = tokenizer.batch_decode(
|
||||||
|
outputs,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
|
||||||
|
all_generated_texts.extend(decoded_texts)
|
||||||
|
|
||||||
|
return all_generated_texts
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||||
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
|
||||||
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
||||||
|
tokenizer.src_lang = "zh"
|
||||||
|
texts = ["你好","我"]
|
||||||
|
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))
|
||||||
@ -1,20 +1,27 @@
|
|||||||
from paddleocr import PaddleOCR
|
from paddleocr import PaddleOCR
|
||||||
import easyocr
|
import easyocr
|
||||||
|
from typing import Optional
|
||||||
from rapidocr_onnxruntime import RapidOCR
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
import langid
|
import langid, sys,os
|
||||||
from helpers.utils import contains_lang
|
from utils import contains_lang, standardize_lang
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
from logging_config import logger
|
||||||
# PaddleOCR
|
# PaddleOCR
|
||||||
# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
|
# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
|
||||||
# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
|
# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
|
||||||
# to switch the language model in order.
|
# to switch the language model in order.
|
||||||
# need to run only once to download and load model into memory
|
# need to run only once to download and load model into memory
|
||||||
|
|
||||||
def _paddle_init(lang='ch', use_angle_cls=False, use_GPU=True):
|
default_languages = ['en', 'ch', 'ja', 'ko']
|
||||||
return PaddleOCR(use_angle_cls=use_angle_cls, lang=lang, use_GPU=use_GPU)
|
|
||||||
|
|
||||||
|
def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
|
||||||
|
return PaddleOCR(use_angle_cls=use_angle_cls, lang=paddle_lang, use_GPU=use_GPU, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _paddle_ocr(ocr, image) -> list:
|
def _paddle_ocr(ocr, image) -> list:
|
||||||
|
|
||||||
### return a list containing the bounding box, text and confidence of the detected text
|
### return a list containing the bounding box, text and confidence of the detected text
|
||||||
result = ocr.ocr(image, cls=False)[0]
|
result = ocr.ocr(image, cls=False)[0]
|
||||||
if not isinstance(result, list):
|
if not isinstance(result, list):
|
||||||
@ -24,31 +31,34 @@ def _paddle_ocr(ocr, image) -> list:
|
|||||||
|
|
||||||
# EasyOCR has support for many languages
|
# EasyOCR has support for many languages
|
||||||
|
|
||||||
def _easy_init(ocr_languages: list, use_GPU=True):
|
def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
|
||||||
return easyocr.Reader(ocr_languages, gpu=use_GPU)
|
langs = []
|
||||||
|
for lang in easy_languages:
|
||||||
|
langs.append(standardize_lang(lang)['easyocr_lang'])
|
||||||
|
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
|
||||||
|
|
||||||
def _easy_ocr(ocr,image) -> list:
|
def _easy_ocr(ocr,image) -> list:
|
||||||
return ocr.readtext(image)
|
return ocr.readtext(image)
|
||||||
|
|
||||||
# RapidOCR mostly for mandarin and some other asian languages
|
# RapidOCR mostly for mandarin and some other asian languages
|
||||||
|
|
||||||
def _rapid_init(use_GPU=True):
|
def _rapid_init(use_GPU=True, **kwargs):
|
||||||
return RapidOCR(use_gpu=use_GPU)
|
return RapidOCR(use_gpu=use_GPU, **kwargs)
|
||||||
|
|
||||||
def _rapid_ocr(ocr, image) -> list:
|
def _rapid_ocr(ocr, image) -> list:
|
||||||
return ocr(image)
|
return ocr(image)
|
||||||
|
|
||||||
### Initialize the OCR model
|
### Initialize the OCR model
|
||||||
def init_OCR(model='paddle', **kwargs):
|
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs):
|
||||||
if model == 'paddle':
|
if model == 'paddle':
|
||||||
return _paddle_init(**kwargs)
|
return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
|
||||||
elif model == 'easy':
|
elif model == 'easy':
|
||||||
return _easy_init(**kwargs)
|
return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU)
|
||||||
elif model == 'rapid':
|
elif model == 'rapid':
|
||||||
return _rapid_init(**kwargs)
|
return _rapid_init(use_GPU=use_GPU)
|
||||||
|
|
||||||
### Perform OCR on the image
|
### Perform OCR on the image
|
||||||
def identify(ocr, image) -> list:
|
def _identify(ocr, image) -> list:
|
||||||
if isinstance(ocr, PaddleOCR):
|
if isinstance(ocr, PaddleOCR):
|
||||||
return _paddle_ocr(ocr, image)
|
return _paddle_ocr(ocr, image)
|
||||||
elif isinstance(ocr, easyocr.Reader):
|
elif isinstance(ocr, easyocr.Reader):
|
||||||
@ -56,13 +66,14 @@ def identify(ocr, image) -> list:
|
|||||||
elif isinstance(ocr, RapidOCR):
|
elif isinstance(ocr, RapidOCR):
|
||||||
return _rapid_ocr(ocr, image)
|
return _rapid_ocr(ocr, image)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to identify().")
|
raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to _identify().")
|
||||||
|
|
||||||
|
|
||||||
### Filter out the results that are not in the source language
|
### Filter out the results that are not in the source language. Slower but for a wider range of languages
|
||||||
def id_filtered(ocr, image, lang) -> list:
|
# not working but also not very reliable so don't worry about it
|
||||||
result = identify(ocr, image)
|
def _id_filtered(ocr, image, lang) -> list:
|
||||||
|
lang = standardize_lang(lang)['id_model_lang']
|
||||||
|
result = _identify(ocr, image)
|
||||||
### Parallelise since langid is slow
|
### Parallelise since langid is slow
|
||||||
def classify_text(entry):
|
def classify_text(entry):
|
||||||
return entry if langid.classify(entry[1])[0] == lang else None
|
return entry if langid.classify(entry[1])[0] == lang else None
|
||||||
@ -71,12 +82,29 @@ def id_filtered(ocr, image, lang) -> list:
|
|||||||
return results_no_eng
|
return results_no_eng
|
||||||
|
|
||||||
|
|
||||||
# zh, ja, ko
|
# ch_sim, ch_tra, ja, ko, en
|
||||||
def id_lang(ocr, image, lang) -> list:
|
def _id_lang(ocr, image, lang) -> list:
|
||||||
result = identify(ocr, image)
|
result = _identify(ocr, image)
|
||||||
|
lang = standardize_lang(lang)['id_model_lang']
|
||||||
|
try:
|
||||||
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
|
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
|
||||||
|
except:
|
||||||
|
logger.error(f"Selected language not part of default: {default_languages}.")
|
||||||
|
raise ValueError(f"Selected language not part of default: {default_languages}.")
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
def id_keep_source_lang(ocr, image, lang) -> list:
|
||||||
|
try:
|
||||||
|
return _id_lang(ocr, image, lang)
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
return _id_filtered(ocr, image, lang)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Probably an issue with the _id_filtered function. {e}')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_words(ocr_output) -> list:
|
def get_words(ocr_output) -> list:
|
||||||
return [entry[1] for entry in ocr_output]
|
return [entry[1] for entry in ocr_output]
|
||||||
|
|
||||||
@ -85,3 +113,12 @@ def get_positions(ocr_output) -> list:
|
|||||||
|
|
||||||
def get_confidences(ocr_output) -> list:
|
def get_confidences(ocr_output) -> list:
|
||||||
return [entry[2] for entry in ocr_output]
|
return [entry[2] for entry in ocr_output]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# OCR_languages = ['ch_sim','en']
|
||||||
|
# image_old = '/home/James/Pictures/Screenshots/DP-1.jpg'
|
||||||
|
# reader = easyocr.Reader(OCR_languages, gpu=True) # this needs to run only once to load the model into memory
|
||||||
|
# result = reader.readtext(image_old)
|
||||||
|
# print(result)
|
||||||
|
print(id_keep_source_lang(init_OCR(model='paddle', paddle_lang='zh', easy_languages=['en', 'ch_sim']), '/home/James/Pictures/Screenshots/DP-1.jpg', 'ch_sim'))
|
||||||
@ -1,76 +1,217 @@
|
|||||||
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig
|
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
||||||
import torch, os
|
import google.generativeai as genai
|
||||||
from dotenv import load_dotenv
|
import torch, os, sys, ast
|
||||||
load_dotenv()
|
from utils import standardize_lang
|
||||||
|
from functools import wraps
|
||||||
|
from batching import generate_text, Gemini
|
||||||
|
from logging_config import logger
|
||||||
|
from multiprocessing import Process,Event
|
||||||
|
# root dir
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
if os.getenv('TRANSLATION_USE_GPU') in ['False', '0', 'false', 'no', 'No', 'NO', 'FALSE']:
|
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||||
device = torch.device("cpu")
|
|
||||||
else:
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
### Batch translate a list of strings
|
##############################
|
||||||
|
# translation decorator
|
||||||
|
def translate(translation_func):
|
||||||
|
@wraps(translation_func)
|
||||||
|
def wrapper(text, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
if len(text) == 0:
|
||||||
|
return []
|
||||||
|
return translation_func(text, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Translation error with the following function: {translation_func.__name__}. Text: {text}\nError: {e}")
|
||||||
|
return wrapper
|
||||||
|
###############################
|
||||||
|
|
||||||
|
|
||||||
|
###############################
|
||||||
|
def init_GEMINI(models_and_rates = None):
|
||||||
|
if not models_and_rates:
|
||||||
|
## this is default for free tier
|
||||||
|
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref
|
||||||
|
models = [Gemini(name, rate) for name, rate in models_and_rates.items()]
|
||||||
|
for model in models:
|
||||||
|
model.start()
|
||||||
|
genai.configure(api_key=GEMINI_KEY)
|
||||||
|
return models
|
||||||
|
|
||||||
|
def translate_GEMINI(text, models, from_lang, target_lang):
|
||||||
|
safety_settings = {
|
||||||
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||||
|
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
|
||||||
|
for model in models:
|
||||||
|
if model.curr_calls < model.rate:
|
||||||
|
try:
|
||||||
|
response = genai.GenerativeModel(model.name).generate_content(prompt,
|
||||||
|
safety_settings=safety_settings)
|
||||||
|
model.curr_calls += 1
|
||||||
|
logger.info(repr(model))
|
||||||
|
logger.info(f'Model Response: {response.text.strip()}')
|
||||||
|
return ast.literal_eval(response.text.strip())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error with model {model.name}. Error: {e}")
|
||||||
|
logger.error("No models available to translate. Please wait for a model to be available.")
|
||||||
|
|
||||||
|
|
||||||
|
###############################
|
||||||
|
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
||||||
|
def init_AYA():
|
||||||
|
model_id = "CohereForAI/aya-23-8B"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, locals_files_only=True, torch_dtype=torch.float16).to(device)
|
||||||
|
model.eval()
|
||||||
|
return (model, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##############################
|
||||||
# M2M100 model
|
# M2M100 model
|
||||||
|
|
||||||
|
|
||||||
def init_M2M():
|
def init_M2M():
|
||||||
global tokenizer, model
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
|
||||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
|
||||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True, torch_dtype=torch.float16).to(device)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
return (model, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def translate_M2M(text, model, tokenizer, from_lang = 'ch_sim', target_lang = 'en') -> list[str]:
|
||||||
def translate_M2M(text, from_lang = 'zh', target_lang = 'en'):
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
||||||
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
return []
|
return []
|
||||||
tokenizer.src_lang = from_lang
|
tokenizer.src_lang = model_lang_from
|
||||||
with torch.no_grad():
|
generated_translations = generate_text(text, model,tokenizer, batch_size=BATCH_SIZE,
|
||||||
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS,
|
||||||
generated_tokens = model.generate(**encoded,
|
forced_bos_token_id=tokenizer.get_lang_id(model_lang_to))
|
||||||
forced_bos_token_id=tokenizer.get_lang_id(target_lang))
|
return generated_translations
|
||||||
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
||||||
return translated
|
###############################
|
||||||
|
|
||||||
|
|
||||||
|
###############################
|
||||||
# Helsinki-NLP model Opus MT
|
# Helsinki-NLP model Opus MT
|
||||||
|
# Refer here for all the models https://huggingface.co/Helsinki-NLP
|
||||||
|
def get_OPUS_model(from_lang, target_lang):
|
||||||
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
||||||
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
||||||
|
return f"Helsinki-NLP/opus-mt-{model_lang_from}-{model_lang_to}"
|
||||||
|
|
||||||
|
|
||||||
def init_OPUS():
|
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
|
||||||
global tokenizer, model
|
opus_model = get_OPUS_model(from_lang, target_lang)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True)
|
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-tc-bible-big-zhx-en", local_files_only=True, torch_dtype=torch.float16).to(device)
|
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
return (model, tokenizer)
|
||||||
|
|
||||||
def translate_OPUS(text: list[str]) -> list[str]:
|
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
|
||||||
|
translated_text = generate_text(model,tokenizer, text,
|
||||||
|
batch_size=BATCH_SIZE, device=device,
|
||||||
|
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
|
||||||
|
return translated_text
|
||||||
|
|
||||||
|
###############################
|
||||||
|
|
||||||
|
|
||||||
|
def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
|
||||||
|
if model_type == 'opus':
|
||||||
|
return init_OPUS(**kwargs)
|
||||||
|
elif model_type == 'm2m':
|
||||||
|
return init_M2M()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
||||||
|
|
||||||
|
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
|
||||||
|
if model_type == 'gemini':
|
||||||
|
return init_GEMINI(**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
|
||||||
|
|
||||||
|
def init_Causal_LLM(model_type, **kwargs):
|
||||||
|
pass
|
||||||
|
###
|
||||||
|
@translate
|
||||||
|
def translate_Seq_LLM(text,
|
||||||
|
model_type, # 'opus' or 'm2m'
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
**kwargs):
|
||||||
|
if model_type == 'opus':
|
||||||
|
return translate_OPUS(text, model, tokenizer)
|
||||||
|
elif model_type == 'm2m':
|
||||||
|
try:
|
||||||
|
return translate_M2M(text, model, tokenizer, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error with M2M model. Error: {e}")
|
||||||
|
# raise ValueError(f"Please provide the correct from_lang and target_lang variables if you are using the M2M model. Use the list from {available_langs}.")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
||||||
|
|
||||||
|
|
||||||
|
### if you want to use any other translation, just define a translate function with input text and output text.
|
||||||
|
# def translate_api(text):
|
||||||
|
#@translate
|
||||||
|
#def translate_Causal_LLM(text, model_type, model)
|
||||||
|
|
||||||
|
@translate
|
||||||
|
def translate_API_LLM(text: list[str],
|
||||||
|
model_type: str, # 'gemma'
|
||||||
|
models: list, # list of objects of classes defined in batching.py
|
||||||
|
from_lang: str, # suggested to use ISO 639-1 codes
|
||||||
|
target_lang: str # suggested to use ISO 639-1 codes
|
||||||
|
) -> list[str]:
|
||||||
|
if model_type == 'gemini':
|
||||||
|
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||||
|
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||||
|
return translate_GEMINI(text, models, from_lang, target_lang)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
|
||||||
|
|
||||||
|
@translate
|
||||||
|
def translate_Causal_LLM(text: list[str],
|
||||||
|
model_type, # aya
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
from_lang: str,
|
||||||
|
target_lang: str) -> list[str]:
|
||||||
|
model_lang_from = standardize_lang(from_lang)['translation_model_lang']
|
||||||
|
model_lang_to = standardize_lang(target_lang)['translation_model_lang']
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
return []
|
return []
|
||||||
with torch.no_grad():
|
pass
|
||||||
encoded = tokenizer(text, return_tensors="pt", padding=True).to(device)
|
|
||||||
generated_tokens = model.generate(**encoded)
|
|
||||||
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
||||||
return translated
|
|
||||||
|
|
||||||
###
|
|
||||||
def init_TRANSLATE(model): # model = 'opus' or 'm2m'
|
# choose between local Seq2Seq LLM or obtain translations from an API
|
||||||
if model == 'opus':
|
def init_func(model):
|
||||||
init_OPUS()
|
if model in seq_llm_models:
|
||||||
elif model == 'm2m':
|
return init_Seq_LLM
|
||||||
init_M2M()
|
elif model in api_llm_models:
|
||||||
|
return init_API_LLM
|
||||||
|
elif model in causal_llm_models:
|
||||||
|
return init_Causal_LLM
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.")
|
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
|
||||||
|
|
||||||
|
|
||||||
###
|
|
||||||
def translate(text, model, **kwargs):
|
|
||||||
if model == 'opus':
|
|
||||||
return translate_OPUS(text)
|
|
||||||
elif model == 'm2m':
|
|
||||||
try:
|
|
||||||
return translate_M2M(text, **kwargs)
|
|
||||||
except:
|
|
||||||
raise ValueError("Please provide the from_lang and target_lang variables if you are using the M2M model.")
|
|
||||||
|
|
||||||
|
def translate_func(model):
|
||||||
|
if model in seq_llm_models:
|
||||||
|
return translate_Seq_LLM
|
||||||
|
elif model in api_llm_models:
|
||||||
|
return translate_API_LLM
|
||||||
|
elif model in causal_llm_models:
|
||||||
|
return translate_Causal_LLM
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid model. Please use 'opus' or 'm2m'.")
|
raise ValueError("Invalid model category. Please use either 'seq' or 'api'.")
|
||||||
|
|
||||||
|
|
||||||
|
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
|
||||||
|
if __name__ == "__main__":
|
||||||
|
models = init_GEMINI()
|
||||||
|
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en'))
|
||||||
|
# model, tokenizer = init_M2M()
|
||||||
|
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))
|
||||||
|
|||||||
@ -3,12 +3,15 @@ from pypinyin import pinyin
|
|||||||
import pyscreenshot as ImageGrab # wayland tings not sure if it will work on other machines alternatively use mss
|
import pyscreenshot as ImageGrab # wayland tings not sure if it will work on other machines alternatively use mss
|
||||||
import mss, io, os
|
import mss, io, os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import jaconv
|
import jaconv, MeCab, unidic, pykakasi
|
||||||
import MeCab
|
|
||||||
import unidic
|
# for creating furigana
|
||||||
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
|
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
|
||||||
uroman = ur.Uroman()
|
uroman = ur.Uroman()
|
||||||
|
|
||||||
|
# for romanising japanese text. Can convert to hiragana or katakana as well but does not split the words up so harder to use for furigana
|
||||||
|
kks = pykakasi.kakasi()
|
||||||
|
|
||||||
|
|
||||||
# define a function for checking whether one axis of a shape intercepts with another
|
# define a function for checking whether one axis of a shape intercepts with another
|
||||||
def intercepts(x,y):
|
def intercepts(x,y):
|
||||||
@ -71,7 +74,10 @@ def add_furigana(text):
|
|||||||
furigana_string = ''
|
furigana_string = ''
|
||||||
for i in parsed:
|
for i in parsed:
|
||||||
words = i.split('\t')[0]
|
words = i.split('\t')[0]
|
||||||
|
try :
|
||||||
add = f'({jaconv.kata2hira(i.split(',')[6])})'
|
add = f'({jaconv.kata2hira(i.split(',')[6])})'
|
||||||
|
except:
|
||||||
|
add = ''
|
||||||
to_add = add if contains_kanji(words) else ''
|
to_add = add if contains_kanji(words) else ''
|
||||||
furigana_string += i.split('\t')[0] + to_add
|
furigana_string += i.split('\t')[0] + to_add
|
||||||
return furigana_string
|
return furigana_string
|
||||||
@ -87,10 +93,12 @@ def contains_katakana(text):
|
|||||||
return bool(re.search(r'[\u30A0-\u30FF]', text))
|
return bool(re.search(r'[\u30A0-\u30FF]', text))
|
||||||
|
|
||||||
|
|
||||||
|
# use kakasi to romanize japanese text
|
||||||
def romanize(text, piny=False):
|
def romanize(text, lang):
|
||||||
if piny:
|
if lang == 'zh':
|
||||||
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
|
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
|
||||||
|
if lang == 'ja':
|
||||||
|
return kks.convert(text)[0]['hepburn']
|
||||||
return uroman.romanize_string(text)
|
return uroman.romanize_string(text)
|
||||||
|
|
||||||
# check if a string contains words from a language
|
# check if a string contains words from a language
|
||||||
@ -107,36 +115,45 @@ def contains_lang(text, lang):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid language. Please use one of 'en', 'zh', 'ja', or 'ko'.")
|
raise ValueError("Invalid language. Please use one of 'en', 'zh', 'ja', or 'ko'.")
|
||||||
|
|
||||||
### en, ch_sim, ja, ko rapidocr only has chinese and en at the moment
|
### en, ch_sim, ch_tra, ja, ko rapidocr only has chinese and en at the moment
|
||||||
def standardize_lang(lang):
|
def standardize_lang(lang):
|
||||||
if lang == 'ch_sim':
|
if lang == 'ch_sim':
|
||||||
easyocr_lang = 'ch_sim'
|
easyocr_lang = 'ch_sim'
|
||||||
paddleocr_lang = 'ch'
|
paddleocr_lang = 'ch'
|
||||||
rapidocr_lang = 'ch'
|
rapidocr_lang = 'ch'
|
||||||
translation_model_lang = 'zh'
|
translation_model_lang = 'zh'
|
||||||
|
id_model_lang = 'zh'
|
||||||
elif lang == 'ch_tra':
|
elif lang == 'ch_tra':
|
||||||
easyocr_lang = 'ch_tra'
|
easyocr_lang = 'ch_tra'
|
||||||
paddleocr_lang = 'ch'
|
paddleocr_lang = 'ch'
|
||||||
rapidocr_lang = 'ch'
|
rapidocr_lang = 'ch'
|
||||||
translation_model_lang = 'zh'
|
translation_model_lang = 'zh'
|
||||||
|
id_model_lang = 'zh'
|
||||||
elif lang == 'ja':
|
elif lang == 'ja':
|
||||||
easyocr_lang = 'ja'
|
easyocr_lang = 'ja'
|
||||||
paddleocr_lang = 'ja'
|
paddleocr_lang = 'ja'
|
||||||
rapidocr_lang = 'ja'
|
rapidocr_lang = 'ja'
|
||||||
translation_model_lang = 'ja'
|
translation_model_lang = 'ja'
|
||||||
|
id_model_lang = 'ja'
|
||||||
elif lang == 'ko':
|
elif lang == 'ko':
|
||||||
easyocr_lang = 'korean'
|
easyocr_lang = 'korean'
|
||||||
paddleocr_lang = 'ko'
|
paddleocr_lang = 'ko'
|
||||||
rapidocr_lang = 'ko'
|
rapidocr_lang = 'ko'
|
||||||
translation_model_lang = 'ko'
|
translation_model_lang = 'ko'
|
||||||
|
id_model_lang = 'ko'
|
||||||
elif lang == 'en':
|
elif lang == 'en':
|
||||||
easyocr_lang = 'en'
|
easyocr_lang = 'en'
|
||||||
paddleocr_lang = 'en'
|
paddleocr_lang = 'en'
|
||||||
rapidocr_lang = 'en'
|
rapidocr_lang = 'en'
|
||||||
translation_model_lang = 'en'
|
translation_model_lang = 'en'
|
||||||
|
id_model_lang = 'en'
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid language. Please use one of 'en', 'ch_sim', 'ch_tra', 'ja', or 'ko'.")
|
raise ValueError(f"Invalid language {lang}. Please use one of 'en', 'ch_sim', 'ch_tra', 'ja', or 'ko'.")
|
||||||
return {'easyocr_lang': easyocr_lang, 'paddleocr_lang': paddleocr_lang, 'rapidocr_lang': rapidocr_lang, 'translation_model_lang': translation_model_lang}
|
return {'easyocr_lang': easyocr_lang,
|
||||||
|
'paddleocr_lang': paddleocr_lang,
|
||||||
|
'rapidocr_lang': rapidocr_lang,
|
||||||
|
'translation_model_lang': translation_model_lang,
|
||||||
|
'id_model_lang': id_model_lang}
|
||||||
|
|
||||||
def which_ocr_lang(model):
|
def which_ocr_lang(model):
|
||||||
if model == 'easy':
|
if model == 'easy':
|
||||||
|
|||||||
@ -1,17 +1,42 @@
|
|||||||
import logging, os
|
import logging, os
|
||||||
|
from typing import Optional
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Configure the logger
|
def setup_logger(
|
||||||
|
name: str,
|
||||||
|
log_file: Optional[str] = None,
|
||||||
|
level: int = logging.INFO
|
||||||
|
) -> Optional[logging.Logger]:
|
||||||
|
"""
|
||||||
|
Set up a logger with the specified name and level.
|
||||||
|
|
||||||
def setup_logger(name: str, log_file: str = None, level: int = logging.INFO) -> logging.Logger:
|
Args:
|
||||||
"""Set up a logger with the specified name and level."""
|
name: Logger name
|
||||||
|
log_file: Path to log file (defaults to name.log)
|
||||||
|
level: Logging level (defaults to INFO)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Logger object if successful, None if setup fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
if log_file is None:
|
if log_file is None:
|
||||||
log_file = f"{name}.log"
|
log_file = f"{name}.log"
|
||||||
|
|
||||||
|
# Validate logging level
|
||||||
|
valid_levels = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL]
|
||||||
|
if level not in valid_levels:
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
# Create a logger
|
# Create a logger
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
# Clear existing handlers
|
||||||
|
if logger.handlers:
|
||||||
|
logger.handlers.clear()
|
||||||
|
|
||||||
# Create file handler
|
# Create file handler
|
||||||
file_handler = logging.FileHandler(log_file)
|
file_handler = logging.FileHandler(log_file)
|
||||||
@ -22,15 +47,21 @@ def setup_logger(name: str, log_file: str = None, level: int = logging.INFO) ->
|
|||||||
console_handler.setLevel(level)
|
console_handler.setLevel(level)
|
||||||
|
|
||||||
# Create a formatter and set it for both handlers
|
# Create a formatter and set it for both handlers
|
||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
|
formatter = logging.Formatter(
|
||||||
datefmt='%Y-%m-%d %H:%M:%S')
|
'%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
|
||||||
# Add handlers to the logger
|
# Add handlers to the logger
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to setup logger: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger = setup_logger('on_screen_translator', log_file='translate.log')
|
||||||
@ -10,7 +10,7 @@
|
|||||||
<img
|
<img
|
||||||
id="live-image"
|
id="live-image"
|
||||||
src="/image"
|
src="/image"
|
||||||
alt="Live Image"
|
alt="No Translations Available"
|
||||||
style="max-width: 100%; height: auto" />
|
style="max-width: 100%; height: auto" />
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
|
|||||||
68
translate.py
Normal file
68
translate.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
###################################################################################
|
||||||
|
##### IMPORT LIBRARIES #####
|
||||||
|
import os, time, sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||||
|
|
||||||
|
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||||
|
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||||
|
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||||
|
from logging_config import logger
|
||||||
|
from draw import modify_image_bytes
|
||||||
|
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||||
|
###################################################################################
|
||||||
|
|
||||||
|
latest_image = None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global latest_image
|
||||||
|
|
||||||
|
##### Initialize the OCR #####
|
||||||
|
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||||
|
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||||
|
|
||||||
|
##### Initialize the translation #####
|
||||||
|
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||||
|
models = init_API_LLM(TRANSLATION_MODEL)
|
||||||
|
###################################################################################
|
||||||
|
runs = 0
|
||||||
|
while True:
|
||||||
|
untranslated_image = printsc(REGION)
|
||||||
|
byte_image = convert_image_to_bytes(untranslated_image)
|
||||||
|
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||||
|
|
||||||
|
if runs == 0:
|
||||||
|
logger.info('Initial run')
|
||||||
|
prev_words = set()
|
||||||
|
else:
|
||||||
|
logger.info(f'Run number: {runs}.')
|
||||||
|
runs += 1
|
||||||
|
|
||||||
|
curr_words = set(get_words(ocr_output))
|
||||||
|
|
||||||
|
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||||
|
if prev_words != curr_words:
|
||||||
|
logger.info('Translating')
|
||||||
|
|
||||||
|
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||||
|
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||||
|
translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||||
|
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||||
|
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||||
|
latest_image = bytes_to_image(translated_image)
|
||||||
|
# latest_image.show() # for debugging
|
||||||
|
|
||||||
|
prev_words = curr_words
|
||||||
|
else:
|
||||||
|
logger.info("No new words to translate. Output will not refresh.")
|
||||||
|
|
||||||
|
logger.info(f'Sleeping for {INTERVAL} seconds')
|
||||||
|
time.sleep(INTERVAL)
|
||||||
|
|
||||||
|
################### TODO ##################
|
||||||
|
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||||
|
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
|
||||||
|
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,14 +1,12 @@
|
|||||||
from flask import Flask, Response, render_template
|
from flask import Flask, Response, render_template
|
||||||
import time
|
|
||||||
import threading
|
import threading
|
||||||
from PIL import Image
|
|
||||||
import io
|
import io
|
||||||
import chinese_to_eng
|
import translate
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Global variable to hold the current image
|
# Global variable to hold the current image
|
||||||
def curr_image():
|
def curr_image():
|
||||||
return chinese_to_eng.latest_image
|
return translate.latest_image
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
@ -19,8 +17,6 @@ def index():
|
|||||||
def stream_image():
|
def stream_image():
|
||||||
if curr_image() is None:
|
if curr_image() is None:
|
||||||
return "No image generated yet.", 503
|
return "No image generated yet.", 503
|
||||||
print('streaming')
|
|
||||||
print(curr_image())
|
|
||||||
file_object = io.BytesIO()
|
file_object = io.BytesIO()
|
||||||
curr_image().save(file_object, 'PNG')
|
curr_image().save(file_object, 'PNG')
|
||||||
file_object.seek(0)
|
file_object.seek(0)
|
||||||
@ -33,7 +29,7 @@ def stream_image():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Start the image updating thread
|
# Start the image updating thread
|
||||||
threading.Thread(target=chinese_to_eng.main, daemon=True).start()
|
threading.Thread(target=translate.main, daemon=True).start()
|
||||||
|
|
||||||
# Start the Flask web server
|
# Start the Flask web server
|
||||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||||
Loading…
Reference in New Issue
Block a user