114 lines
4.7 KiB
Python
114 lines
4.7 KiB
Python
from flask import Flask, Response, render_template
|
|
import threading
|
|
import io
|
|
import os, time, sys, threading, subprocess
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
|
|
|
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
|
from utils import printsc, convert_image_to_bytes, bytes_to_image, check_similarity, is_wayland
|
|
from ocr import get_words, init_OCR, id_keep_source_lang
|
|
from data import Base, engine, create_tables
|
|
from draw import modify_image
|
|
import asyncio
|
|
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD, TEMP_IMG_PATH
|
|
from logging_config import logger
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
latest_image = None
|
|
|
|
async def web_app_main():
|
|
###################################################################################
|
|
global latest_image
|
|
# Initialisation
|
|
##### Create the database if not present #####
|
|
create_tables()
|
|
|
|
##### Initialize the OCR #####
|
|
OCR_LANGUAGES = [SOURCE_LANG, 'en']
|
|
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
|
|
|
##### Initialize the translation #####
|
|
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
|
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
|
###################################################################################
|
|
|
|
runs = 0
|
|
|
|
# label, app = view_buffer_app.create_viewer()
|
|
|
|
# try:
|
|
while True:
|
|
logger.debug("Capturing screen")
|
|
printsc(REGION, TEMP_IMG_PATH)
|
|
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
|
|
ocr_output = id_keep_source_lang(ocr, TEMP_IMG_PATH, SOURCE_LANG) # keep only phrases containing the source language
|
|
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
|
|
if runs == 0:
|
|
logger.info('Initial run')
|
|
prev_words = set()
|
|
else:
|
|
logger.debug(f'Run number: {runs}.')
|
|
runs += 1
|
|
|
|
curr_words = set(get_words(ocr_output))
|
|
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
|
|
|
|
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
|
if prev_words != curr_words and not check_similarity(list(curr_words),list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD, method="tfidf"):
|
|
logger.info('Beginning Translation')
|
|
|
|
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
|
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
|
try:
|
|
if len(to_translate) == 0:
|
|
logger.info("No text detected. Skipping translation. Continuing to next iteration.")
|
|
continue
|
|
translation = await translate_API_LLM(to_translate, models, call_size = 3)
|
|
except TypeError as e:
|
|
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for {2*INTERVAL} seconds.")
|
|
time.sleep(2*INTERVAL)
|
|
continue
|
|
logger.debug('Translation complete. Modifying image.')
|
|
translated_image = modify_image(TEMP_IMG_PATH, ocr_output, translation)
|
|
latest_image = bytes_to_image(translated_image)
|
|
logger.debug("Image modified. Saving image.")
|
|
prev_words = curr_words
|
|
else:
|
|
logger.info(f"Skipping translation. No significant change in the screen detected. Total translation attempts so far: {runs}.")
|
|
logger.debug("Continuing to next iteration.")
|
|
time.sleep(INTERVAL)
|
|
|
|
# Global variable to hold the current image
|
|
def curr_image():
|
|
return latest_image
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return render_template('index.html')
|
|
|
|
|
|
@app.route('/image')
|
|
def stream_image():
|
|
if curr_image() is None:
|
|
return "No image generated yet.", 503
|
|
file_object = io.BytesIO()
|
|
curr_image().save(file_object, 'PNG')
|
|
file_object.seek(0)
|
|
response = Response(file_object.getvalue(), mimetype='image/png')
|
|
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' # HTTP 1.1
|
|
response.headers['Pragma'] = 'no-cache' # HTTP 1.0
|
|
response.headers['Expires'] = '0' # Proxies
|
|
|
|
return response
|
|
|
|
if __name__ == '__main__':
|
|
# Start the image updating thread
|
|
|
|
threading.Thread(target=asyncio.run, args=(web_app_main(),), daemon=True).start()
|
|
|
|
# Start the Flask web server
|
|
app.run(host='0.0.0.0', port=5000, debug=False)
|