From 499a2c39723a29a70e51fb7b8efe259c0c0c1c62 Mon Sep 17 00:00:00 2001 From: chickenflyshigh Date: Mon, 4 Nov 2024 15:36:50 +1100 Subject: [PATCH] Added support for Groqq API models. Created QT6 overlay app. --- .gitignore | 3 +- README.md | 2 +- api_models.json | 13 ++ translate.py => app.py | 25 +++- config.py | 23 +++- create_overlay.py | 305 +++++++++++++++++++++++++++++++++++++++++ draw.py | 111 +++++++++++---- helpers/batching.py | 134 +++++++++++++++--- helpers/translation.py | 81 ++++------- logging_config.py | 2 +- qtapp.py | 115 ++++++++++++++++ web.py | 6 +- 12 files changed, 699 insertions(+), 121 deletions(-) create mode 100644 api_models.json rename translate.py => app.py (82%) create mode 100644 create_overlay.py create mode 100644 qtapp.py diff --git a/.gitignore b/.gitignore index 93e590d..d32e2e0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ translate/ __pycache__/ .* test.py - +notebooks/ +qttest.py \ No newline at end of file diff --git a/README.md b/README.md index 0f89a85..b442cc4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ ## Debugging Issues -1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment. +1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment. 2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info. diff --git a/api_models.json b/api_models.json new file mode 100644 index 0000000..1d04faf --- /dev/null +++ b/api_models.json @@ -0,0 +1,13 @@ +{ + "Gemini": { + "gemini-1.5-pro": 2, + "gemini-1.5-flash": 15, + "gemini-1.5-flash-8b": 8, + "gemini-1.0-pro": 15 + }, + "Groqq": { + "llama-3.2-90b-text-preview": 30, + "llama3-70b-8192": 30, + "mixtral-8x7b-32768": 30 + } +} diff --git a/translate.py b/app.py similarity index 82% rename from translate.py rename to app.py index de71c60..5309916 100644 --- a/translate.py +++ b/app.py @@ -9,9 +9,11 @@ from utils import printsc, convert_image_to_bytes, bytes_to_image from ocr import get_words, init_OCR, id_keep_source_lang from logging_config import logger from draw import modify_image_bytes -from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL +from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL ################################################################################### +ADD_OVERLAY = False + latest_image = None def main(): @@ -23,11 +25,21 @@ def main(): ##### Initialize the translation ##### # model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG) - models = init_API_LLM(TRANSLATION_MODEL) + models = init_API_LLM(SOURCE_LANG, TARGET_LANG) ################################################################################### runs = 0 + app.exec() while True: + if ADD_OVERLAY: + overlay.clear_all_text() + untranslated_image = printsc(REGION) + + if ADD_OVERLAY: + overlay.text_entries = overlay.text_entries_copy + overlay.update() + overlay.text_entries.clear() + byte_image = convert_image_to_bytes(untranslated_image) ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language @@ -46,7 +58,7 @@ def main(): to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE] # translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG) - translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG) + translation = translate_API_LLM(to_translate, models) logger.info(f'Translation from {to_translate} to\n {translation}') translated_image = modify_image_bytes(byte_image, ocr_output, translation) latest_image = bytes_to_image(translated_image) @@ -58,11 +70,14 @@ def main(): logger.info(f'Sleeping for {INTERVAL} seconds') time.sleep(INTERVAL) - + # if ADD_OVERLAY: + # sys.exit(app.exec()) + ################### TODO ################## # 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models. # 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes. # Create a way for it to just replace the text and provide only the translation on-screen. Qt6 if __name__ == "__main__": - main() \ No newline at end of file + sys.exit(main()) + \ No newline at end of file diff --git a/config.py b/config.py index 3230746..d0998e3 100644 --- a/config.py +++ b/config.py @@ -6,6 +6,7 @@ load_dotenv(override=True) ### EDIT THESE VARIABLES ### ### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en' + INTERVAL = int(os.getenv('INTERVAL')) ### OCR @@ -13,24 +14,34 @@ OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True')) ### Drawing/Overlay Config +ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True')) +FILL_COLOUR = os.getenv('FILL_COLOUR', 'white') FONT_FILE = os.getenv('FONT_FILE') FONT_SIZE = int(os.getenv('FONT_SIZE', 16)) LINE_SPACING = int(os.getenv('LINE_SPACING', 3)) REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)')) -TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000") +FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000") TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True')) + +# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file +GEMINI_API_KEY = os.getenv('GEMINI_KEY') +GROQ_API_KEY = os.getenv('GROQ_API_KEY') # +# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf + ### Translation -BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6)) -GEMINI_KEY = os.getenv('GEMINI_KEY') -LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False')) -MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512)) -MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512)) MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200)) SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja') TARGET_LANG = os.getenv('TARGET_LANG', 'en') + + +### Local Translation TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True')) +MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512)) +MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512)) +BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6)) +LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False')) ################################################################################################### diff --git a/create_overlay.py b/create_overlay.py new file mode 100644 index 0000000..91c8556 --- /dev/null +++ b/create_overlay.py @@ -0,0 +1,305 @@ +from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer +from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon) +from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QLabel, QSystemTrayIcon, QMenu) +import sys, io, os, signal, time, platform +from PIL import Image +from dataclasses import dataclass +from typing import List, Optional +from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE +from logging_config import logger + + +def qpixmap_to_bytes(qpixmap): + qimage = qpixmap.toImage() + buffer = QBuffer() + buffer.open(QBuffer.ReadWrite) + qimage.save(buffer, "PNG") + return qimage + +@dataclass +class TextEntry: + text: str + x: int + y: int + font: QFont = QFont('Arial', FONT_SIZE) + visible: bool = True + text_color: Qt.GlobalColor = Qt.GlobalColor.red + background_color: Optional[Qt.GlobalColor] = None + padding: int = 1 + +class TranslationOverlay(QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("Translation Overlay") + self.is_passthrough = True + self.text_entries: List[TextEntry] = [] + self.setup_window_attributes() + self.setup_shortcuts() + self.closeEvent = lambda event: QApplication.quit() + self.default_font = QFont('Arial', FONT_SIZE) + self.text_entries_copy: List[TextEntry] = [] + self.next_text_entries: List[TextEntry] = [] + #self.show_background = True + self.background_opacity = 0.5 + # self.setup_tray() + + def prepare_for_capture(self): + """Preserve current state and clear overlay""" + if ADD_OVERLAY: + self.text_entries_copy = self.text_entries.copy() + self.clear_all_text() + self.update() + + def restore_after_capture(self): + """Restore overlay state after capture""" + if ADD_OVERLAY: + logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}') + self.text_entries = self.text_entries_copy.copy() + logger.debug(f"Restored text entries: {self.text_entries}") + self.update() + + def add_next_text_at_position_no_update(self, x: int, y: int, text: str, + font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red): + """Add new text without triggering update""" + entry = TextEntry( + text=text, + x=x, + y=y, + font=font or self.default_font, + text_color=text_color + ) + self.next_text_entries.append(entry) + + def update_translation(self, ocr_output, translation): + # Update your overlay with new translations here + # You'll need to implement the logic to display the translations + self.clear_all_text() + self.text_entries = self.next_text_entries.copy() + self.next_text_entries.clear() + self.update() + + def capture_behind(self, x=None, y=None, width=None, height=None): + """ + Capture the screen area behind the overlay. + If no coordinates provided, captures the area under the window. + """ + # Temporarily hide the window + self.hide() + + # Get screen + screen = QScreen.grabWindow( + self.screen(), + 0, + x if x is not None else self.x(), + y if y is not None else self.y(), + width if width is not None else self.width(), + height if height is not None else self.height() + ) + + # Show the window again + self.show() + screen_bytes = qpixmap_to_bytes(screen) + return screen_bytes + + def clear_all_text(self): + """Clear all text entries""" + self.text_entries.clear() + self.update() + + def setup_window_attributes(self): + # Set window flags for overlay behavior + self.setWindowFlags( + Qt.WindowType.FramelessWindowHint | + Qt.WindowType.WindowStaysOnTopHint | + Qt.WindowType.Tool + ) + + # Set attributes for transparency + self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground) + + # Make the window cover the entire screen + self.setGeometry(QApplication.primaryScreen().geometry()) + + # Special handling for Wayland + if platform.system() == "Linux": + if "WAYLAND_DISPLAY" in os.environ: + self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo) + self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors) + + def setup_shortcuts(self): + # Toggle visibility (Alt+Shift+T) + self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self) + self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility) + + # Toggle passthrough mode (Alt+Shift+P) + self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self) + self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough) + + # Quick hide (Escape) + self.hide_shortcut = QShortcut(QKeySequence("Esc"), self) + self.hide_shortcut.activated.connect(self.hide) + + # Clear all text (Alt+Shift+C) + self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self) + self.clear_shortcut.activated.connect(self.clear_all_text) + + # Toggle background + self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self) + self.toggle_background_shortcut.activated.connect(self.toggle_background) + + def toggle_visibility(self): + if self.isVisible(): + self.hide() + else: + self.show() + + def toggle_passthrough(self): + self.is_passthrough = not self.is_passthrough + if self.is_passthrough: + self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents) + if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ: + self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint) + else: + self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False) + if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ: + self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint) + + self.hide() + self.show() + + def toggle_background(self): + """Toggle background visibility""" + self.show_background = not self.show_background + self.update() + + def set_background_opacity(self, opacity: float): + """Set background opacity (0.0 to 1.0)""" + self.background_opacity = max(0.0, min(1.0, opacity)) + self.update() + + + + def add_text_at_position(self, x: int, y: int, text: str): + """Add new text at specific coordinates""" + entry = TextEntry(text, x, y) + self.text_entries.append(entry) + self.update() + + def update_text_at_position(self, x: int, y: int, text: str): + """Update text at specific coordinates, or add if none exists""" + # Look for existing text entry near these coordinates (within 5 pixels) + for entry in self.text_entries: + if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1: + entry.text = text + self.update() + return + + # If no existing entry found, add new one + self.add_text_at_position(x, y, text) + + def setup_tray(self): + self.tray_icon = QSystemTrayIcon(self) + self.tray_icon.setIcon(QIcon.fromTheme("applications-system")) + + tray_menu = QMenu() + + toggle_action = tray_menu.addAction("Show/Hide Overlay") + toggle_action.triggered.connect(self.toggle_visibility) + + toggle_passthrough = tray_menu.addAction("Toggle Passthrough") + toggle_passthrough.triggered.connect(self.toggle_passthrough) + + # Add background toggle to tray menu + toggle_background = tray_menu.addAction("Toggle Background") + toggle_background.triggered.connect(self.toggle_background) + + clear_action = tray_menu.addAction("Clear All Text") + clear_action.triggered.connect(self.clear_all_text) + + tray_menu.addSeparator() + + quit_action = tray_menu.addAction("Quit") + quit_action.triggered.connect(self.clean_exit) + + self.tray_icon.setToolTip("Translation Overlay") + self.tray_icon.setContextMenu(tray_menu) + self.tray_icon.show() + self.tray_icon.activated.connect(self.tray_activated) + + def remove_text_at_position(self, x: int, y: int): + """Remove text entry near specified coordinates""" + self.text_entries = [ + entry for entry in self.text_entries + if abs(entry.x - x) > 1 or abs(entry.y - y) > 1 + ] + self.update() + + def paintEvent(self, event): + painter = QPainter(self) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + + # Draw each text entry + for entry in self.text_entries: + if not entry.visible: + continue + + # Set the font for this specific entry + painter.setFont(entry.font) + text_metrics = painter.fontMetrics() + + # Get the bounding rectangles for text + text_bounds = text_metrics.boundingRect( + entry.text + ) + total_width = text_bounds.width() + total_height = text_bounds.height() + + # Create rectangles for text placement + text_rect = QRect(entry.x, entry.y, total_width, total_height) + # Calculate background rectangle that encompasses both texts + if entry.background_color is not None: + bg_rect = QRect(entry.x - entry.padding, + entry.y - entry.padding, + total_width + (2 * entry.padding), + total_height + (2 * entry.padding)) + + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(entry.background_color) + painter.drawRect(bg_rect) + + # Draw the texts + painter.setPen(entry.text_color) + painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text) + + + +def handle_exit(signum, frame): + QApplication.quit() + +def start_overlay(): + app = QApplication(sys.argv) + + # Enable Wayland support if available + if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ: + app.setProperty("platform", "wayland") + + overlay = TranslationOverlay() + + overlay.show() + signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt) + signal.signal(signal.SIGTERM, handle_exit) + return (app, overlay) + # sys.exit(app.exec()) + +if ADD_OVERLAY: + app, overlay = start_overlay() + +if __name__ == "__main__": + ADD_OVERLAY = True + if not ADD_OVERLAY: + app, overlay = start_overlay() + overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG") + capture = overlay.capture_behind() + capture.save("capture.png") + sys.exit(app.exec()) \ No newline at end of file diff --git a/draw.py b/draw.py index a83bce7..9cf231e 100644 --- a/draw.py +++ b/draw.py @@ -1,44 +1,52 @@ from PIL import Image, ImageDraw, ImageFont -import os,io, sys, numpy as np +import os, io, sys, numpy as np sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers')) from utils import romanize, intercepts, add_furigana from logging_config import logger -from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE - +from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION +from PySide6.QtGui import QFont font = ImageFont.truetype(FONT_FILE, FONT_SIZE) - def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes: - # Load the image from bytes + """Modify the image bytes with the translated text and return the modified image bytes""" + with io.BytesIO(image_bytes) as byte_stream: image = Image.open(byte_stream) draw = ImageDraw.Draw(image) - translate_image(draw, translation, ocr_output, MAX_TRANSLATE) + draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE) # Save the modified image back to bytes without changing the format with io.BytesIO() as byte_stream: image.save(byte_stream, format=image.format) # Save in original format modified_image_bytes = byte_stream.getvalue() - return modified_image_bytes -def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw: +def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw: + """Draw the original, translated and optionally the romanisation of the texts on the image""" translated_number = 0 bounding_boxes = [] + logger.debug(f"Translations: {len(translation)} {translation}") + logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}") for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output): - if translated_number >= max_translate: + logger.debug(f"Untranslated phrase: {untranslated_phrase}") + if translated_number >= max_translate - 1: break - translate_one_phrase(draw, translation[i], position, bounding_boxes, untranslated_phrase) + if replace: + draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase) + else: + draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase) translated_number += 1 return draw -def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tuple, bounding_boxes: list, untranslated_phrase: str) -> ImageDraw: - # Draw the bounding box - top_left, _, _, _ = position - position = (top_left[0], top_left[1] - 60) +def draw_one_phrase_add(draw: ImageDraw, + translated_phrase: str, + position: tuple, bounding_boxes: list, + untranslated_phrase: str) -> ImageDraw: + """Draw the bounding box rectangle and text on the image above the original text""" + if SOURCE_LANG == 'ja': untranslated_phrase = add_furigana(untranslated_phrase) romanized_phrase = romanize(untranslated_phrase, 'ja') @@ -50,26 +58,75 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl text_content = f"{translated_phrase}\n{untranslated_phrase}" lines = text_content.split('\n') - x,y = position - max_width = 0 - total_height = 0 - total_height = len(lines) * (LINE_HEIGHT + LINE_SPACING) - for line in lines: - bbox = draw.textbbox(position, line, font=font) - line_width = bbox[2] - bbox[0] - max_width = max(max_width, line_width) - bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase) - adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height) + # Draw the bounding box + top_left, _, _, _ = position + max_width = get_max_width(lines, FONT_FILE, FONT_SIZE) + total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING) + right_edge = REGION[2] + + # Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate + x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width + y_onscreen = max(top_left[1] - total_height, 0) + bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase) + + adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height) + adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1] draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1) position = (adjusted_x,adjusted_y) for line in lines: - draw.text(position, line, fill= TEXT_COLOR, font=font) + draw.text(position, line, fill= FONT_COLOUR, font=font) + if ADD_OVERLAY: + overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR) adjusted_y += FONT_SIZE + LINE_SPACING position = (adjusted_x,adjusted_y) -def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple: + + + +### Only support for horizontal text atm, vertical text is on the todo list +def draw_one_phrase_replace(draw: ImageDraw, + translated_phrase: str, + position: tuple, bounding_boxes: list, + untranslated_phrase: str) -> ImageDraw: + """Cover up old text and add translation directly on top""" + # Draw the bounding box + top_left, _, _, bottom_right = position + max_width = bottom_right[0] - top_left[0] + font_size = bottom_right[1] - top_left[1] + draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR) + while True: + font = ImageFont.truetype(FONT_FILE, font_size) + if font.get_max_width < max_width: + draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font) + break + elif font_size <= 1: + break + else: + font_size -= 1 + +def get_max_width(lines: list, font_path, font_size) -> int: + """Get the maximum width of the text lines""" + font = ImageFont.truetype(font_path, font_size) + max_width = 0 + dummy_image = Image.new("RGB", (1, 1)) + draw = ImageDraw.Draw(dummy_image) + for line in lines: + bbox = draw.textbbox((0,0), line, font=font) + line_width = bbox[2] - bbox[0] + max_width = max(max_width, line_width) + return max_width + +def get_max_height(lines: list, font_size, line_spacing) -> int: + """Get the maximum height of the text lines""" + return len(lines) * (font_size + line_spacing) + +def adjust_if_intersects(x: int, y: int, + bounding_box: tuple, bounding_boxes: list, + untranslated_phrase: str, + max_width: int, total_height: int) -> tuple: + """Adjust the y coordinate if the bounding box intersects with any other bounding box""" y = np.max([y,0]) if len(bounding_boxes) > 0: for box in bounding_boxes: @@ -79,5 +136,3 @@ def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: li bounding_boxes.append(adjusted_bounding_box) return adjusted_bounding_box - - diff --git a/helpers/batching.py b/helpers/batching.py index 2a89f27..5c49437 100644 --- a/helpers/batching.py +++ b/helpers/batching.py @@ -1,37 +1,63 @@ from torch.utils.data import Dataset, DataLoader from typing import List, Dict from dotenv import load_dotenv -import os , sys, torch, time -from multiprocessing import Process, Event +import os , sys, torch, time, ast +from werkzeug.exceptions import TooManyRequests +from multiprocessing import Process, Event, Value load_dotenv() sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) -from config import device +from config import device, GEMINI_API_KEY, GROQ_API_KEY from logging_config import logger +from groq import Groq +import google.generativeai as genai +import asyncio +from functools import wraps -class Gemini(): - def __init__(self, name, rate): - self.name = name + +class ApiModel(): + def __init__(self, model, # model name + rate, # rate of calls per minute + api_key, # api key for the model wrt the site + ): + self.model = model self.rate = rate - self.curr_calls = 0 - self.time = 0 + self.api_key = api_key + self.curr_calls = Value('i', 0) + self.time = Value('i', 0) self.process = None self.stop_event = Event() + self.site = None + self.from_lang = None + self.target_lang = None + self.request = None # request response from API def __repr__(self): - return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.' + return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.' def __str__(self): - return self.name + return self.model - def background_task(self): + def set_lang(self, from_lang, target_lang): + self.from_lang = from_lang + self.target_lang = target_lang + + ### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit + async def api_rate_check(self): # Background task to manage the rate of calls to the API while not self.stop_event.is_set(): - time.sleep(5) - self.time += 5 - if self.time >= 60: - self.time = 0 - self.curr_calls = 0 + start_time = time.monotonic() + self.time.value += 5 + if self.time.value >= 60: + self.time.value = 0 + self.curr_calls.value = 0 + elapsed = time.monotonic() - start_time + # Sleep for exactly 5 seconds minus the elapsed time + sleep_time = max(0, 5 - elapsed) + await asyncio.sleep(sleep_time) + + def background_task(self): + asyncio.run(self.api_rate_check()) def start(self): # Start the background task @@ -49,6 +75,68 @@ class Gemini(): if self.process.is_alive(): self.process.terminate() + def request_func(request): + @wraps(request) + def wrapper(self, text, *args, **kwargs): + if self.curr_calls.value < self.rate: + # try: + response = request(self, text, *args, **kwargs) + self.curr_calls.value += 1 + return response + # except Exception as e: + #logger.error(f"Error with model {self.model} from {self.site}. Error: {e}") + else: + logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.") + raise TooManyRequests('Rate limit reached for this model.') + return wrapper + + @request_func + def translate(self, request_fn, texts_to_translate): + if len(texts_to_translate) == 0: + return [] + prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}" + response = request_fn(self, prompt) + return ast.literal_eval(response.strip()) + +class Groqq(ApiModel): + def __init__(self, model, rate, api_key = GROQ_API_KEY): + super().__init__(model, rate, api_key) + self.site = "Groq" + + def request(self, content): + client = Groq() + chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": content, + } + ], + model=self.model + ) + return chat_completion.choices[0].message.content + + def translate(self, texts_to_translate): + return super().translate(Groqq.request, texts_to_translate) + +class Gemini(ApiModel): + def __init__(self, model, rate, api_key = GEMINI_API_KEY): + super().__init__(model, rate, api_key) + self.site = "Gemini" + + def request(self, content): + genai.configure(api_key=self.api_key) + safety_settings = { + "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", + "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", + "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", + "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"} + response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings) + return response.text.strip() + + def translate(self, texts_to_translate): + return super().translate(Gemini.request, texts_to_translate) + class TranslationDataset(Dataset): def __init__(self, texts: List[str], tokenizer, max_length: int = 512): """ @@ -174,9 +262,11 @@ def generate_text( return all_generated_texts if __name__ == '__main__': - from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer - model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device) - tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True) - tokenizer.src_lang = "zh" - texts = ["你好","我"] - print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en'))) + GROQ_API_KEY = os.getenv('GROQ_API_KEY') + GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') + groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY) + groq.set_lang('zh','en') + gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY) + gemini.set_lang('zh','en') + print(gemini.translate(['荷兰咯'])) + print(groq.translate(['荷兰咯'])) \ No newline at end of file diff --git a/helpers/translation.py b/helpers/translation.py index 5ded3b9..cbb2a4a 100644 --- a/helpers/translation.py +++ b/helpers/translation.py @@ -1,15 +1,17 @@ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM import google.generativeai as genai -import torch, os, sys, ast +import torch, os, sys, ast, json from utils import standardize_lang from functools import wraps -from batching import generate_text, Gemini +import random +import batching +from batching import generate_text, Gemini, Groq from logging_config import logger from multiprocessing import Process,Event # root dir sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) -from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models +from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models ############################## # translation decorator @@ -27,36 +29,30 @@ def translate(translation_func): ############################### -def init_GEMINI(models_and_rates = None): - if not models_and_rates: - ## this is default for free tier - models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref - models = [Gemini(name, rate) for name, rate in models_and_rates.items()] +def init_API_LLM(from_lang, target_lang): + from_lang = standardize_lang(from_lang)['translation_model_lang'] + target_lang = standardize_lang(target_lang)['translation_model_lang'] + with open('api_models.json', 'r') as f: + models_and_rates = json.load(f) + models = [] + for class_type, class_models in models_and_rates.items(): + cls = getattr(batching, class_type) + instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()] + models.extend(instantiated_objects) + for model in models: model.start() - genai.configure(api_key=GEMINI_KEY) + model.set_lang(from_lang, target_lang) return models -def translate_GEMINI(text, models, from_lang, target_lang): - safety_settings = { - "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE", - "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE", - "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE", - "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"} - prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}" +def translate_API_LLM(text, models): + random.shuffle(models) for model in models: - if model.curr_calls < model.rate: - try: - response = genai.GenerativeModel(model.name).generate_content(prompt, - safety_settings=safety_settings) - model.curr_calls += 1 - logger.info(repr(model)) - logger.info(f'Model Response: {response.text.strip()}') - return ast.literal_eval(response.text.strip()) - except Exception as e: - logger.error(f"Error with model {model.name}. Error: {e}") - logger.error("No models available to translate. Please wait for a model to be available.") - + try: + return model.translate(text) + except: + continue + logger.error("All models have failed to translate the text.") ############################### # Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well. @@ -127,12 +123,6 @@ def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m' else: raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.") -def init_API_LLM(model_type, **kwargs): # model = 'gemma' - if model_type == 'gemini': - return init_GEMINI(**kwargs) - else: - raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.") - def init_Causal_LLM(model_type, **kwargs): pass ### @@ -155,23 +145,8 @@ def translate_Seq_LLM(text, ### if you want to use any other translation, just define a translate function with input text and output text. -# def translate_api(text): -#@translate -#def translate_Causal_LLM(text, model_type, model) -@translate -def translate_API_LLM(text: list[str], - model_type: str, # 'gemma' - models: list, # list of objects of classes defined in batching.py - from_lang: str, # suggested to use ISO 639-1 codes - target_lang: str # suggested to use ISO 639-1 codes - ) -> list[str]: - if model_type == 'gemini': - from_lang = standardize_lang(from_lang)['translation_model_lang'] - target_lang = standardize_lang(target_lang)['translation_model_lang'] - return translate_GEMINI(text, models, from_lang, target_lang) - else: - raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.") +#def translate_Causal_LLM(text, model_type, model) @translate def translate_Causal_LLM(text: list[str], @@ -211,7 +186,5 @@ def translate_func(model): ### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster if __name__ == "__main__": - models = init_GEMINI() - print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en')) - # model, tokenizer = init_M2M() - # print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en')) + models = init_API_LLM('ja', 'en') + print(translate_API_LLM(['こんにちは'], models)) \ No newline at end of file diff --git a/logging_config.py b/logging_config.py index 3b74af8..c1dce08 100644 --- a/logging_config.py +++ b/logging_config.py @@ -64,4 +64,4 @@ def setup_logger( print(f"Failed to setup logger: {e}") return None -logger = setup_logger('on_screen_translator', log_file='translate.log') \ No newline at end of file +logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG) \ No newline at end of file diff --git a/qtapp.py b/qtapp.py new file mode 100644 index 0000000..2d72364 --- /dev/null +++ b/qtapp.py @@ -0,0 +1,115 @@ +################################################################################### +##### IMPORT LIBRARIES ##### +import os, time, sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers')) +from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM +from utils import printsc, convert_image_to_bytes, bytes_to_image +from ocr import get_words, init_OCR, id_keep_source_lang +from logging_config import logger +from draw import modify_image_bytes +from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR +from create_overlay import app, overlay +from typing import Optional, List +################################################################################### +from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal +from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, + QColor, QIcon, QImage, QPixmap) +from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QLabel, QSystemTrayIcon, QMenu) +from dataclasses import dataclass + + +class TranslationThread(QThread): + translation_ready = Signal(list, list) # Signal to send translation results + start_capture = Signal() + end_capture = Signal() + screen_capture = Signal(int, int, int, int) + def __init__(self, ocr, models, source_lang, target_lang, interval): + super().__init__() + self.ocr = ocr + self.models = models + self.source_lang = source_lang + self.target_lang = target_lang + self.interval = interval + self.running = True + self.prev_words = set() + self.runs = 0 + + def run(self): + while self.running: + self.start_capture.emit() + untranslated_image = printsc(REGION) + self.end_capture.emit() + byte_image = convert_image_to_bytes(untranslated_image) + ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang) + + if self.runs == 0: + logger.info('Initial run') + else: + logger.info(f'Run number: {self.runs}.') + self.runs += 1 + + curr_words = set(get_words(ocr_output)) + + if self.prev_words != curr_words: + logger.info('Translating') + to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE] + translation = translate_API_LLM(to_translate, self.models) + logger.info(f'Translation from {to_translate} to\n {translation}') + + # Emit the translation results + modify_image_bytes(byte_image, ocr_output, translation) + self.translation_ready.emit(ocr_output, translation) + + self.prev_words = curr_words + else: + logger.info("No new words to translate. Output will not refresh.") + + logger.info(f'Sleeping for {self.interval} seconds') + time.sleep(self.interval) + + def stop(self): + self.running = False + + + +def main(): + + # Initialize OCR + OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en'] + ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU) + + # Initialize translation + models = init_API_LLM(SOURCE_LANG, TARGET_LANG) + + + # Create and start translation thread + translation_thread = TranslationThread( + ocr=ocr, + models=models, + source_lang=SOURCE_LANG, + target_lang=TARGET_LANG, + interval=INTERVAL + ) + + # Connect translation results to overlay update + translation_thread.start_capture.connect(overlay.prepare_for_capture) + translation_thread.end_capture.connect(overlay.restore_after_capture) + translation_thread.translation_ready.connect(overlay.update_translation) + translation_thread.screen_capture.connect(overlay.capture_behind) + # Start the translation thread + translation_thread.start() + + + # Start Qt event loop + result = app.exec() + + # Cleanup + translation_thread.stop() + translation_thread.wait() + + return result + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/web.py b/web.py index 5f544f5..0f7275e 100644 --- a/web.py +++ b/web.py @@ -1,12 +1,12 @@ from flask import Flask, Response, render_template import threading import io -import translate +import app app = Flask(__name__) # Global variable to hold the current image def curr_image(): - return translate.latest_image + return app.latest_image @app.route('/') def index(): @@ -29,7 +29,7 @@ def stream_image(): if __name__ == '__main__': # Start the image updating thread - threading.Thread(target=translate.main, daemon=True).start() + threading.Thread(target=app.main, daemon=True).start() # Start the Flask web server app.run(host='0.0.0.0', port=5000, debug=True)