Added support for Groqq API models. Created QT6 overlay app.

This commit is contained in:
chickenflyshigh 2024-11-04 15:36:50 +11:00
parent 17e7f6526f
commit 499a2c3972
12 changed files with 699 additions and 121 deletions

3
.gitignore vendored
View File

@ -4,4 +4,5 @@ translate/
__pycache__/ __pycache__/
.* .*
test.py test.py
notebooks/
qttest.py

View File

@ -1,4 +1,4 @@
## Debugging Issues ## Debugging Issues
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment. 1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment.
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info. 2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.

13
api_models.json Normal file
View File

@ -0,0 +1,13 @@
{
"Gemini": {
"gemini-1.5-pro": 2,
"gemini-1.5-flash": 15,
"gemini-1.5-flash-8b": 8,
"gemini-1.0-pro": 15
},
"Groqq": {
"llama-3.2-90b-text-preview": 30,
"llama3-70b-8192": 30,
"mixtral-8x7b-32768": 30
}
}

View File

@ -9,9 +9,11 @@ from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger from logging_config import logger
from draw import modify_image_bytes from draw import modify_image_bytes
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
################################################################################### ###################################################################################
ADD_OVERLAY = False
latest_image = None latest_image = None
def main(): def main():
@ -23,11 +25,21 @@ def main():
##### Initialize the translation ##### ##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG) # model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(TRANSLATION_MODEL) models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
################################################################################### ###################################################################################
runs = 0 runs = 0
app.exec()
while True: while True:
if ADD_OVERLAY:
overlay.clear_all_text()
untranslated_image = printsc(REGION) untranslated_image = printsc(REGION)
if ADD_OVERLAY:
overlay.text_entries = overlay.text_entries_copy
overlay.update()
overlay.text_entries.clear()
byte_image = convert_image_to_bytes(untranslated_image) byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
@ -46,7 +58,7 @@ def main():
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE] to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG) # translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG) translation = translate_API_LLM(to_translate, models)
logger.info(f'Translation from {to_translate} to\n {translation}') logger.info(f'Translation from {to_translate} to\n {translation}')
translated_image = modify_image_bytes(byte_image, ocr_output, translation) translated_image = modify_image_bytes(byte_image, ocr_output, translation)
latest_image = bytes_to_image(translated_image) latest_image = bytes_to_image(translated_image)
@ -58,11 +70,14 @@ def main():
logger.info(f'Sleeping for {INTERVAL} seconds') logger.info(f'Sleeping for {INTERVAL} seconds')
time.sleep(INTERVAL) time.sleep(INTERVAL)
# if ADD_OVERLAY:
# sys.exit(app.exec())
################### TODO ################## ################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models. # 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes. # 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6 # Create a way for it to just replace the text and provide only the translation on-screen. Qt6
if __name__ == "__main__": if __name__ == "__main__":
main() sys.exit(main())

View File

@ -6,6 +6,7 @@ load_dotenv(override=True)
### EDIT THESE VARIABLES ### ### EDIT THESE VARIABLES ###
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en' ### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
INTERVAL = int(os.getenv('INTERVAL')) INTERVAL = int(os.getenv('INTERVAL'))
### OCR ### OCR
@ -13,24 +14,34 @@ OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True')) OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
### Drawing/Overlay Config ### Drawing/Overlay Config
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
FONT_FILE = os.getenv('FONT_FILE') FONT_FILE = os.getenv('FONT_FILE')
FONT_SIZE = int(os.getenv('FONT_SIZE', 16)) FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
LINE_SPACING = int(os.getenv('LINE_SPACING', 3)) LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)')) REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000") FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True')) TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
### Translation ### Translation
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
GEMINI_KEY = os.getenv('GEMINI_KEY')
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200)) MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja') SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
TARGET_LANG = os.getenv('TARGET_LANG', 'en') TARGET_LANG = os.getenv('TARGET_LANG', 'en')
### Local Translation
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True')) TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
################################################################################################### ###################################################################################################

305
create_overlay.py Normal file
View File

@ -0,0 +1,305 @@
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
import sys, io, os, signal, time, platform
from PIL import Image
from dataclasses import dataclass
from typing import List, Optional
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
from logging_config import logger
def qpixmap_to_bytes(qpixmap):
qimage = qpixmap.toImage()
buffer = QBuffer()
buffer.open(QBuffer.ReadWrite)
qimage.save(buffer, "PNG")
return qimage
@dataclass
class TextEntry:
text: str
x: int
y: int
font: QFont = QFont('Arial', FONT_SIZE)
visible: bool = True
text_color: Qt.GlobalColor = Qt.GlobalColor.red
background_color: Optional[Qt.GlobalColor] = None
padding: int = 1
class TranslationOverlay(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Translation Overlay")
self.is_passthrough = True
self.text_entries: List[TextEntry] = []
self.setup_window_attributes()
self.setup_shortcuts()
self.closeEvent = lambda event: QApplication.quit()
self.default_font = QFont('Arial', FONT_SIZE)
self.text_entries_copy: List[TextEntry] = []
self.next_text_entries: List[TextEntry] = []
#self.show_background = True
self.background_opacity = 0.5
# self.setup_tray()
def prepare_for_capture(self):
"""Preserve current state and clear overlay"""
if ADD_OVERLAY:
self.text_entries_copy = self.text_entries.copy()
self.clear_all_text()
self.update()
def restore_after_capture(self):
"""Restore overlay state after capture"""
if ADD_OVERLAY:
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
self.text_entries = self.text_entries_copy.copy()
logger.debug(f"Restored text entries: {self.text_entries}")
self.update()
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
"""Add new text without triggering update"""
entry = TextEntry(
text=text,
x=x,
y=y,
font=font or self.default_font,
text_color=text_color
)
self.next_text_entries.append(entry)
def update_translation(self, ocr_output, translation):
# Update your overlay with new translations here
# You'll need to implement the logic to display the translations
self.clear_all_text()
self.text_entries = self.next_text_entries.copy()
self.next_text_entries.clear()
self.update()
def capture_behind(self, x=None, y=None, width=None, height=None):
"""
Capture the screen area behind the overlay.
If no coordinates provided, captures the area under the window.
"""
# Temporarily hide the window
self.hide()
# Get screen
screen = QScreen.grabWindow(
self.screen(),
0,
x if x is not None else self.x(),
y if y is not None else self.y(),
width if width is not None else self.width(),
height if height is not None else self.height()
)
# Show the window again
self.show()
screen_bytes = qpixmap_to_bytes(screen)
return screen_bytes
def clear_all_text(self):
"""Clear all text entries"""
self.text_entries.clear()
self.update()
def setup_window_attributes(self):
# Set window flags for overlay behavior
self.setWindowFlags(
Qt.WindowType.FramelessWindowHint |
Qt.WindowType.WindowStaysOnTopHint |
Qt.WindowType.Tool
)
# Set attributes for transparency
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
# Make the window cover the entire screen
self.setGeometry(QApplication.primaryScreen().geometry())
# Special handling for Wayland
if platform.system() == "Linux":
if "WAYLAND_DISPLAY" in os.environ:
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
def setup_shortcuts(self):
# Toggle visibility (Alt+Shift+T)
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
# Toggle passthrough mode (Alt+Shift+P)
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
# Quick hide (Escape)
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
self.hide_shortcut.activated.connect(self.hide)
# Clear all text (Alt+Shift+C)
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
self.clear_shortcut.activated.connect(self.clear_all_text)
# Toggle background
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
self.toggle_background_shortcut.activated.connect(self.toggle_background)
def toggle_visibility(self):
if self.isVisible():
self.hide()
else:
self.show()
def toggle_passthrough(self):
self.is_passthrough = not self.is_passthrough
if self.is_passthrough:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
else:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
self.hide()
self.show()
def toggle_background(self):
"""Toggle background visibility"""
self.show_background = not self.show_background
self.update()
def set_background_opacity(self, opacity: float):
"""Set background opacity (0.0 to 1.0)"""
self.background_opacity = max(0.0, min(1.0, opacity))
self.update()
def add_text_at_position(self, x: int, y: int, text: str):
"""Add new text at specific coordinates"""
entry = TextEntry(text, x, y)
self.text_entries.append(entry)
self.update()
def update_text_at_position(self, x: int, y: int, text: str):
"""Update text at specific coordinates, or add if none exists"""
# Look for existing text entry near these coordinates (within 5 pixels)
for entry in self.text_entries:
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
entry.text = text
self.update()
return
# If no existing entry found, add new one
self.add_text_at_position(x, y, text)
def setup_tray(self):
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
tray_menu = QMenu()
toggle_action = tray_menu.addAction("Show/Hide Overlay")
toggle_action.triggered.connect(self.toggle_visibility)
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
toggle_passthrough.triggered.connect(self.toggle_passthrough)
# Add background toggle to tray menu
toggle_background = tray_menu.addAction("Toggle Background")
toggle_background.triggered.connect(self.toggle_background)
clear_action = tray_menu.addAction("Clear All Text")
clear_action.triggered.connect(self.clear_all_text)
tray_menu.addSeparator()
quit_action = tray_menu.addAction("Quit")
quit_action.triggered.connect(self.clean_exit)
self.tray_icon.setToolTip("Translation Overlay")
self.tray_icon.setContextMenu(tray_menu)
self.tray_icon.show()
self.tray_icon.activated.connect(self.tray_activated)
def remove_text_at_position(self, x: int, y: int):
"""Remove text entry near specified coordinates"""
self.text_entries = [
entry for entry in self.text_entries
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
]
self.update()
def paintEvent(self, event):
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
# Draw each text entry
for entry in self.text_entries:
if not entry.visible:
continue
# Set the font for this specific entry
painter.setFont(entry.font)
text_metrics = painter.fontMetrics()
# Get the bounding rectangles for text
text_bounds = text_metrics.boundingRect(
entry.text
)
total_width = text_bounds.width()
total_height = text_bounds.height()
# Create rectangles for text placement
text_rect = QRect(entry.x, entry.y, total_width, total_height)
# Calculate background rectangle that encompasses both texts
if entry.background_color is not None:
bg_rect = QRect(entry.x - entry.padding,
entry.y - entry.padding,
total_width + (2 * entry.padding),
total_height + (2 * entry.padding))
painter.setPen(Qt.PenStyle.NoPen)
painter.setBrush(entry.background_color)
painter.drawRect(bg_rect)
# Draw the texts
painter.setPen(entry.text_color)
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
def handle_exit(signum, frame):
QApplication.quit()
def start_overlay():
app = QApplication(sys.argv)
# Enable Wayland support if available
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
app.setProperty("platform", "wayland")
overlay = TranslationOverlay()
overlay.show()
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
signal.signal(signal.SIGTERM, handle_exit)
return (app, overlay)
# sys.exit(app.exec())
if ADD_OVERLAY:
app, overlay = start_overlay()
if __name__ == "__main__":
ADD_OVERLAY = True
if not ADD_OVERLAY:
app, overlay = start_overlay()
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
capture = overlay.capture_behind()
capture.save("capture.png")
sys.exit(app.exec())

111
draw.py
View File

@ -1,44 +1,52 @@
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import os,io, sys, numpy as np import os, io, sys, numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from utils import romanize, intercepts, add_furigana from utils import romanize, intercepts, add_furigana
from logging_config import logger from logging_config import logger
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
from PySide6.QtGui import QFont
font = ImageFont.truetype(FONT_FILE, FONT_SIZE) font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes: def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
# Load the image from bytes """Modify the image bytes with the translated text and return the modified image bytes"""
with io.BytesIO(image_bytes) as byte_stream: with io.BytesIO(image_bytes) as byte_stream:
image = Image.open(byte_stream) image = Image.open(byte_stream)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
translate_image(draw, translation, ocr_output, MAX_TRANSLATE) draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
# Save the modified image back to bytes without changing the format # Save the modified image back to bytes without changing the format
with io.BytesIO() as byte_stream: with io.BytesIO() as byte_stream:
image.save(byte_stream, format=image.format) # Save in original format image.save(byte_stream, format=image.format) # Save in original format
modified_image_bytes = byte_stream.getvalue() modified_image_bytes = byte_stream.getvalue()
return modified_image_bytes return modified_image_bytes
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw: def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
translated_number = 0 translated_number = 0
bounding_boxes = [] bounding_boxes = []
logger.debug(f"Translations: {len(translation)} {translation}")
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output): for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
if translated_number >= max_translate: logger.debug(f"Untranslated phrase: {untranslated_phrase}")
if translated_number >= max_translate - 1:
break break
translate_one_phrase(draw, translation[i], position, bounding_boxes, untranslated_phrase) if replace:
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
else:
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
translated_number += 1 translated_number += 1
return draw return draw
def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tuple, bounding_boxes: list, untranslated_phrase: str) -> ImageDraw: def draw_one_phrase_add(draw: ImageDraw,
# Draw the bounding box translated_phrase: str,
top_left, _, _, _ = position position: tuple, bounding_boxes: list,
position = (top_left[0], top_left[1] - 60) untranslated_phrase: str) -> ImageDraw:
"""Draw the bounding box rectangle and text on the image above the original text"""
if SOURCE_LANG == 'ja': if SOURCE_LANG == 'ja':
untranslated_phrase = add_furigana(untranslated_phrase) untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja') romanized_phrase = romanize(untranslated_phrase, 'ja')
@ -50,26 +58,75 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
text_content = f"{translated_phrase}\n{untranslated_phrase}" text_content = f"{translated_phrase}\n{untranslated_phrase}"
lines = text_content.split('\n') lines = text_content.split('\n')
x,y = position
max_width = 0
total_height = 0
total_height = len(lines) * (LINE_HEIGHT + LINE_SPACING)
for line in lines:
bbox = draw.textbbox(position, line, font=font)
line_width = bbox[2] - bbox[0]
max_width = max(max_width, line_width)
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height) # Draw the bounding box
top_left, _, _, _ = position
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
right_edge = REGION[2]
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
y_onscreen = max(top_left[1] - total_height, 0)
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1] adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1) draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
position = (adjusted_x,adjusted_y) position = (adjusted_x,adjusted_y)
for line in lines: for line in lines:
draw.text(position, line, fill= TEXT_COLOR, font=font) draw.text(position, line, fill= FONT_COLOUR, font=font)
if ADD_OVERLAY:
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
adjusted_y += FONT_SIZE + LINE_SPACING adjusted_y += FONT_SIZE + LINE_SPACING
position = (adjusted_x,adjusted_y) position = (adjusted_x,adjusted_y)
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
### Only support for horizontal text atm, vertical text is on the todo list
def draw_one_phrase_replace(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
# Draw the bounding box
top_left, _, _, bottom_right = position
max_width = bottom_right[0] - top_left[0]
font_size = bottom_right[1] - top_left[1]
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
while True:
font = ImageFont.truetype(FONT_FILE, font_size)
if font.get_max_width < max_width:
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
break
elif font_size <= 1:
break
else:
font_size -= 1
def get_max_width(lines: list, font_path, font_size) -> int:
"""Get the maximum width of the text lines"""
font = ImageFont.truetype(font_path, font_size)
max_width = 0
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
for line in lines:
bbox = draw.textbbox((0,0), line, font=font)
line_width = bbox[2] - bbox[0]
max_width = max(max_width, line_width)
return max_width
def get_max_height(lines: list, font_size, line_spacing) -> int:
"""Get the maximum height of the text lines"""
return len(lines) * (font_size + line_spacing)
def adjust_if_intersects(x: int, y: int,
bounding_box: tuple, bounding_boxes: list,
untranslated_phrase: str,
max_width: int, total_height: int) -> tuple:
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
y = np.max([y,0]) y = np.max([y,0])
if len(bounding_boxes) > 0: if len(bounding_boxes) > 0:
for box in bounding_boxes: for box in bounding_boxes:
@ -79,5 +136,3 @@ def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: li
bounding_boxes.append(adjusted_bounding_box) bounding_boxes.append(adjusted_bounding_box)
return adjusted_bounding_box return adjusted_bounding_box

View File

@ -1,37 +1,63 @@
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from typing import List, Dict from typing import List, Dict
from dotenv import load_dotenv from dotenv import load_dotenv
import os , sys, torch, time import os , sys, torch, time, ast
from multiprocessing import Process, Event from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value
load_dotenv() load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device from config import device, GEMINI_API_KEY, GROQ_API_KEY
from logging_config import logger from logging_config import logger
from groq import Groq
import google.generativeai as genai
import asyncio
from functools import wraps
class Gemini():
def __init__(self, name, rate): class ApiModel():
self.name = name def __init__(self, model, # model name
rate, # rate of calls per minute
api_key, # api key for the model wrt the site
):
self.model = model
self.rate = rate self.rate = rate
self.curr_calls = 0 self.api_key = api_key
self.time = 0 self.curr_calls = Value('i', 0)
self.time = Value('i', 0)
self.process = None self.process = None
self.stop_event = Event() self.stop_event = Event()
self.site = None
self.from_lang = None
self.target_lang = None
self.request = None # request response from API
def __repr__(self): def __repr__(self):
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.' return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
def __str__(self): def __str__(self):
return self.name return self.model
def background_task(self): def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang
self.target_lang = target_lang
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
async def api_rate_check(self):
# Background task to manage the rate of calls to the API # Background task to manage the rate of calls to the API
while not self.stop_event.is_set(): while not self.stop_event.is_set():
time.sleep(5) start_time = time.monotonic()
self.time += 5 self.time.value += 5
if self.time >= 60: if self.time.value >= 60:
self.time = 0 self.time.value = 0
self.curr_calls = 0 self.curr_calls.value = 0
elapsed = time.monotonic() - start_time
# Sleep for exactly 5 seconds minus the elapsed time
sleep_time = max(0, 5 - elapsed)
await asyncio.sleep(sleep_time)
def background_task(self):
asyncio.run(self.api_rate_check())
def start(self): def start(self):
# Start the background task # Start the background task
@ -49,6 +75,68 @@ class Gemini():
if self.process.is_alive(): if self.process.is_alive():
self.process.terminate() self.process.terminate()
def request_func(request):
@wraps(request)
def wrapper(self, text, *args, **kwargs):
if self.curr_calls.value < self.rate:
# try:
response = request(self, text, *args, **kwargs)
self.curr_calls.value += 1
return response
# except Exception as e:
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
else:
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
raise TooManyRequests('Rate limit reached for this model.')
return wrapper
@request_func
def translate(self, request_fn, texts_to_translate):
if len(texts_to_translate) == 0:
return []
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
response = request_fn(self, prompt)
return ast.literal_eval(response.strip())
class Groqq(ApiModel):
def __init__(self, model, rate, api_key = GROQ_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Groq"
def request(self, content):
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": content,
}
],
model=self.model
)
return chat_completion.choices[0].message.content
def translate(self, texts_to_translate):
return super().translate(Groqq.request, texts_to_translate)
class Gemini(ApiModel):
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Gemini"
def request(self, content):
genai.configure(api_key=self.api_key)
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
return response.text.strip()
def translate(self, texts_to_translate):
return super().translate(Gemini.request, texts_to_translate)
class TranslationDataset(Dataset): class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512): def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
""" """
@ -174,9 +262,11 @@ def generate_text(
return all_generated_texts return all_generated_texts
if __name__ == '__main__': if __name__ == '__main__':
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer GROQ_API_KEY = os.getenv('GROQ_API_KEY')
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device) GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True) groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
tokenizer.src_lang = "zh" groq.set_lang('zh','en')
texts = ["你好",""] gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en'))) gemini.set_lang('zh','en')
print(gemini.translate(['荷兰咯']))
print(groq.translate(['荷兰咯']))

View File

@ -1,15 +1,17 @@
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import google.generativeai as genai import google.generativeai as genai
import torch, os, sys, ast import torch, os, sys, ast, json
from utils import standardize_lang from utils import standardize_lang
from functools import wraps from functools import wraps
from batching import generate_text, Gemini import random
import batching
from batching import generate_text, Gemini, Groq
from logging_config import logger from logging_config import logger
from multiprocessing import Process,Event from multiprocessing import Process,Event
# root dir # root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
############################## ##############################
# translation decorator # translation decorator
@ -27,36 +29,30 @@ def translate(translation_func):
############################### ###############################
def init_GEMINI(models_and_rates = None): def init_API_LLM(from_lang, target_lang):
if not models_and_rates: from_lang = standardize_lang(from_lang)['translation_model_lang']
## this is default for free tier target_lang = standardize_lang(target_lang)['translation_model_lang']
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref with open('api_models.json', 'r') as f:
models = [Gemini(name, rate) for name, rate in models_and_rates.items()] models_and_rates = json.load(f)
models = []
for class_type, class_models in models_and_rates.items():
cls = getattr(batching, class_type)
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
models.extend(instantiated_objects)
for model in models: for model in models:
model.start() model.start()
genai.configure(api_key=GEMINI_KEY) model.set_lang(from_lang, target_lang)
return models return models
def translate_GEMINI(text, models, from_lang, target_lang): def translate_API_LLM(text, models):
safety_settings = { random.shuffle(models)
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
for model in models: for model in models:
if model.curr_calls < model.rate: try:
try: return model.translate(text)
response = genai.GenerativeModel(model.name).generate_content(prompt, except:
safety_settings=safety_settings) continue
model.curr_calls += 1 logger.error("All models have failed to translate the text.")
logger.info(repr(model))
logger.info(f'Model Response: {response.text.strip()}')
return ast.literal_eval(response.text.strip())
except Exception as e:
logger.error(f"Error with model {model.name}. Error: {e}")
logger.error("No models available to translate. Please wait for a model to be available.")
############################### ###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well. # Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
@ -127,12 +123,6 @@ def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
else: else:
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.") raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
if model_type == 'gemini':
return init_GEMINI(**kwargs)
else:
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
def init_Causal_LLM(model_type, **kwargs): def init_Causal_LLM(model_type, **kwargs):
pass pass
### ###
@ -155,23 +145,8 @@ def translate_Seq_LLM(text,
### if you want to use any other translation, just define a translate function with input text and output text. ### if you want to use any other translation, just define a translate function with input text and output text.
# def translate_api(text):
#@translate
#def translate_Causal_LLM(text, model_type, model)
@translate #def translate_Causal_LLM(text, model_type, model)
def translate_API_LLM(text: list[str],
model_type: str, # 'gemma'
models: list, # list of objects of classes defined in batching.py
from_lang: str, # suggested to use ISO 639-1 codes
target_lang: str # suggested to use ISO 639-1 codes
) -> list[str]:
if model_type == 'gemini':
from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang']
return translate_GEMINI(text, models, from_lang, target_lang)
else:
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
@translate @translate
def translate_Causal_LLM(text: list[str], def translate_Causal_LLM(text: list[str],
@ -211,7 +186,5 @@ def translate_func(model):
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster ### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
if __name__ == "__main__": if __name__ == "__main__":
models = init_GEMINI() models = init_API_LLM('ja', 'en')
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en')) print(translate_API_LLM(['こんにちは'], models))
# model, tokenizer = init_M2M()
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))

View File

@ -64,4 +64,4 @@ def setup_logger(
print(f"Failed to setup logger: {e}") print(f"Failed to setup logger: {e}")
return None return None
logger = setup_logger('on_screen_translator', log_file='translate.log') logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)

115
qtapp.py Normal file
View File

@ -0,0 +1,115 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
from create_overlay import app, overlay
from typing import Optional, List
###################################################################################
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
QColor, QIcon, QImage, QPixmap)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
from dataclasses import dataclass
class TranslationThread(QThread):
translation_ready = Signal(list, list) # Signal to send translation results
start_capture = Signal()
end_capture = Signal()
screen_capture = Signal(int, int, int, int)
def __init__(self, ocr, models, source_lang, target_lang, interval):
super().__init__()
self.ocr = ocr
self.models = models
self.source_lang = source_lang
self.target_lang = target_lang
self.interval = interval
self.running = True
self.prev_words = set()
self.runs = 0
def run(self):
while self.running:
self.start_capture.emit()
untranslated_image = printsc(REGION)
self.end_capture.emit()
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
if self.runs == 0:
logger.info('Initial run')
else:
logger.info(f'Run number: {self.runs}.')
self.runs += 1
curr_words = set(get_words(ocr_output))
if self.prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
translation = translate_API_LLM(to_translate, self.models)
logger.info(f'Translation from {to_translate} to\n {translation}')
# Emit the translation results
modify_image_bytes(byte_image, ocr_output, translation)
self.translation_ready.emit(ocr_output, translation)
self.prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {self.interval} seconds')
time.sleep(self.interval)
def stop(self):
self.running = False
def main():
# Initialize OCR
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
# Initialize translation
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
# Create and start translation thread
translation_thread = TranslationThread(
ocr=ocr,
models=models,
source_lang=SOURCE_LANG,
target_lang=TARGET_LANG,
interval=INTERVAL
)
# Connect translation results to overlay update
translation_thread.start_capture.connect(overlay.prepare_for_capture)
translation_thread.end_capture.connect(overlay.restore_after_capture)
translation_thread.translation_ready.connect(overlay.update_translation)
translation_thread.screen_capture.connect(overlay.capture_behind)
# Start the translation thread
translation_thread.start()
# Start Qt event loop
result = app.exec()
# Cleanup
translation_thread.stop()
translation_thread.wait()
return result
if __name__ == "__main__":
sys.exit(main())

6
web.py
View File

@ -1,12 +1,12 @@
from flask import Flask, Response, render_template from flask import Flask, Response, render_template
import threading import threading
import io import io
import translate import app
app = Flask(__name__) app = Flask(__name__)
# Global variable to hold the current image # Global variable to hold the current image
def curr_image(): def curr_image():
return translate.latest_image return app.latest_image
@app.route('/') @app.route('/')
def index(): def index():
@ -29,7 +29,7 @@ def stream_image():
if __name__ == '__main__': if __name__ == '__main__':
# Start the image updating thread # Start the image updating thread
threading.Thread(target=translate.main, daemon=True).start() threading.Thread(target=app.main, daemon=True).start()
# Start the Flask web server # Start the Flask web server
app.run(host='0.0.0.0', port=5000, debug=True) app.run(host='0.0.0.0', port=5000, debug=True)