onscreen-translator/main.py

101 lines
4.8 KiB
Python

###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys, threading, subprocess
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image, similar_tfidf
from ocr import get_words, init_OCR, id_keep_source_lang
from data import Base, engine, create_tables
from draw import modify_image_bytes
import config, asyncio
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD
from logging_config import logger
import web_app
import view_buffer_app
###################################################################################
async def main():
###################################################################################
# Initialisation
##### Create the database if not present #####
create_tables()
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
###################################################################################
runs = 0
# label, app = view_buffer_app.create_viewer()
# try:
while True:
logger.debug("Capturing screen")
untranslated_image = printsc(REGION)
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.info(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if not similar_tfidf(list(curr_words), list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD) and prev_words != curr_words:
logger.info('Beginning Translation')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
try:
translation = await translate_API_LLM(to_translate, models, call_size = 3)
except TypeError as e:
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for 30 seconds.")
time.sleep(30)
continue
logger.debug('Translation complete. Modifying image.')
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
# view_buffer_app.show_buffer_image(translated_image, label)
web_app.latest_image = bytes_to_image(translated_image)
logger.debug("Image modified. Saving image.")
# web_app.latest_image.save('/home/James/Pictures/translated.png') # home use
# logger.debug("Image saved.")
prev_words = curr_words
else:
logger.info("Skipping translation. No significant change in the screen detected.")
logger.debug("Continuing to next iteration.")
# logger.debug(f'Sleeping for {INTERVAL} seconds')
asyncio.sleep(INTERVAL)
# finally:
# label.close()
# app.quit()
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
if __name__ == "__main__":
# subprocess.Popen(['feh','--auto-reload', '/home/James/Pictures/translated.png'])
# asyncio.run(main())
# Start the image updating thread
logger.info('Configuration:')
for i in dir(config):
if not callable(getattr(config, i)) and not i.startswith("__"):
logger.info(f'{i}: {getattr(config, i)}')
threading.Thread(target=asyncio.run, args=(main(),), daemon=True).start()
# Start the Flask web server
web_app.app.run(host='0.0.0.0', port=5000, debug=False)