Compare commits

...

1 Commits

17 changed files with 788 additions and 711 deletions

View File

@ -1,13 +1,20 @@
{
"Gemini": {
"gemini-1.5-pro": 2,
"gemini-1.5-flash": 15,
"gemini-1.5-flash-8b": 8,
"gemini-1.0-pro": 15
"gemini-1.5-pro": { "rpmin": 2, "rpd": 50 },
"gemini-1.5-flash": { "rpmin": 15, "rpd": 1500 },
"gemini-1.5-flash-8b": { "rpmin": 15, "rpd": 1500 },
"gemini-1.0-pro": { "rpmin": 15, "rpd": 1500 }
},
"Groqq": {
"llama-3.2-90b-text-preview": 30,
"llama3-70b-8192": 30,
"mixtral-8x7b-32768": 30
"Groq": {
"llama-3.2-90b-text-preview": { "rpmin": 30, "rpd": 7000 },
"llama3-70b-8192": { "rpmin": 30, "rpd": 14400 },
"mixtral-8x7b-32768": { "rpmin": 30, "rpd": 14400 },
"llama-3.1-70b-versatile": { "rpmin": 30, "rpd": 14400 },
"gemma2-9b-it": { "rpmin": 30, "rpd": 14400 },
"llama3-groq-8b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
"llama3-groq-70b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
"llama-3.2-90b-vision-preview": { "rpmin": 15, "rpd": 3500 },
"llama-3.2-11b-text-preview": { "rpmin": 30, "rpd": 7000 },
"llama-3.2-11b-vision-preview": { "rpmin": 30, "rpd": 7000 }
}
}

83
app.py
View File

@ -1,83 +0,0 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
###################################################################################
ADD_OVERLAY = False
latest_image = None
def main():
global latest_image
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
###################################################################################
runs = 0
app.exec()
while True:
if ADD_OVERLAY:
overlay.clear_all_text()
untranslated_image = printsc(REGION)
if ADD_OVERLAY:
overlay.text_entries = overlay.text_entries_copy
overlay.update()
overlay.text_entries.clear()
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.info(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
translation = translate_API_LLM(to_translate, models)
logger.info(f'Translation from {to_translate} to\n {translation}')
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
latest_image = bytes_to_image(translated_image)
# latest_image.show() # for debugging
prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {INTERVAL} seconds')
time.sleep(INTERVAL)
# if ADD_OVERLAY:
# sys.exit(app.exec())
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
if __name__ == "__main__":
sys.exit(main())

View File

@ -10,22 +10,31 @@ load_dotenv(override=True)
INTERVAL = int(os.getenv('INTERVAL'))
### OCR
IMAGE_CHANGE_THRESHOLD = float(os.getenv('IMAGE_CHANGE_THRESHOLD', 0.75)) # higher values mean more sensitivity to changes in the screen, too high and the screen will constantly refresh
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
### Drawing/Overlay Config
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
FONT_FILE = os.getenv('FONT_FILE')
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
FONT_SIZE_MAX = int(os.getenv('FONT_SIZE_MAX', 20))
FONT_SIZE_MIN = int(os.getenv('FONT_SIZE_MIN', 8))
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
DRAW_TRANSLATIONS_MODE = os.getenv('DRAW_TRANSLATIONS_MODE', 'add')
"""
`learn': adds translated text, original text (should be added so when texts get moved around the translation of which it references is understood) and (optionally with the other TO_ROMANIZE option) romanized text above the original text. Texts can overlap if squished into a corner. Works well for games where texts are sparser
'learn_cover': same as above but covers the original text with the translated text. Can help with readability and is less cluttered but with sufficiently dense text the texts can still overlap
'translation_only_cover': cover the original text with the translated text - will not show the original text at all but not affected by overlapping texts
"""
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
@ -44,12 +53,13 @@ BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
###################################################################################################
## Filepaths
API_MODELS_FILEPATH = os.path.join(os.path.dirname(__file__), 'api_models.json')
FONT_SIZE = int((FONT_SIZE_MAX + FONT_SIZE_MIN)/2)
LINE_HEIGHT = FONT_SIZE
if TRANSLATION_USE_GPU is False:
device = torch.device("cpu")
else:

View File

@ -1,305 +0,0 @@
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
import sys, io, os, signal, time, platform
from PIL import Image
from dataclasses import dataclass
from typing import List, Optional
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
from logging_config import logger
def qpixmap_to_bytes(qpixmap):
qimage = qpixmap.toImage()
buffer = QBuffer()
buffer.open(QBuffer.ReadWrite)
qimage.save(buffer, "PNG")
return qimage
@dataclass
class TextEntry:
text: str
x: int
y: int
font: QFont = QFont('Arial', FONT_SIZE)
visible: bool = True
text_color: Qt.GlobalColor = Qt.GlobalColor.red
background_color: Optional[Qt.GlobalColor] = None
padding: int = 1
class TranslationOverlay(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Translation Overlay")
self.is_passthrough = True
self.text_entries: List[TextEntry] = []
self.setup_window_attributes()
self.setup_shortcuts()
self.closeEvent = lambda event: QApplication.quit()
self.default_font = QFont('Arial', FONT_SIZE)
self.text_entries_copy: List[TextEntry] = []
self.next_text_entries: List[TextEntry] = []
#self.show_background = True
self.background_opacity = 0.5
# self.setup_tray()
def prepare_for_capture(self):
"""Preserve current state and clear overlay"""
if ADD_OVERLAY:
self.text_entries_copy = self.text_entries.copy()
self.clear_all_text()
self.update()
def restore_after_capture(self):
"""Restore overlay state after capture"""
if ADD_OVERLAY:
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
self.text_entries = self.text_entries_copy.copy()
logger.debug(f"Restored text entries: {self.text_entries}")
self.update()
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
"""Add new text without triggering update"""
entry = TextEntry(
text=text,
x=x,
y=y,
font=font or self.default_font,
text_color=text_color
)
self.next_text_entries.append(entry)
def update_translation(self, ocr_output, translation):
# Update your overlay with new translations here
# You'll need to implement the logic to display the translations
self.clear_all_text()
self.text_entries = self.next_text_entries.copy()
self.next_text_entries.clear()
self.update()
def capture_behind(self, x=None, y=None, width=None, height=None):
"""
Capture the screen area behind the overlay.
If no coordinates provided, captures the area under the window.
"""
# Temporarily hide the window
self.hide()
# Get screen
screen = QScreen.grabWindow(
self.screen(),
0,
x if x is not None else self.x(),
y if y is not None else self.y(),
width if width is not None else self.width(),
height if height is not None else self.height()
)
# Show the window again
self.show()
screen_bytes = qpixmap_to_bytes(screen)
return screen_bytes
def clear_all_text(self):
"""Clear all text entries"""
self.text_entries.clear()
self.update()
def setup_window_attributes(self):
# Set window flags for overlay behavior
self.setWindowFlags(
Qt.WindowType.FramelessWindowHint |
Qt.WindowType.WindowStaysOnTopHint |
Qt.WindowType.Tool
)
# Set attributes for transparency
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
# Make the window cover the entire screen
self.setGeometry(QApplication.primaryScreen().geometry())
# Special handling for Wayland
if platform.system() == "Linux":
if "WAYLAND_DISPLAY" in os.environ:
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
def setup_shortcuts(self):
# Toggle visibility (Alt+Shift+T)
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
# Toggle passthrough mode (Alt+Shift+P)
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
# Quick hide (Escape)
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
self.hide_shortcut.activated.connect(self.hide)
# Clear all text (Alt+Shift+C)
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
self.clear_shortcut.activated.connect(self.clear_all_text)
# Toggle background
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
self.toggle_background_shortcut.activated.connect(self.toggle_background)
def toggle_visibility(self):
if self.isVisible():
self.hide()
else:
self.show()
def toggle_passthrough(self):
self.is_passthrough = not self.is_passthrough
if self.is_passthrough:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
else:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
self.hide()
self.show()
def toggle_background(self):
"""Toggle background visibility"""
self.show_background = not self.show_background
self.update()
def set_background_opacity(self, opacity: float):
"""Set background opacity (0.0 to 1.0)"""
self.background_opacity = max(0.0, min(1.0, opacity))
self.update()
def add_text_at_position(self, x: int, y: int, text: str):
"""Add new text at specific coordinates"""
entry = TextEntry(text, x, y)
self.text_entries.append(entry)
self.update()
def update_text_at_position(self, x: int, y: int, text: str):
"""Update text at specific coordinates, or add if none exists"""
# Look for existing text entry near these coordinates (within 5 pixels)
for entry in self.text_entries:
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
entry.text = text
self.update()
return
# If no existing entry found, add new one
self.add_text_at_position(x, y, text)
def setup_tray(self):
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
tray_menu = QMenu()
toggle_action = tray_menu.addAction("Show/Hide Overlay")
toggle_action.triggered.connect(self.toggle_visibility)
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
toggle_passthrough.triggered.connect(self.toggle_passthrough)
# Add background toggle to tray menu
toggle_background = tray_menu.addAction("Toggle Background")
toggle_background.triggered.connect(self.toggle_background)
clear_action = tray_menu.addAction("Clear All Text")
clear_action.triggered.connect(self.clear_all_text)
tray_menu.addSeparator()
quit_action = tray_menu.addAction("Quit")
quit_action.triggered.connect(self.clean_exit)
self.tray_icon.setToolTip("Translation Overlay")
self.tray_icon.setContextMenu(tray_menu)
self.tray_icon.show()
self.tray_icon.activated.connect(self.tray_activated)
def remove_text_at_position(self, x: int, y: int):
"""Remove text entry near specified coordinates"""
self.text_entries = [
entry for entry in self.text_entries
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
]
self.update()
def paintEvent(self, event):
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
# Draw each text entry
for entry in self.text_entries:
if not entry.visible:
continue
# Set the font for this specific entry
painter.setFont(entry.font)
text_metrics = painter.fontMetrics()
# Get the bounding rectangles for text
text_bounds = text_metrics.boundingRect(
entry.text
)
total_width = text_bounds.width()
total_height = text_bounds.height()
# Create rectangles for text placement
text_rect = QRect(entry.x, entry.y, total_width, total_height)
# Calculate background rectangle that encompasses both texts
if entry.background_color is not None:
bg_rect = QRect(entry.x - entry.padding,
entry.y - entry.padding,
total_width + (2 * entry.padding),
total_height + (2 * entry.padding))
painter.setPen(Qt.PenStyle.NoPen)
painter.setBrush(entry.background_color)
painter.drawRect(bg_rect)
# Draw the texts
painter.setPen(entry.text_color)
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
def handle_exit(signum, frame):
QApplication.quit()
def start_overlay():
app = QApplication(sys.argv)
# Enable Wayland support if available
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
app.setProperty("platform", "wayland")
overlay = TranslationOverlay()
overlay.show()
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
signal.signal(signal.SIGTERM, handle_exit)
return (app, overlay)
# sys.exit(app.exec())
if ADD_OVERLAY:
app, overlay = start_overlay()
if __name__ == "__main__":
ADD_OVERLAY = True
if not ADD_OVERLAY:
app, overlay = start_overlay()
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
capture = overlay.capture_behind()
capture.save("capture.png")
sys.exit(app.exec())

59
data.py Normal file
View File

@ -0,0 +1,59 @@
from sqlalchemy import create_engine, Column, Index, Integer, String, MetaData, Table, DateTime, ForeignKey, Boolean
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import relationship, declarative_base, sessionmaker
import logging
from logging_config import logger
import os
# Set up the database connection
data_dir = os.path.join(os.path.dirname(__file__), 'database')
os.makedirs(data_dir, exist_ok=True)
database_file = os.path.join(os.path.dirname(__file__), data_dir, 'translations.db')
engine = create_engine(f'sqlite:///{database_file}', echo=False)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base()
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
class Api(Base):
__tablename__ = 'api'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(String, nullable=False)
site = Column(Integer, nullable=False)
rpmin = Column(Integer) # rate per minute
rph = Column(Integer) # rate per hour
rpd = Column(Integer) # rate per day
rpw = Column(Integer) # rate per week
rpmth = Column(Integer) # rate per month
rpy = Column(Integer) # rate per year
translations = relationship("Translations", back_populates="api")
class Translations(Base):
__tablename__ = 'translations'
id = Column(Integer, primary_key=True, autoincrement=True)
model_id = Column(Integer, ForeignKey('api.id'), nullable=False)
source_texts = Column(String, nullable=False) # as a json string
translated_texts = Column(String, nullable=False) # as a json string
source_lang = Column(String, nullable=False)
target_lang = Column(String, nullable=False)
timestamp = Column(DateTime, nullable=False)
translation_mismatch = Column(Boolean, nullable=False)
api = relationship("Api", back_populates="translations")
__table_args__ = (
Index('idx_timestamp', 'timestamp'),
)
def create_tables():
if not os.path.exists(database_file):
Base.metadata.create_all(engine)
logger.info(f"Database created at {database_file}")
else:
logger.info(f"Using Pre-existing Database at {database_file}.")

BIN
database/translations.db Normal file

Binary file not shown.

192
draw.py
View File

@ -1,14 +1,14 @@
from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageDraw, ImageFont, ImageFilter
import os, io, sys, numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from utils import romanize, intercepts, add_furigana
from logging_config import logger
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE_MAX,FONT_SIZE_MIN, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION, DRAW_TRANSLATIONS_MODE
from PySide6.QtGui import QFont
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
#### CREATE A CLASS LATER so it doesn't have to inherit the same arguments all the way too confusing :| its so ass like this man i had no foresight
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
"""Modify the image bytes with the translated text and return the modified image bytes"""
@ -24,45 +24,36 @@ def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -
modified_image_bytes = byte_stream.getvalue()
return modified_image_bytes
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, draw_mode: str = DRAW_TRANSLATIONS_MODE) -> ImageDraw:
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
translated_number = 0
bounding_boxes = []
logger.debug(f"Translations: {len(translation)} {translation}")
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
logger.debug(f"Untranslated phrase: {untranslated_phrase}")
if translated_number >= max_translate - 1:
if translated_number >= len(translation): # note if using api llm some issues may cause it to return less translations than expected
break
if replace:
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
else:
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
if draw_mode == 'learn':
draw_one_phrase_learn(draw, translation[i], position, bounding_boxes, untranslated_phrase)
elif draw_mode == 'translation_only':
draw_one_phrase_translation_only(draw, translation[i], position, bounding_boxes, untranslated_phrase)
elif draw_mode == 'learn_cover':
draw_one_phrase_learn_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
elif draw_mode == 'translation_only_cover':
draw_one_phrase_translation_only_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
translated_number += 1
return draw
def draw_one_phrase_add(draw: ImageDraw,
def draw_one_phrase_learn(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Draw the bounding box rectangle and text on the image above the original text"""
if SOURCE_LANG == 'ja':
untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja')
else:
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
if TO_ROMANIZE:
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
else:
text_content = f"{translated_phrase}\n{untranslated_phrase}"
lines = text_content.split('\n')
lines = get_lines(untranslated_phrase, translated_phrase)
# Draw the bounding box
top_left, _, _, _ = position
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
top_left, _, bottom_right,_ = position
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
max_width = get_max_width(lines, FONT_FILE, font_size)
total_height = get_max_height(lines, font_size, LINE_SPACING)
font = ImageFont.truetype(FONT_FILE, font_size)
right_edge = REGION[2]
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
@ -75,58 +66,142 @@ def draw_one_phrase_add(draw: ImageDraw,
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
position = (adjusted_x,adjusted_y)
for line in lines:
if FONT_COLOUR == 'rainbow':
rainbow_text(draw, line, *position, font)
else:
draw.text(position, line, fill= FONT_COLOUR, font=font)
if ADD_OVERLAY:
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
adjusted_y += FONT_SIZE + LINE_SPACING
adjusted_y += font_size + LINE_SPACING
position = (adjusted_x,adjusted_y)
### Only support for horizontal text atm, vertical text is on the todo list
def draw_one_phrase_replace(draw: ImageDraw,
def draw_one_phrase_translation_only_cover(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
# Draw the bounding box
top_left, _, _, bottom_right = position
top_left, _, bottom_right, _ = position
bounding_boxes.append((top_left[0], top_left[1], bottom_right[0], bottom_right[1], untranslated_phrase)) # Debugging purposes
max_width = bottom_right[0] - top_left[0]
font_size = bottom_right[1] - top_left[1]
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
while True:
font = ImageFont.truetype(FONT_FILE, font_size)
if font.get_max_width < max_width:
phrase_width = get_max_width(translated_phrase, FONT_FILE, font_size)
rectangle = get_rectangle_coordinates(translated_phrase, top_left, FONT_FILE, font_size, LINE_SPACING)
if phrase_width < max_width:
draw.rectangle(rectangle, fill=FILL_COLOUR)
if FONT_COLOUR == 'rainbow':
rainbow_text(draw, translated_phrase, *top_left, font)
else:
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
break
elif font_size <= 1:
elif font_size <= FONT_SIZE_MIN:
break
else:
font_size -= 1
def get_max_width(lines: list, font_path, font_size) -> int:
def draw_one_phrase_learn_cover(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
lines = get_lines(untranslated_phrase, translated_phrase)
# Draw the bounding box
top_left, _, bottom_right,_ = position
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
max_width = get_max_width(lines, FONT_FILE, font_size)
total_height = get_max_height(lines, font_size, LINE_SPACING)
font = ImageFont.truetype(FONT_FILE, font_size)
right_edge = REGION[2]
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
y_onscreen = max(top_left[1] - int(total_height/3), 0)
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="black", width=2, radius=5)
position = (adjusted_x,adjusted_y)
for line in lines:
if FONT_COLOUR == 'rainbow': # easter egg yay
rainbow_text(draw, line, *position, font)
else:
draw.text(position, line, fill= FONT_COLOUR, font=font)
adjusted_y += font_size + LINE_SPACING
position = (adjusted_x,adjusted_y)
def draw_one_phrase_translation_only(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
# Draw the bounding box
pass
def get_rectangle_coordinates(lines: list | str, top_left: tuple | list, font_path, font_size, line_spacing, padding: int = 1) -> list:
"""Get the coordinates of the rectangle surrounding the text"""
text_width = get_max_width(lines, font_path, font_size)
text_height = get_max_height(lines, font_size, line_spacing)
x1 = top_left[0] - padding
y1 = top_left[1] - padding
x2 = top_left[0] + text_width + padding
y2 = top_left[1] + text_height + padding
return [(x1,y1), (x2,y2)]
def get_max_width(lines: list | str, font_path, font_size) -> int:
"""Get the maximum width of the text lines"""
font = ImageFont.truetype(font_path, font_size)
max_width = 0
dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image)
if isinstance(lines, list):
for line in lines:
bbox = draw.textbbox((0,0), line, font=font)
line_width = bbox[2] - bbox[0]
max_width = max(max_width, line_width)
else:
bbox = draw.textbbox((0,0), lines, font=font)
max_width = bbox[2] - bbox[0]
return max_width
def get_max_height(lines: list, font_size, line_spacing) -> int:
def get_max_height(lines: list | str, font_size, line_spacing) -> int:
"""Get the maximum height of the text lines"""
return len(lines) * (font_size + line_spacing)
no_of_lines = len(lines) if isinstance(lines, list) else 1
return no_of_lines * (font_size + line_spacing)
def get_lines(untranslated_phrase: str, translated_phrase: str) -> list:
"""Get the translated. untranslated and optionally the romanised text as a list"""
if SOURCE_LANG == 'ja':
untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja')
else:
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
if TO_ROMANIZE:
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
else:
text_content = f"{translated_phrase}\n{untranslated_phrase}"
return text_content.split('\n')
def adjust_if_intersects(x: int, y: int,
bounding_box: tuple, bounding_boxes: list,
untranslated_phrase: str,
max_width: int, total_height: int) -> tuple:
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
"""Adjust the y coordinate every time the bounding box intersects with any previous bounding boxes. OCR returns results from top to bottom so it works."""
y = np.max([y,0])
if len(bounding_boxes) > 0:
for box in bounding_boxes:
@ -136,3 +211,36 @@ def adjust_if_intersects(x: int, y: int,
bounding_boxes.append(adjusted_bounding_box)
return adjusted_bounding_box
def get_font_size(y_1, y_2, font_size_max: int, font_size_min: int) -> int:
"""Get the average of the maximum and minimum font sizes"""
if font_size_min > font_size_max:
raise ValueError("Minimum font size cannot be greater than maximum font size")
font_size = min(
max(int(abs(2/3*(y_2-y_1))), font_size_min),
font_size_max)
return font_size
def rainbow_text(draw,text,x,y,font):
for i, letter in enumerate(text):
# Calculate hue for rainbow effect
# Convert HSV to RGB (using full saturation and value)
rgb = tuple(np.random.randint(50,255,3))
# Get the width of this letter
letter_bbox = draw.textbbox((x, y), letter, font=font)
letter_width = letter_bbox[2] - letter_bbox[0]
# Draw the letter
draw.text((x, y), letter, fill=rgb, font=font)
# Move x position for next letter
x += letter_width
if __name__ == "__main__":
pass

View File

@ -1,141 +1,325 @@
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict
from datetime import datetime, timedelta
from dotenv import load_dotenv
import os , sys, torch, time, ast
import os , sys, torch, time, ast, json, pytz
from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value
load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device, GEMINI_API_KEY, GROQ_API_KEY
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
from logging_config import logger
from groq import Groq
from groq import Groq as Groqq
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
import asyncio
import aiohttp
from functools import wraps
from data import session, Api, Translations
from typing import Optional
class ApiModel():
def __init__(self, model, # model name
rate, # rate of calls per minute
api_key, # api key for the model wrt the site
def __init__(self, model, # model name as defined by the API
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
api_key: Optional[str] = None, # api key for the model wrt the site
rpmin: Optional[int] = None, # rate of calls per minute
rph: Optional[int] = None, # rate of calls per hour
rpd: Optional[int] = None, # rate of calls per day
rpw: Optional[int] = None, # rate of calls per week
rpmth: Optional[int] = None, # rate of calls per month
rpy: Optional[int] = None # rate of calls per year
):
self.model = model
self.rate = rate
self.api_key = api_key
self.curr_calls = Value('i', 0)
self.time = Value('i', 0)
self.process = None
self.stop_event = Event()
self.site = None
self.model = model
self.rpmin = rpmin
self.rph = rph
self.rpd = rpd
self.rpw = rpw
self.rpmth = rpmth
self.rpy = rpy
self.site = site
self.from_lang = None
self.target_lang = None
self.request = None # request response from API
self.db_table = None
self.session_calls = 0
self._id = None
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
# Create the table if it does not already exist
def __repr__(self):
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
def __str__(self):
return self.model
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
def _get_db_model_id(self):
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if model:
return model.id
else:
return None
def _set_db_model_id(self):
self._id = self._get_db_model_id()
@staticmethod
def _get_time():
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang
self.target_lang = target_lang
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
async def api_rate_check(self):
# Background task to manage the rate of calls to the API
while not self.stop_event.is_set():
start_time = time.monotonic()
self.time.value += 5
if self.time.value >= 60:
self.time.value = 0
self.curr_calls.value = 0
elapsed = time.monotonic() - start_time
# Sleep for exactly 5 seconds minus the elapsed time
sleep_time = max(0, 5 - elapsed)
await asyncio.sleep(sleep_time)
def set_db_table(self, db_table):
self.db_table = db_table
def background_task(self):
asyncio.run(self.api_rate_check())
def update_db(self):
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if not api:
api = Api(model_name = self.model,
site = self.site,
rpmin = self.rpmin,
rph = self.rph,
rpd = self.rpd,
rpw = self.rpw,
rpmth = self.rpmth,
rpy = self.rpy)
session.add(api)
session.commit()
self._set_db_model_id()
else:
api.rpmin = self.rpmin
api.rph = self.rph
api.rpd = self.rpd
api.rpw = self.rpw
api.rpmth = self.rpmth
api.rpy = self.rpy
session.commit()
def start(self):
# Start the background task
self.process = Process(target=self.background_task)
self.process.daemon = True
self.process.start()
logger.info(f"Background process started with PID: {self.process.pid}")
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
translation = json.dumps(translation)
translation = Translations(source_texts = text, translated_texts = translation,
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
translation_mismatch = mismatch)
session.add(translation)
session.commit()
def stop(self):
# Stop the background task
logger.info(f"Stopping background process with PID: {self.process.pid}")
self.stop_event.set()
if self.process:
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
@staticmethod
def _single_period_calls_check(max_calls, call_count):
if not max_calls:
return True
if max_calls <= call_count:
return False
else:
return True
def request_func(request):
@wraps(request)
def wrapper(self, text, *args, **kwargs):
if self.curr_calls.value < self.rate:
def _are_rates_good(self):
curr_time = self._get_time()
min_ago = curr_time - timedelta(minutes=1)
hour_ago = curr_time - timedelta(hours=1)
day_ago = curr_time - timedelta(days=1)
week_ago = curr_time - timedelta(weeks=1)
month_ago = curr_time - timedelta(days=30)
year_ago = curr_time - timedelta(days=365)
min_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= min_ago
).count()
hour_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= hour_ago
).count()
day_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= day_ago
).count()
week_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= week_ago
).count()
month_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= month_ago
).count()
year_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= year_ago
).count()
if self._single_period_calls_check(self.rpmin, min_calls) \
and self._single_period_calls_check(self.rph, hour_calls) \
and self._single_period_calls_check(self.rpd, day_calls) \
and self._single_period_calls_check(self.rpw, week_calls) \
and self._single_period_calls_check(self.rpmth, month_calls) \
and self._single_period_calls_check(self.rpy, year_calls):
return True
else:
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
return False
# async def request_func(request):
# @wraps(request)
# async def wrapper(self, text, *args, **kwargs):
# if await self._are_rates_good():
# try:
response = request(self, text, *args, **kwargs)
self.curr_calls.value += 1
return response
# self.session_calls += 1
# response = await request(self, text, *args, **kwargs)
# return response
# except Exception as e:
# logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
else:
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
raise TooManyRequests('Rate limit reached for this model.')
return wrapper
# else:
# logger.error(f"Rate limit reached for this model.")
# raise TooManyRequests('Rate limit reached for this model.')
# return wrapper
@request_func
def translate(self, request_fn, texts_to_translate):
# @request_func
async def translate(self, texts_to_translate, store = False):
if isinstance(texts_to_translate, str):
texts_to_translate = [texts_to_translate]
if len(texts_to_translate) == 0:
return []
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
response = request_fn(self, prompt)
return ast.literal_eval(response.strip())
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
prompt = f"""INSTRUCTIONS:
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
- The translations must preserve the original order.
- Each translation must be from the Source language to the Target language
- Source language: {self.from_lang}
- Target language: {self.target_lang}
- Texts are provided in JSON array syntax.
- Respond using ONLY valid JSON array syntax.
- Do not include explanations or additional text
- Escape special characters properly
class Groqq(ApiModel):
def __init__(self, model, rate, api_key = GROQ_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Groq"
Input texts:
{texts_to_translate}
def request(self, content):
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": content,
Expected format:
["translation1", "translation2", ...]
Translation:"""
response = await self._request(prompt)
response_list = ast.literal_eval(response.strip())
logger.debug(repr(self))
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
logger.error(f"{self.model} model failed to translate all the texts. Number of translations to make: {len(texts_to_translate)}; Number of translated texts: {len(response_list)}.")
if store:
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
else:
if store:
self._db_add_translation(texts_to_translate, response_list)
print(response_list)
return response_list
class Groq(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GROQ_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Groq', **kwargs)
self.client = Groqq()
async def _request(self, content: str) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {GROQ_API_KEY}",
"Content-Type": "application/json"
},
json={
"messages": [{"role": "user", "content": content}],
"model": self.model
}
],
model=self.model
)
return chat_completion.choices[0].message.content
) as response:
response_json = await response.json()
return response_json["choices"][0]["message"]["content"]
# https://console.groq.com/settings/limits for limits
# def request(self, content):
# chat_completion = self.client.chat.completions.create(
# messages=[
# {
# "role": "user",
# "content": content,
# }
# ],
# model=self.model
# )
# return chat_completion.choices[0].message.content
def translate(self, texts_to_translate):
return super().translate(Groqq.request, texts_to_translate)
# async def translate(self, texts_to_translate):
# return super().translate(self.request, texts_to_translate)
class Gemini(ApiModel):
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Gemini"
def __init__(self, # model name as defined by the API
model,
api_key = GEMINI_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Google',
**kwargs)
def request(self, content):
genai.configure(api_key=self.api_key)
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
return response.text.strip()
# def request(self, content):
# genai.configure(api_key=self.api_key)
# safety_settings = {
# "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
# "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
# "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
# "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
# try:
# response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
# except ResourceExhausted as e:
# logger.error(f"Rate limited with {self.model}. Error: {e}")
# raise ResourceExhausted("Rate limited.")
# return response.text.strip()
def translate(self, texts_to_translate):
return super().translate(Gemini.request, texts_to_translate)
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
headers={
"Content-Type": "application/json"
},
json={
"contents": [{"parts": [{"text": content}]}],
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
}
]
}
) as response:
response_json = await response.json()
return response_json['candidates'][0]['content']['parts'][0]['text']
# async def translate(self, texts_to_translate):
# return super().translate(self.request, texts_to_translate)
###################################################################################################
### LOCAL LLM TRANSLATION
class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):

View File

@ -21,7 +21,6 @@ def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
def _paddle_ocr(ocr, image) -> list:
### return a list containing the bounding box, text and confidence of the detected text
result = ocr.ocr(image, cls=False)[0]
if not isinstance(result, list):
@ -32,28 +31,29 @@ def _paddle_ocr(ocr, image) -> list:
# EasyOCR has support for many languages
def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
langs = []
for lang in easy_languages:
langs.append(standardize_lang(lang)['easyocr_lang'])
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
return easyocr.Reader(easy_languages, gpu=use_GPU, **kwargs)
def _easy_ocr(ocr,image) -> list:
return ocr.readtext(image)
# RapidOCR mostly for mandarin and some other asian languages
# default only supports chinese and english
def _rapid_init(use_GPU=True, **kwargs):
return RapidOCR(use_gpu=use_GPU, **kwargs)
def _rapid_ocr(ocr, image) -> list:
return ocr(image)
return ocr(image)[0]
### Initialize the OCR model
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs):
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch_sim', use_GPU=True):
if model == 'paddle':
paddle_lang = standardize_lang(paddle_lang)['paddleocr_lang']
return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
elif model == 'easy':
return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU)
langs = []
for lang in easy_languages:
langs.append(standardize_lang(lang)['easyocr_lang'])
return _easy_init(easy_languages=langs, use_GPU=use_GPU)
elif model == 'rapid':
return _rapid_init(use_GPU=use_GPU)
@ -82,10 +82,11 @@ def _id_filtered(ocr, image, lang) -> list:
return results_no_eng
# ch_sim, ch_tra, ja, ko, en
# ch_sim, ch_tra, ja, ko, en input
def _id_lang(ocr, image, lang) -> list:
result = _identify(ocr, image)
lang = standardize_lang(lang)['id_model_lang']
print(result)
try:
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
except:

View File

@ -1,17 +1,16 @@
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import google.generativeai as genai
import torch, os, sys, ast, json
import torch, os, sys, ast, json, asyncio, batching, random
from typing import List, Optional, Set
from utils import standardize_lang
from functools import wraps
import random
import batching
from batching import generate_text, Gemini, Groq
from batching import generate_text, Gemini, Groq, ApiModel
from logging_config import logger
from multiprocessing import Process,Event
from asyncio import Task
# root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
##############################
# translation decorator
@ -32,27 +31,66 @@ def translate(translation_func):
def init_API_LLM(from_lang, target_lang):
from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang']
with open('api_models.json', 'r') as f:
with open(API_MODELS_FILEPATH, 'r') as f:
models_and_rates = json.load(f)
models = []
for class_type, class_models in models_and_rates.items():
cls = getattr(batching, class_type)
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
models.extend(instantiated_objects)
for model in models:
model.start()
model.update_db()
model.set_lang(from_lang, target_lang)
return models
def translate_API_LLM(text, models):
random.shuffle(models)
for model in models:
async def translate_API_LLM(texts_to_translate: List[str],
models: List[ApiModel],
call_size: int = 2,
stagger_delay: int = 2) -> List[str]:
async def try_translate(model: ApiModel) -> Optional[List[str]]:
try:
return model.translate(text)
except:
continue
result = await model.translate(texts_to_translate, store=True)
logger.debug(f'Try_translate result: {result}')
return result
except Exception as e:
logger.error(f"Translation failed for {model.model} from {model.site}: {e}")
return None
random.shuffle(models)
groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
for group in groups:
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
while tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_COMPLETED
)
logger.debug(f"Tasks done: {done}")
logger.debug(f"Tasks remaining: {pending}")
for task in done:
result = await task
logger.debug(f'Result: {result}')
if result is not None:
# Cancel remaining tasks
for t in pending:
t.cancel()
return result
logger.error("All models have failed to translate the text.")
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
# def translate_API_LLM(text, models):
# random.shuffle(models)
# logger.debug(f"All Models Available: {models}")
# for model in models:
# logger.info(f"Attempting translation with model {model}.")
# try:
# translation = model.translate(text)
# logger.debug(f"Translation obtained: {translation}")
# if translation or translation == []:
# return translation
# except Exception as e:
# logger.error(f"Error with model {repr(model)}. Error: {e}")
# continue
# logger.error("All models have failed to translate the text.")
# raise TypeError("Models have likely all outputted garbage translations or rate limited.")
###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.

View File

@ -4,7 +4,9 @@ import pyscreenshot as ImageGrab # wayland tings not sure if it will work on oth
import mss, io, os
from PIL import Image
import jaconv, MeCab, unidic, pykakasi
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
# for creating furigana
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
uroman = ur.Uroman()
@ -95,7 +97,7 @@ def contains_katakana(text):
# use kakasi to romanize japanese text
def romanize(text, lang):
if lang == 'zh':
if lang in ['zh','ch_sim','ch_tra']:
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
if lang == 'ja':
return kks.convert(text)[0]['hepburn']
@ -131,13 +133,13 @@ def standardize_lang(lang):
id_model_lang = 'zh'
elif lang == 'ja':
easyocr_lang = 'ja'
paddleocr_lang = 'ja'
paddleocr_lang = 'japan'
rapidocr_lang = 'ja'
translation_model_lang = 'ja'
id_model_lang = 'ja'
elif lang == 'ko':
easyocr_lang = 'korean'
paddleocr_lang = 'ko'
paddleocr_lang = 'korean'
rapidocr_lang = 'ko'
translation_model_lang = 'ko'
id_model_lang = 'ko'
@ -165,6 +167,23 @@ def which_ocr_lang(model):
else:
raise ValueError("Invalid OCR model. Please use one of 'easy', 'paddle', or 'rapid'.")
def similar_tfidf(list1,list2,threshold) -> float:
"""Calculate cosine similarity using TF-IDF vectors."""
if not list1 or not list2:
return 0.0
vectorizer = TfidfVectorizer()
all_texts = list1 + list2
tfidf_matrix = vectorizer.fit_transform(all_texts)
# Calculate average vectors for each list
vec1 = np.mean(tfidf_matrix[:len(list1)].toarray(), axis=0).reshape(1, -1)
vec2 = np.mean(tfidf_matrix[len(list1):].toarray(), axis=0).reshape(1, -1)
return float(cosine_similarity(vec1, vec2)[0, 0]) > threshold
if __name__ == "__main__":
# Example usage

View File

@ -48,8 +48,8 @@ def setup_logger(
# Create a formatter and set it for both handlers
formatter = logging.Formatter(
'%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
'%(asctime)s.%(msecs)03d - %(name)s - [%(levelname)s] %(message)s',
datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
@ -65,3 +65,4 @@ def setup_logger(
return None
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)

100
main.py Normal file
View File

@ -0,0 +1,100 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys, threading, subprocess
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image, similar_tfidf
from ocr import get_words, init_OCR, id_keep_source_lang
from data import Base, engine, create_tables
from draw import modify_image_bytes
import config, asyncio
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD
from logging_config import logger
import web_app
import view_buffer_app
###################################################################################
async def main():
###################################################################################
# Initialisation
##### Create the database if not present #####
create_tables()
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
###################################################################################
runs = 0
# label, app = view_buffer_app.create_viewer()
# try:
while True:
logger.debug("Capturing screen")
untranslated_image = printsc(REGION)
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.info(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if not similar_tfidf(list(curr_words), list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD) and prev_words != curr_words:
logger.info('Beginning Translation')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
try:
translation = await translate_API_LLM(to_translate, models, call_size = 3)
except TypeError as e:
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for 30 seconds.")
time.sleep(30)
continue
logger.debug('Translation complete. Modifying image.')
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
# view_buffer_app.show_buffer_image(translated_image, label)
web_app.latest_image = bytes_to_image(translated_image)
logger.debug("Image modified. Saving image.")
# web_app.latest_image.save('/home/James/Pictures/translated.png') # home use
# logger.debug("Image saved.")
prev_words = curr_words
else:
logger.info("Skipping translation. No significant change in the screen detected.")
logger.debug("Continuing to next iteration.")
# logger.debug(f'Sleeping for {INTERVAL} seconds')
asyncio.sleep(INTERVAL)
# finally:
# label.close()
# app.quit()
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
if __name__ == "__main__":
# subprocess.Popen(['feh','--auto-reload', '/home/James/Pictures/translated.png'])
# asyncio.run(main())
# Start the image updating thread
logger.info('Configuration:')
for i in dir(config):
if not callable(getattr(config, i)) and not i.startswith("__"):
logger.info(f'{i}: {getattr(config, i)}')
threading.Thread(target=asyncio.run, args=(main(),), daemon=True).start()
# Start the Flask web server
web_app.app.run(host='0.0.0.0', port=5000, debug=False)

115
qtapp.py
View File

@ -1,115 +0,0 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
from create_overlay import app, overlay
from typing import Optional, List
###################################################################################
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
QColor, QIcon, QImage, QPixmap)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
from dataclasses import dataclass
class TranslationThread(QThread):
translation_ready = Signal(list, list) # Signal to send translation results
start_capture = Signal()
end_capture = Signal()
screen_capture = Signal(int, int, int, int)
def __init__(self, ocr, models, source_lang, target_lang, interval):
super().__init__()
self.ocr = ocr
self.models = models
self.source_lang = source_lang
self.target_lang = target_lang
self.interval = interval
self.running = True
self.prev_words = set()
self.runs = 0
def run(self):
while self.running:
self.start_capture.emit()
untranslated_image = printsc(REGION)
self.end_capture.emit()
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
if self.runs == 0:
logger.info('Initial run')
else:
logger.info(f'Run number: {self.runs}.')
self.runs += 1
curr_words = set(get_words(ocr_output))
if self.prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
translation = translate_API_LLM(to_translate, self.models)
logger.info(f'Translation from {to_translate} to\n {translation}')
# Emit the translation results
modify_image_bytes(byte_image, ocr_output, translation)
self.translation_ready.emit(ocr_output, translation)
self.prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {self.interval} seconds')
time.sleep(self.interval)
def stop(self):
self.running = False
def main():
# Initialize OCR
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
# Initialize translation
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
# Create and start translation thread
translation_thread = TranslationThread(
ocr=ocr,
models=models,
source_lang=SOURCE_LANG,
target_lang=TARGET_LANG,
interval=INTERVAL
)
# Connect translation results to overlay update
translation_thread.start_capture.connect(overlay.prepare_for_capture)
translation_thread.end_capture.connect(overlay.restore_after_capture)
translation_thread.translation_ready.connect(overlay.update_translation)
translation_thread.screen_capture.connect(overlay.capture_behind)
# Start the translation thread
translation_thread.start()
# Start Qt event loop
result = app.exec()
# Cleanup
translation_thread.stop()
translation_thread.wait()
return result
if __name__ == "__main__":
sys.exit(main())

View File

@ -17,7 +17,7 @@
setInterval(function () {
document.getElementById("live-image").src =
"/image?" + new Date().getTime();
}, 3500); // Update every 2 seconds
}, 2500); // Update every 2.5 seconds. Beware that if the image fails to reload on time, the browser will continuously refresh without being able to display the images.
</script>
</body>
</html>

52
view_buffer_app.py Normal file
View File

@ -0,0 +1,52 @@
#### Same thread as main.py so it will be relatively unresponsive. Just for use locally for a faster image display from buffer.
from PySide6.QtWidgets import QApplication, QLabel
from PySide6.QtCore import Qt
from PySide6.QtGui import QImage, QPixmap
import sys
def create_viewer():
"""Create and return a QLabel widget for displaying images"""
app = QApplication.instance()
if app is None:
app = QApplication(sys.argv)
label = QLabel()
label.setWindowTitle("Image Viewer")
label.setMinimumSize(640, 480)
# Enable mouse tracking for potential future interactivity
label.setMouseTracking(True)
# Better scaling quality
label.setScaledContents(True)
label.show()
return label, app
def show_buffer_image(buffer, label):
"""
Display an image from buffer using PySide6
Parameters:
buffer: bytes
Raw image data in memory
label: QLabel
Qt label widget to display the image
"""
# Convert buffer to QImage
qimg = QImage.fromData(buffer)
# Convert to QPixmap and set to label
pixmap = QPixmap.fromImage(qimg)
# Scale with better quality
scaled_pixmap = pixmap.scaled(
label.size(),
Qt.KeepAspectRatio,
Qt.SmoothTransformation
)
label.setPixmap(scaled_pixmap)
# Process Qt events to update the display
QApplication.processEvents()

View File

@ -1,12 +1,12 @@
from flask import Flask, Response, render_template
import threading
import io
import app
app = Flask(__name__)
latest_image = None
# Global variable to hold the current image
def curr_image():
return app.latest_image
return latest_image
@app.route('/')
def index():
@ -29,7 +29,8 @@ def stream_image():
if __name__ == '__main__':
# Start the image updating thread
threading.Thread(target=app.main, daemon=True).start()
import main, asyncio
threading.Thread(target=asyncio.run, args=(main(),), daemon=True).start()
# Start the Flask web server
app.run(host='0.0.0.0', port=5000, debug=True)