Added support for Groqq API models. Created QT6 overlay app.
This commit is contained in:
parent
17e7f6526f
commit
499a2c3972
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ translate/
|
||||
__pycache__/
|
||||
.*
|
||||
test.py
|
||||
|
||||
notebooks/
|
||||
qttest.py
|
||||
@ -1,4 +1,4 @@
|
||||
## Debugging Issues
|
||||
|
||||
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment.
|
||||
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment.
|
||||
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
|
||||
|
||||
13
api_models.json
Normal file
13
api_models.json
Normal file
@ -0,0 +1,13 @@
|
||||
{
|
||||
"Gemini": {
|
||||
"gemini-1.5-pro": 2,
|
||||
"gemini-1.5-flash": 15,
|
||||
"gemini-1.5-flash-8b": 8,
|
||||
"gemini-1.0-pro": 15
|
||||
},
|
||||
"Groqq": {
|
||||
"llama-3.2-90b-text-preview": 30,
|
||||
"llama3-70b-8192": 30,
|
||||
"mixtral-8x7b-32768": 30
|
||||
}
|
||||
}
|
||||
@ -9,9 +9,11 @@ from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||
###################################################################################
|
||||
|
||||
ADD_OVERLAY = False
|
||||
|
||||
latest_image = None
|
||||
|
||||
def main():
|
||||
@ -23,11 +25,21 @@ def main():
|
||||
|
||||
##### Initialize the translation #####
|
||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
models = init_API_LLM(TRANSLATION_MODEL)
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
###################################################################################
|
||||
runs = 0
|
||||
app.exec()
|
||||
while True:
|
||||
if ADD_OVERLAY:
|
||||
overlay.clear_all_text()
|
||||
|
||||
untranslated_image = printsc(REGION)
|
||||
|
||||
if ADD_OVERLAY:
|
||||
overlay.text_entries = overlay.text_entries_copy
|
||||
overlay.update()
|
||||
overlay.text_entries.clear()
|
||||
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||
|
||||
@ -46,7 +58,7 @@ def main():
|
||||
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
translation = translate_API_LLM(to_translate, models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||
latest_image = bytes_to_image(translated_image)
|
||||
@ -58,11 +70,14 @@ def main():
|
||||
|
||||
logger.info(f'Sleeping for {INTERVAL} seconds')
|
||||
time.sleep(INTERVAL)
|
||||
|
||||
# if ADD_OVERLAY:
|
||||
# sys.exit(app.exec())
|
||||
|
||||
################### TODO ##################
|
||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
|
||||
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
sys.exit(main())
|
||||
|
||||
23
config.py
23
config.py
@ -6,6 +6,7 @@ load_dotenv(override=True)
|
||||
### EDIT THESE VARIABLES ###
|
||||
|
||||
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||
|
||||
INTERVAL = int(os.getenv('INTERVAL'))
|
||||
|
||||
### OCR
|
||||
@ -13,24 +14,34 @@ OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy
|
||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
||||
|
||||
### Drawing/Overlay Config
|
||||
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
|
||||
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
|
||||
FONT_FILE = os.getenv('FONT_FILE')
|
||||
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
||||
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
|
||||
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
|
||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
||||
|
||||
|
||||
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
|
||||
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
|
||||
|
||||
### Translation
|
||||
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||
GEMINI_KEY = os.getenv('GEMINI_KEY')
|
||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
||||
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
||||
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
||||
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
||||
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
||||
|
||||
|
||||
### Local Translation
|
||||
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
|
||||
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
|
||||
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
||||
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
||||
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
||||
305
create_overlay.py
Normal file
305
create_overlay.py
Normal file
@ -0,0 +1,305 @@
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
import sys, io, os, signal, time, platform
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
|
||||
from logging_config import logger
|
||||
|
||||
|
||||
def qpixmap_to_bytes(qpixmap):
|
||||
qimage = qpixmap.toImage()
|
||||
buffer = QBuffer()
|
||||
buffer.open(QBuffer.ReadWrite)
|
||||
qimage.save(buffer, "PNG")
|
||||
return qimage
|
||||
|
||||
@dataclass
|
||||
class TextEntry:
|
||||
text: str
|
||||
x: int
|
||||
y: int
|
||||
font: QFont = QFont('Arial', FONT_SIZE)
|
||||
visible: bool = True
|
||||
text_color: Qt.GlobalColor = Qt.GlobalColor.red
|
||||
background_color: Optional[Qt.GlobalColor] = None
|
||||
padding: int = 1
|
||||
|
||||
class TranslationOverlay(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Translation Overlay")
|
||||
self.is_passthrough = True
|
||||
self.text_entries: List[TextEntry] = []
|
||||
self.setup_window_attributes()
|
||||
self.setup_shortcuts()
|
||||
self.closeEvent = lambda event: QApplication.quit()
|
||||
self.default_font = QFont('Arial', FONT_SIZE)
|
||||
self.text_entries_copy: List[TextEntry] = []
|
||||
self.next_text_entries: List[TextEntry] = []
|
||||
#self.show_background = True
|
||||
self.background_opacity = 0.5
|
||||
# self.setup_tray()
|
||||
|
||||
def prepare_for_capture(self):
|
||||
"""Preserve current state and clear overlay"""
|
||||
if ADD_OVERLAY:
|
||||
self.text_entries_copy = self.text_entries.copy()
|
||||
self.clear_all_text()
|
||||
self.update()
|
||||
|
||||
def restore_after_capture(self):
|
||||
"""Restore overlay state after capture"""
|
||||
if ADD_OVERLAY:
|
||||
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
|
||||
self.text_entries = self.text_entries_copy.copy()
|
||||
logger.debug(f"Restored text entries: {self.text_entries}")
|
||||
self.update()
|
||||
|
||||
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
|
||||
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
|
||||
"""Add new text without triggering update"""
|
||||
entry = TextEntry(
|
||||
text=text,
|
||||
x=x,
|
||||
y=y,
|
||||
font=font or self.default_font,
|
||||
text_color=text_color
|
||||
)
|
||||
self.next_text_entries.append(entry)
|
||||
|
||||
def update_translation(self, ocr_output, translation):
|
||||
# Update your overlay with new translations here
|
||||
# You'll need to implement the logic to display the translations
|
||||
self.clear_all_text()
|
||||
self.text_entries = self.next_text_entries.copy()
|
||||
self.next_text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def capture_behind(self, x=None, y=None, width=None, height=None):
|
||||
"""
|
||||
Capture the screen area behind the overlay.
|
||||
If no coordinates provided, captures the area under the window.
|
||||
"""
|
||||
# Temporarily hide the window
|
||||
self.hide()
|
||||
|
||||
# Get screen
|
||||
screen = QScreen.grabWindow(
|
||||
self.screen(),
|
||||
0,
|
||||
x if x is not None else self.x(),
|
||||
y if y is not None else self.y(),
|
||||
width if width is not None else self.width(),
|
||||
height if height is not None else self.height()
|
||||
)
|
||||
|
||||
# Show the window again
|
||||
self.show()
|
||||
screen_bytes = qpixmap_to_bytes(screen)
|
||||
return screen_bytes
|
||||
|
||||
def clear_all_text(self):
|
||||
"""Clear all text entries"""
|
||||
self.text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def setup_window_attributes(self):
|
||||
# Set window flags for overlay behavior
|
||||
self.setWindowFlags(
|
||||
Qt.WindowType.FramelessWindowHint |
|
||||
Qt.WindowType.WindowStaysOnTopHint |
|
||||
Qt.WindowType.Tool
|
||||
)
|
||||
|
||||
# Set attributes for transparency
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
|
||||
|
||||
# Make the window cover the entire screen
|
||||
self.setGeometry(QApplication.primaryScreen().geometry())
|
||||
|
||||
# Special handling for Wayland
|
||||
if platform.system() == "Linux":
|
||||
if "WAYLAND_DISPLAY" in os.environ:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
|
||||
|
||||
def setup_shortcuts(self):
|
||||
# Toggle visibility (Alt+Shift+T)
|
||||
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
|
||||
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
|
||||
|
||||
# Toggle passthrough mode (Alt+Shift+P)
|
||||
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
|
||||
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
|
||||
|
||||
# Quick hide (Escape)
|
||||
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
|
||||
self.hide_shortcut.activated.connect(self.hide)
|
||||
|
||||
# Clear all text (Alt+Shift+C)
|
||||
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
|
||||
self.clear_shortcut.activated.connect(self.clear_all_text)
|
||||
|
||||
# Toggle background
|
||||
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
|
||||
self.toggle_background_shortcut.activated.connect(self.toggle_background)
|
||||
|
||||
def toggle_visibility(self):
|
||||
if self.isVisible():
|
||||
self.hide()
|
||||
else:
|
||||
self.show()
|
||||
|
||||
def toggle_passthrough(self):
|
||||
self.is_passthrough = not self.is_passthrough
|
||||
if self.is_passthrough:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
|
||||
else:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
|
||||
|
||||
self.hide()
|
||||
self.show()
|
||||
|
||||
def toggle_background(self):
|
||||
"""Toggle background visibility"""
|
||||
self.show_background = not self.show_background
|
||||
self.update()
|
||||
|
||||
def set_background_opacity(self, opacity: float):
|
||||
"""Set background opacity (0.0 to 1.0)"""
|
||||
self.background_opacity = max(0.0, min(1.0, opacity))
|
||||
self.update()
|
||||
|
||||
|
||||
|
||||
def add_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Add new text at specific coordinates"""
|
||||
entry = TextEntry(text, x, y)
|
||||
self.text_entries.append(entry)
|
||||
self.update()
|
||||
|
||||
def update_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Update text at specific coordinates, or add if none exists"""
|
||||
# Look for existing text entry near these coordinates (within 5 pixels)
|
||||
for entry in self.text_entries:
|
||||
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
|
||||
entry.text = text
|
||||
self.update()
|
||||
return
|
||||
|
||||
# If no existing entry found, add new one
|
||||
self.add_text_at_position(x, y, text)
|
||||
|
||||
def setup_tray(self):
|
||||
self.tray_icon = QSystemTrayIcon(self)
|
||||
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
|
||||
|
||||
tray_menu = QMenu()
|
||||
|
||||
toggle_action = tray_menu.addAction("Show/Hide Overlay")
|
||||
toggle_action.triggered.connect(self.toggle_visibility)
|
||||
|
||||
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
|
||||
toggle_passthrough.triggered.connect(self.toggle_passthrough)
|
||||
|
||||
# Add background toggle to tray menu
|
||||
toggle_background = tray_menu.addAction("Toggle Background")
|
||||
toggle_background.triggered.connect(self.toggle_background)
|
||||
|
||||
clear_action = tray_menu.addAction("Clear All Text")
|
||||
clear_action.triggered.connect(self.clear_all_text)
|
||||
|
||||
tray_menu.addSeparator()
|
||||
|
||||
quit_action = tray_menu.addAction("Quit")
|
||||
quit_action.triggered.connect(self.clean_exit)
|
||||
|
||||
self.tray_icon.setToolTip("Translation Overlay")
|
||||
self.tray_icon.setContextMenu(tray_menu)
|
||||
self.tray_icon.show()
|
||||
self.tray_icon.activated.connect(self.tray_activated)
|
||||
|
||||
def remove_text_at_position(self, x: int, y: int):
|
||||
"""Remove text entry near specified coordinates"""
|
||||
self.text_entries = [
|
||||
entry for entry in self.text_entries
|
||||
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
|
||||
]
|
||||
self.update()
|
||||
|
||||
def paintEvent(self, event):
|
||||
painter = QPainter(self)
|
||||
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
||||
|
||||
# Draw each text entry
|
||||
for entry in self.text_entries:
|
||||
if not entry.visible:
|
||||
continue
|
||||
|
||||
# Set the font for this specific entry
|
||||
painter.setFont(entry.font)
|
||||
text_metrics = painter.fontMetrics()
|
||||
|
||||
# Get the bounding rectangles for text
|
||||
text_bounds = text_metrics.boundingRect(
|
||||
entry.text
|
||||
)
|
||||
total_width = text_bounds.width()
|
||||
total_height = text_bounds.height()
|
||||
|
||||
# Create rectangles for text placement
|
||||
text_rect = QRect(entry.x, entry.y, total_width, total_height)
|
||||
# Calculate background rectangle that encompasses both texts
|
||||
if entry.background_color is not None:
|
||||
bg_rect = QRect(entry.x - entry.padding,
|
||||
entry.y - entry.padding,
|
||||
total_width + (2 * entry.padding),
|
||||
total_height + (2 * entry.padding))
|
||||
|
||||
painter.setPen(Qt.PenStyle.NoPen)
|
||||
painter.setBrush(entry.background_color)
|
||||
painter.drawRect(bg_rect)
|
||||
|
||||
# Draw the texts
|
||||
painter.setPen(entry.text_color)
|
||||
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
|
||||
|
||||
|
||||
|
||||
def handle_exit(signum, frame):
|
||||
QApplication.quit()
|
||||
|
||||
def start_overlay():
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
# Enable Wayland support if available
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
|
||||
app.setProperty("platform", "wayland")
|
||||
|
||||
overlay = TranslationOverlay()
|
||||
|
||||
overlay.show()
|
||||
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
|
||||
signal.signal(signal.SIGTERM, handle_exit)
|
||||
return (app, overlay)
|
||||
# sys.exit(app.exec())
|
||||
|
||||
if ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ADD_OVERLAY = True
|
||||
if not ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
|
||||
capture = overlay.capture_behind()
|
||||
capture.save("capture.png")
|
||||
sys.exit(app.exec())
|
||||
111
draw.py
111
draw.py
@ -1,44 +1,52 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import os,io, sys, numpy as np
|
||||
import os, io, sys, numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from utils import romanize, intercepts, add_furigana
|
||||
from logging_config import logger
|
||||
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE
|
||||
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
|
||||
|
||||
from PySide6.QtGui import QFont
|
||||
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
||||
|
||||
|
||||
|
||||
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
|
||||
# Load the image from bytes
|
||||
"""Modify the image bytes with the translated text and return the modified image bytes"""
|
||||
|
||||
with io.BytesIO(image_bytes) as byte_stream:
|
||||
image = Image.open(byte_stream)
|
||||
draw = ImageDraw.Draw(image)
|
||||
translate_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
||||
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
||||
|
||||
# Save the modified image back to bytes without changing the format
|
||||
with io.BytesIO() as byte_stream:
|
||||
image.save(byte_stream, format=image.format) # Save in original format
|
||||
modified_image_bytes = byte_stream.getvalue()
|
||||
|
||||
return modified_image_bytes
|
||||
|
||||
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw:
|
||||
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
|
||||
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
|
||||
translated_number = 0
|
||||
bounding_boxes = []
|
||||
logger.debug(f"Translations: {len(translation)} {translation}")
|
||||
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
|
||||
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
||||
if translated_number >= max_translate:
|
||||
logger.debug(f"Untranslated phrase: {untranslated_phrase}")
|
||||
if translated_number >= max_translate - 1:
|
||||
break
|
||||
translate_one_phrase(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
if replace:
|
||||
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
else:
|
||||
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
translated_number += 1
|
||||
return draw
|
||||
|
||||
def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tuple, bounding_boxes: list, untranslated_phrase: str) -> ImageDraw:
|
||||
# Draw the bounding box
|
||||
top_left, _, _, _ = position
|
||||
position = (top_left[0], top_left[1] - 60)
|
||||
def draw_one_phrase_add(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Draw the bounding box rectangle and text on the image above the original text"""
|
||||
|
||||
if SOURCE_LANG == 'ja':
|
||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||
@ -50,26 +58,75 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
|
||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||
|
||||
lines = text_content.split('\n')
|
||||
x,y = position
|
||||
max_width = 0
|
||||
total_height = 0
|
||||
total_height = len(lines) * (LINE_HEIGHT + LINE_SPACING)
|
||||
for line in lines:
|
||||
bbox = draw.textbbox(position, line, font=font)
|
||||
line_width = bbox[2] - bbox[0]
|
||||
max_width = max(max_width, line_width)
|
||||
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
||||
|
||||
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||
# Draw the bounding box
|
||||
top_left, _, _, _ = position
|
||||
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
|
||||
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
|
||||
right_edge = REGION[2]
|
||||
|
||||
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
|
||||
y_onscreen = max(top_left[1] - total_height, 0)
|
||||
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
|
||||
|
||||
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||
|
||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
||||
position = (adjusted_x,adjusted_y)
|
||||
for line in lines:
|
||||
draw.text(position, line, fill= TEXT_COLOR, font=font)
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
if ADD_OVERLAY:
|
||||
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
|
||||
adjusted_y += FONT_SIZE + LINE_SPACING
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
|
||||
|
||||
|
||||
|
||||
### Only support for horizontal text atm, vertical text is on the todo list
|
||||
def draw_one_phrase_replace(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
top_left, _, _, bottom_right = position
|
||||
max_width = bottom_right[0] - top_left[0]
|
||||
font_size = bottom_right[1] - top_left[1]
|
||||
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
|
||||
while True:
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
if font.get_max_width < max_width:
|
||||
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
|
||||
break
|
||||
elif font_size <= 1:
|
||||
break
|
||||
else:
|
||||
font_size -= 1
|
||||
|
||||
def get_max_width(lines: list, font_path, font_size) -> int:
|
||||
"""Get the maximum width of the text lines"""
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
max_width = 0
|
||||
dummy_image = Image.new("RGB", (1, 1))
|
||||
draw = ImageDraw.Draw(dummy_image)
|
||||
for line in lines:
|
||||
bbox = draw.textbbox((0,0), line, font=font)
|
||||
line_width = bbox[2] - bbox[0]
|
||||
max_width = max(max_width, line_width)
|
||||
return max_width
|
||||
|
||||
def get_max_height(lines: list, font_size, line_spacing) -> int:
|
||||
"""Get the maximum height of the text lines"""
|
||||
return len(lines) * (font_size + line_spacing)
|
||||
|
||||
def adjust_if_intersects(x: int, y: int,
|
||||
bounding_box: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str,
|
||||
max_width: int, total_height: int) -> tuple:
|
||||
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
|
||||
y = np.max([y,0])
|
||||
if len(bounding_boxes) > 0:
|
||||
for box in bounding_boxes:
|
||||
@ -79,5 +136,3 @@ def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: li
|
||||
bounding_boxes.append(adjusted_bounding_box)
|
||||
return adjusted_bounding_box
|
||||
|
||||
|
||||
|
||||
|
||||
@ -1,37 +1,63 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from typing import List, Dict
|
||||
from dotenv import load_dotenv
|
||||
import os , sys, torch, time
|
||||
from multiprocessing import Process, Event
|
||||
import os , sys, torch, time, ast
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
from multiprocessing import Process, Event, Value
|
||||
load_dotenv()
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
from config import device
|
||||
from config import device, GEMINI_API_KEY, GROQ_API_KEY
|
||||
from logging_config import logger
|
||||
from groq import Groq
|
||||
import google.generativeai as genai
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
|
||||
class Gemini():
|
||||
def __init__(self, name, rate):
|
||||
self.name = name
|
||||
|
||||
class ApiModel():
|
||||
def __init__(self, model, # model name
|
||||
rate, # rate of calls per minute
|
||||
api_key, # api key for the model wrt the site
|
||||
):
|
||||
self.model = model
|
||||
self.rate = rate
|
||||
self.curr_calls = 0
|
||||
self.time = 0
|
||||
self.api_key = api_key
|
||||
self.curr_calls = Value('i', 0)
|
||||
self.time = Value('i', 0)
|
||||
self.process = None
|
||||
self.stop_event = Event()
|
||||
self.site = None
|
||||
self.from_lang = None
|
||||
self.target_lang = None
|
||||
self.request = None # request response from API
|
||||
|
||||
def __repr__(self):
|
||||
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
|
||||
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
return self.model
|
||||
|
||||
def background_task(self):
|
||||
def set_lang(self, from_lang, target_lang):
|
||||
self.from_lang = from_lang
|
||||
self.target_lang = target_lang
|
||||
|
||||
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
|
||||
async def api_rate_check(self):
|
||||
# Background task to manage the rate of calls to the API
|
||||
while not self.stop_event.is_set():
|
||||
time.sleep(5)
|
||||
self.time += 5
|
||||
if self.time >= 60:
|
||||
self.time = 0
|
||||
self.curr_calls = 0
|
||||
start_time = time.monotonic()
|
||||
self.time.value += 5
|
||||
if self.time.value >= 60:
|
||||
self.time.value = 0
|
||||
self.curr_calls.value = 0
|
||||
elapsed = time.monotonic() - start_time
|
||||
# Sleep for exactly 5 seconds minus the elapsed time
|
||||
sleep_time = max(0, 5 - elapsed)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
def background_task(self):
|
||||
asyncio.run(self.api_rate_check())
|
||||
|
||||
def start(self):
|
||||
# Start the background task
|
||||
@ -49,6 +75,68 @@ class Gemini():
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
|
||||
def request_func(request):
|
||||
@wraps(request)
|
||||
def wrapper(self, text, *args, **kwargs):
|
||||
if self.curr_calls.value < self.rate:
|
||||
# try:
|
||||
response = request(self, text, *args, **kwargs)
|
||||
self.curr_calls.value += 1
|
||||
return response
|
||||
# except Exception as e:
|
||||
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
||||
else:
|
||||
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
|
||||
raise TooManyRequests('Rate limit reached for this model.')
|
||||
return wrapper
|
||||
|
||||
@request_func
|
||||
def translate(self, request_fn, texts_to_translate):
|
||||
if len(texts_to_translate) == 0:
|
||||
return []
|
||||
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
|
||||
response = request_fn(self, prompt)
|
||||
return ast.literal_eval(response.strip())
|
||||
|
||||
class Groqq(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GROQ_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Groq"
|
||||
|
||||
def request(self, content):
|
||||
client = Groq()
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
model=self.model
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Groqq.request, texts_to_translate)
|
||||
|
||||
class Gemini(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Gemini"
|
||||
|
||||
def request(self, content):
|
||||
genai.configure(api_key=self.api_key)
|
||||
safety_settings = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
||||
return response.text.strip()
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Gemini.request, texts_to_translate)
|
||||
|
||||
class TranslationDataset(Dataset):
|
||||
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
||||
"""
|
||||
@ -174,9 +262,11 @@ def generate_text(
|
||||
return all_generated_texts
|
||||
|
||||
if __name__ == '__main__':
|
||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
||||
tokenizer.src_lang = "zh"
|
||||
texts = ["你好","我"]
|
||||
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
||||
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
||||
groq.set_lang('zh','en')
|
||||
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
||||
gemini.set_lang('zh','en')
|
||||
print(gemini.translate(['荷兰咯']))
|
||||
print(groq.translate(['荷兰咯']))
|
||||
@ -1,15 +1,17 @@
|
||||
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
||||
import google.generativeai as genai
|
||||
import torch, os, sys, ast
|
||||
import torch, os, sys, ast, json
|
||||
from utils import standardize_lang
|
||||
from functools import wraps
|
||||
from batching import generate_text, Gemini
|
||||
import random
|
||||
import batching
|
||||
from batching import generate_text, Gemini, Groq
|
||||
from logging_config import logger
|
||||
from multiprocessing import Process,Event
|
||||
# root dir
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||
|
||||
##############################
|
||||
# translation decorator
|
||||
@ -27,36 +29,30 @@ def translate(translation_func):
|
||||
|
||||
|
||||
###############################
|
||||
def init_GEMINI(models_and_rates = None):
|
||||
if not models_and_rates:
|
||||
## this is default for free tier
|
||||
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref
|
||||
models = [Gemini(name, rate) for name, rate in models_and_rates.items()]
|
||||
def init_API_LLM(from_lang, target_lang):
|
||||
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||
with open('api_models.json', 'r') as f:
|
||||
models_and_rates = json.load(f)
|
||||
models = []
|
||||
for class_type, class_models in models_and_rates.items():
|
||||
cls = getattr(batching, class_type)
|
||||
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
|
||||
models.extend(instantiated_objects)
|
||||
|
||||
for model in models:
|
||||
model.start()
|
||||
genai.configure(api_key=GEMINI_KEY)
|
||||
model.set_lang(from_lang, target_lang)
|
||||
return models
|
||||
|
||||
def translate_GEMINI(text, models, from_lang, target_lang):
|
||||
safety_settings = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
|
||||
def translate_API_LLM(text, models):
|
||||
random.shuffle(models)
|
||||
for model in models:
|
||||
if model.curr_calls < model.rate:
|
||||
try:
|
||||
response = genai.GenerativeModel(model.name).generate_content(prompt,
|
||||
safety_settings=safety_settings)
|
||||
model.curr_calls += 1
|
||||
logger.info(repr(model))
|
||||
logger.info(f'Model Response: {response.text.strip()}')
|
||||
return ast.literal_eval(response.text.strip())
|
||||
except Exception as e:
|
||||
logger.error(f"Error with model {model.name}. Error: {e}")
|
||||
logger.error("No models available to translate. Please wait for a model to be available.")
|
||||
|
||||
try:
|
||||
return model.translate(text)
|
||||
except:
|
||||
continue
|
||||
logger.error("All models have failed to translate the text.")
|
||||
|
||||
###############################
|
||||
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
||||
@ -127,12 +123,6 @@ def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
|
||||
else:
|
||||
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
||||
|
||||
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
|
||||
if model_type == 'gemini':
|
||||
return init_GEMINI(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
|
||||
|
||||
def init_Causal_LLM(model_type, **kwargs):
|
||||
pass
|
||||
###
|
||||
@ -155,23 +145,8 @@ def translate_Seq_LLM(text,
|
||||
|
||||
|
||||
### if you want to use any other translation, just define a translate function with input text and output text.
|
||||
# def translate_api(text):
|
||||
#@translate
|
||||
#def translate_Causal_LLM(text, model_type, model)
|
||||
|
||||
@translate
|
||||
def translate_API_LLM(text: list[str],
|
||||
model_type: str, # 'gemma'
|
||||
models: list, # list of objects of classes defined in batching.py
|
||||
from_lang: str, # suggested to use ISO 639-1 codes
|
||||
target_lang: str # suggested to use ISO 639-1 codes
|
||||
) -> list[str]:
|
||||
if model_type == 'gemini':
|
||||
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||
return translate_GEMINI(text, models, from_lang, target_lang)
|
||||
else:
|
||||
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
|
||||
#def translate_Causal_LLM(text, model_type, model)
|
||||
|
||||
@translate
|
||||
def translate_Causal_LLM(text: list[str],
|
||||
@ -211,7 +186,5 @@ def translate_func(model):
|
||||
|
||||
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
|
||||
if __name__ == "__main__":
|
||||
models = init_GEMINI()
|
||||
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en'))
|
||||
# model, tokenizer = init_M2M()
|
||||
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))
|
||||
models = init_API_LLM('ja', 'en')
|
||||
print(translate_API_LLM(['こんにちは'], models))
|
||||
@ -64,4 +64,4 @@ def setup_logger(
|
||||
print(f"Failed to setup logger: {e}")
|
||||
return None
|
||||
|
||||
logger = setup_logger('on_screen_translator', log_file='translate.log')
|
||||
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)
|
||||
115
qtapp.py
Normal file
115
qtapp.py
Normal file
@ -0,0 +1,115 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
|
||||
from create_overlay import app, overlay
|
||||
from typing import Optional, List
|
||||
###################################################################################
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
|
||||
QColor, QIcon, QImage, QPixmap)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TranslationThread(QThread):
|
||||
translation_ready = Signal(list, list) # Signal to send translation results
|
||||
start_capture = Signal()
|
||||
end_capture = Signal()
|
||||
screen_capture = Signal(int, int, int, int)
|
||||
def __init__(self, ocr, models, source_lang, target_lang, interval):
|
||||
super().__init__()
|
||||
self.ocr = ocr
|
||||
self.models = models
|
||||
self.source_lang = source_lang
|
||||
self.target_lang = target_lang
|
||||
self.interval = interval
|
||||
self.running = True
|
||||
self.prev_words = set()
|
||||
self.runs = 0
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
self.start_capture.emit()
|
||||
untranslated_image = printsc(REGION)
|
||||
self.end_capture.emit()
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
|
||||
|
||||
if self.runs == 0:
|
||||
logger.info('Initial run')
|
||||
else:
|
||||
logger.info(f'Run number: {self.runs}.')
|
||||
self.runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
|
||||
if self.prev_words != curr_words:
|
||||
logger.info('Translating')
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
translation = translate_API_LLM(to_translate, self.models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
|
||||
# Emit the translation results
|
||||
modify_image_bytes(byte_image, ocr_output, translation)
|
||||
self.translation_ready.emit(ocr_output, translation)
|
||||
|
||||
self.prev_words = curr_words
|
||||
else:
|
||||
logger.info("No new words to translate. Output will not refresh.")
|
||||
|
||||
logger.info(f'Sleeping for {self.interval} seconds')
|
||||
time.sleep(self.interval)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Initialize OCR
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
# Initialize translation
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
|
||||
|
||||
# Create and start translation thread
|
||||
translation_thread = TranslationThread(
|
||||
ocr=ocr,
|
||||
models=models,
|
||||
source_lang=SOURCE_LANG,
|
||||
target_lang=TARGET_LANG,
|
||||
interval=INTERVAL
|
||||
)
|
||||
|
||||
# Connect translation results to overlay update
|
||||
translation_thread.start_capture.connect(overlay.prepare_for_capture)
|
||||
translation_thread.end_capture.connect(overlay.restore_after_capture)
|
||||
translation_thread.translation_ready.connect(overlay.update_translation)
|
||||
translation_thread.screen_capture.connect(overlay.capture_behind)
|
||||
# Start the translation thread
|
||||
translation_thread.start()
|
||||
|
||||
|
||||
# Start Qt event loop
|
||||
result = app.exec()
|
||||
|
||||
# Cleanup
|
||||
translation_thread.stop()
|
||||
translation_thread.wait()
|
||||
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
6
web.py
6
web.py
@ -1,12 +1,12 @@
|
||||
from flask import Flask, Response, render_template
|
||||
import threading
|
||||
import io
|
||||
import translate
|
||||
import app
|
||||
app = Flask(__name__)
|
||||
|
||||
# Global variable to hold the current image
|
||||
def curr_image():
|
||||
return translate.latest_image
|
||||
return app.latest_image
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
@ -29,7 +29,7 @@ def stream_image():
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Start the image updating thread
|
||||
threading.Thread(target=translate.main, daemon=True).start()
|
||||
threading.Thread(target=app.main, daemon=True).start()
|
||||
|
||||
# Start the Flask web server
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user