Added support for Groqq API models. Created QT6 overlay app.
This commit is contained in:
parent
17e7f6526f
commit
499a2c3972
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,4 +4,5 @@ translate/
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
.*
|
.*
|
||||||
test.py
|
test.py
|
||||||
|
notebooks/
|
||||||
|
qttest.py
|
||||||
@ -1,4 +1,4 @@
|
|||||||
## Debugging Issues
|
## Debugging Issues
|
||||||
|
|
||||||
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-pip cudnn from python environment.
|
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment.
|
||||||
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
|
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
|
||||||
|
|||||||
13
api_models.json
Normal file
13
api_models.json
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"Gemini": {
|
||||||
|
"gemini-1.5-pro": 2,
|
||||||
|
"gemini-1.5-flash": 15,
|
||||||
|
"gemini-1.5-flash-8b": 8,
|
||||||
|
"gemini-1.0-pro": 15
|
||||||
|
},
|
||||||
|
"Groqq": {
|
||||||
|
"llama-3.2-90b-text-preview": 30,
|
||||||
|
"llama3-70b-8192": 30,
|
||||||
|
"mixtral-8x7b-32768": 30
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -9,9 +9,11 @@ from utils import printsc, convert_image_to_bytes, bytes_to_image
|
|||||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||||
from logging_config import logger
|
from logging_config import logger
|
||||||
from draw import modify_image_bytes
|
from draw import modify_image_bytes
|
||||||
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||||
###################################################################################
|
###################################################################################
|
||||||
|
|
||||||
|
ADD_OVERLAY = False
|
||||||
|
|
||||||
latest_image = None
|
latest_image = None
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -23,11 +25,21 @@ def main():
|
|||||||
|
|
||||||
##### Initialize the translation #####
|
##### Initialize the translation #####
|
||||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||||
models = init_API_LLM(TRANSLATION_MODEL)
|
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||||
###################################################################################
|
###################################################################################
|
||||||
runs = 0
|
runs = 0
|
||||||
|
app.exec()
|
||||||
while True:
|
while True:
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
overlay.clear_all_text()
|
||||||
|
|
||||||
untranslated_image = printsc(REGION)
|
untranslated_image = printsc(REGION)
|
||||||
|
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
overlay.text_entries = overlay.text_entries_copy
|
||||||
|
overlay.update()
|
||||||
|
overlay.text_entries.clear()
|
||||||
|
|
||||||
byte_image = convert_image_to_bytes(untranslated_image)
|
byte_image = convert_image_to_bytes(untranslated_image)
|
||||||
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||||
|
|
||||||
@ -46,7 +58,7 @@ def main():
|
|||||||
|
|
||||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||||
translation = translate_API_LLM(to_translate, TRANSLATION_MODEL, models, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
translation = translate_API_LLM(to_translate, models)
|
||||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||||
latest_image = bytes_to_image(translated_image)
|
latest_image = bytes_to_image(translated_image)
|
||||||
@ -58,6 +70,8 @@ def main():
|
|||||||
|
|
||||||
logger.info(f'Sleeping for {INTERVAL} seconds')
|
logger.info(f'Sleeping for {INTERVAL} seconds')
|
||||||
time.sleep(INTERVAL)
|
time.sleep(INTERVAL)
|
||||||
|
# if ADD_OVERLAY:
|
||||||
|
# sys.exit(app.exec())
|
||||||
|
|
||||||
################### TODO ##################
|
################### TODO ##################
|
||||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||||
@ -65,4 +79,5 @@ def main():
|
|||||||
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
sys.exit(main())
|
||||||
|
|
||||||
23
config.py
23
config.py
@ -6,6 +6,7 @@ load_dotenv(override=True)
|
|||||||
### EDIT THESE VARIABLES ###
|
### EDIT THESE VARIABLES ###
|
||||||
|
|
||||||
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||||
|
|
||||||
INTERVAL = int(os.getenv('INTERVAL'))
|
INTERVAL = int(os.getenv('INTERVAL'))
|
||||||
|
|
||||||
### OCR
|
### OCR
|
||||||
@ -13,24 +14,34 @@ OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy
|
|||||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
||||||
|
|
||||||
### Drawing/Overlay Config
|
### Drawing/Overlay Config
|
||||||
|
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
|
||||||
|
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
|
||||||
FONT_FILE = os.getenv('FONT_FILE')
|
FONT_FILE = os.getenv('FONT_FILE')
|
||||||
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
||||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
||||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
||||||
TEXT_COLOR = os.getenv('TEXT_COLOR', "#ff0000")
|
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
|
||||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
||||||
|
|
||||||
|
|
||||||
|
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
|
||||||
|
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
|
||||||
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
|
||||||
|
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
|
||||||
|
|
||||||
### Translation
|
### Translation
|
||||||
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
|
||||||
GEMINI_KEY = os.getenv('GEMINI_KEY')
|
|
||||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
|
||||||
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
|
||||||
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
|
||||||
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
||||||
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
||||||
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
||||||
|
|
||||||
|
|
||||||
|
### Local Translation
|
||||||
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
|
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
|
||||||
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
|
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
|
||||||
|
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
||||||
|
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
||||||
|
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||||
|
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||||
###################################################################################################
|
###################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
305
create_overlay.py
Normal file
305
create_overlay.py
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
|
||||||
|
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
|
||||||
|
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||||
|
QLabel, QSystemTrayIcon, QMenu)
|
||||||
|
import sys, io, os, signal, time, platform
|
||||||
|
from PIL import Image
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
|
||||||
|
from logging_config import logger
|
||||||
|
|
||||||
|
|
||||||
|
def qpixmap_to_bytes(qpixmap):
|
||||||
|
qimage = qpixmap.toImage()
|
||||||
|
buffer = QBuffer()
|
||||||
|
buffer.open(QBuffer.ReadWrite)
|
||||||
|
qimage.save(buffer, "PNG")
|
||||||
|
return qimage
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextEntry:
|
||||||
|
text: str
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
font: QFont = QFont('Arial', FONT_SIZE)
|
||||||
|
visible: bool = True
|
||||||
|
text_color: Qt.GlobalColor = Qt.GlobalColor.red
|
||||||
|
background_color: Optional[Qt.GlobalColor] = None
|
||||||
|
padding: int = 1
|
||||||
|
|
||||||
|
class TranslationOverlay(QMainWindow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("Translation Overlay")
|
||||||
|
self.is_passthrough = True
|
||||||
|
self.text_entries: List[TextEntry] = []
|
||||||
|
self.setup_window_attributes()
|
||||||
|
self.setup_shortcuts()
|
||||||
|
self.closeEvent = lambda event: QApplication.quit()
|
||||||
|
self.default_font = QFont('Arial', FONT_SIZE)
|
||||||
|
self.text_entries_copy: List[TextEntry] = []
|
||||||
|
self.next_text_entries: List[TextEntry] = []
|
||||||
|
#self.show_background = True
|
||||||
|
self.background_opacity = 0.5
|
||||||
|
# self.setup_tray()
|
||||||
|
|
||||||
|
def prepare_for_capture(self):
|
||||||
|
"""Preserve current state and clear overlay"""
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
self.text_entries_copy = self.text_entries.copy()
|
||||||
|
self.clear_all_text()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def restore_after_capture(self):
|
||||||
|
"""Restore overlay state after capture"""
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
|
||||||
|
self.text_entries = self.text_entries_copy.copy()
|
||||||
|
logger.debug(f"Restored text entries: {self.text_entries}")
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
|
||||||
|
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
|
||||||
|
"""Add new text without triggering update"""
|
||||||
|
entry = TextEntry(
|
||||||
|
text=text,
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
font=font or self.default_font,
|
||||||
|
text_color=text_color
|
||||||
|
)
|
||||||
|
self.next_text_entries.append(entry)
|
||||||
|
|
||||||
|
def update_translation(self, ocr_output, translation):
|
||||||
|
# Update your overlay with new translations here
|
||||||
|
# You'll need to implement the logic to display the translations
|
||||||
|
self.clear_all_text()
|
||||||
|
self.text_entries = self.next_text_entries.copy()
|
||||||
|
self.next_text_entries.clear()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def capture_behind(self, x=None, y=None, width=None, height=None):
|
||||||
|
"""
|
||||||
|
Capture the screen area behind the overlay.
|
||||||
|
If no coordinates provided, captures the area under the window.
|
||||||
|
"""
|
||||||
|
# Temporarily hide the window
|
||||||
|
self.hide()
|
||||||
|
|
||||||
|
# Get screen
|
||||||
|
screen = QScreen.grabWindow(
|
||||||
|
self.screen(),
|
||||||
|
0,
|
||||||
|
x if x is not None else self.x(),
|
||||||
|
y if y is not None else self.y(),
|
||||||
|
width if width is not None else self.width(),
|
||||||
|
height if height is not None else self.height()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show the window again
|
||||||
|
self.show()
|
||||||
|
screen_bytes = qpixmap_to_bytes(screen)
|
||||||
|
return screen_bytes
|
||||||
|
|
||||||
|
def clear_all_text(self):
|
||||||
|
"""Clear all text entries"""
|
||||||
|
self.text_entries.clear()
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def setup_window_attributes(self):
|
||||||
|
# Set window flags for overlay behavior
|
||||||
|
self.setWindowFlags(
|
||||||
|
Qt.WindowType.FramelessWindowHint |
|
||||||
|
Qt.WindowType.WindowStaysOnTopHint |
|
||||||
|
Qt.WindowType.Tool
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set attributes for transparency
|
||||||
|
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
|
||||||
|
|
||||||
|
# Make the window cover the entire screen
|
||||||
|
self.setGeometry(QApplication.primaryScreen().geometry())
|
||||||
|
|
||||||
|
# Special handling for Wayland
|
||||||
|
if platform.system() == "Linux":
|
||||||
|
if "WAYLAND_DISPLAY" in os.environ:
|
||||||
|
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
|
||||||
|
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
|
||||||
|
|
||||||
|
def setup_shortcuts(self):
|
||||||
|
# Toggle visibility (Alt+Shift+T)
|
||||||
|
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
|
||||||
|
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
|
||||||
|
|
||||||
|
# Toggle passthrough mode (Alt+Shift+P)
|
||||||
|
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
|
||||||
|
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
|
||||||
|
|
||||||
|
# Quick hide (Escape)
|
||||||
|
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
|
||||||
|
self.hide_shortcut.activated.connect(self.hide)
|
||||||
|
|
||||||
|
# Clear all text (Alt+Shift+C)
|
||||||
|
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
|
||||||
|
self.clear_shortcut.activated.connect(self.clear_all_text)
|
||||||
|
|
||||||
|
# Toggle background
|
||||||
|
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
|
||||||
|
self.toggle_background_shortcut.activated.connect(self.toggle_background)
|
||||||
|
|
||||||
|
def toggle_visibility(self):
|
||||||
|
if self.isVisible():
|
||||||
|
self.hide()
|
||||||
|
else:
|
||||||
|
self.show()
|
||||||
|
|
||||||
|
def toggle_passthrough(self):
|
||||||
|
self.is_passthrough = not self.is_passthrough
|
||||||
|
if self.is_passthrough:
|
||||||
|
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
|
||||||
|
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||||
|
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
|
||||||
|
else:
|
||||||
|
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
|
||||||
|
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||||
|
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
|
||||||
|
|
||||||
|
self.hide()
|
||||||
|
self.show()
|
||||||
|
|
||||||
|
def toggle_background(self):
|
||||||
|
"""Toggle background visibility"""
|
||||||
|
self.show_background = not self.show_background
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def set_background_opacity(self, opacity: float):
|
||||||
|
"""Set background opacity (0.0 to 1.0)"""
|
||||||
|
self.background_opacity = max(0.0, min(1.0, opacity))
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def add_text_at_position(self, x: int, y: int, text: str):
|
||||||
|
"""Add new text at specific coordinates"""
|
||||||
|
entry = TextEntry(text, x, y)
|
||||||
|
self.text_entries.append(entry)
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def update_text_at_position(self, x: int, y: int, text: str):
|
||||||
|
"""Update text at specific coordinates, or add if none exists"""
|
||||||
|
# Look for existing text entry near these coordinates (within 5 pixels)
|
||||||
|
for entry in self.text_entries:
|
||||||
|
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
|
||||||
|
entry.text = text
|
||||||
|
self.update()
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no existing entry found, add new one
|
||||||
|
self.add_text_at_position(x, y, text)
|
||||||
|
|
||||||
|
def setup_tray(self):
|
||||||
|
self.tray_icon = QSystemTrayIcon(self)
|
||||||
|
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
|
||||||
|
|
||||||
|
tray_menu = QMenu()
|
||||||
|
|
||||||
|
toggle_action = tray_menu.addAction("Show/Hide Overlay")
|
||||||
|
toggle_action.triggered.connect(self.toggle_visibility)
|
||||||
|
|
||||||
|
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
|
||||||
|
toggle_passthrough.triggered.connect(self.toggle_passthrough)
|
||||||
|
|
||||||
|
# Add background toggle to tray menu
|
||||||
|
toggle_background = tray_menu.addAction("Toggle Background")
|
||||||
|
toggle_background.triggered.connect(self.toggle_background)
|
||||||
|
|
||||||
|
clear_action = tray_menu.addAction("Clear All Text")
|
||||||
|
clear_action.triggered.connect(self.clear_all_text)
|
||||||
|
|
||||||
|
tray_menu.addSeparator()
|
||||||
|
|
||||||
|
quit_action = tray_menu.addAction("Quit")
|
||||||
|
quit_action.triggered.connect(self.clean_exit)
|
||||||
|
|
||||||
|
self.tray_icon.setToolTip("Translation Overlay")
|
||||||
|
self.tray_icon.setContextMenu(tray_menu)
|
||||||
|
self.tray_icon.show()
|
||||||
|
self.tray_icon.activated.connect(self.tray_activated)
|
||||||
|
|
||||||
|
def remove_text_at_position(self, x: int, y: int):
|
||||||
|
"""Remove text entry near specified coordinates"""
|
||||||
|
self.text_entries = [
|
||||||
|
entry for entry in self.text_entries
|
||||||
|
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
|
||||||
|
]
|
||||||
|
self.update()
|
||||||
|
|
||||||
|
def paintEvent(self, event):
|
||||||
|
painter = QPainter(self)
|
||||||
|
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
||||||
|
|
||||||
|
# Draw each text entry
|
||||||
|
for entry in self.text_entries:
|
||||||
|
if not entry.visible:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Set the font for this specific entry
|
||||||
|
painter.setFont(entry.font)
|
||||||
|
text_metrics = painter.fontMetrics()
|
||||||
|
|
||||||
|
# Get the bounding rectangles for text
|
||||||
|
text_bounds = text_metrics.boundingRect(
|
||||||
|
entry.text
|
||||||
|
)
|
||||||
|
total_width = text_bounds.width()
|
||||||
|
total_height = text_bounds.height()
|
||||||
|
|
||||||
|
# Create rectangles for text placement
|
||||||
|
text_rect = QRect(entry.x, entry.y, total_width, total_height)
|
||||||
|
# Calculate background rectangle that encompasses both texts
|
||||||
|
if entry.background_color is not None:
|
||||||
|
bg_rect = QRect(entry.x - entry.padding,
|
||||||
|
entry.y - entry.padding,
|
||||||
|
total_width + (2 * entry.padding),
|
||||||
|
total_height + (2 * entry.padding))
|
||||||
|
|
||||||
|
painter.setPen(Qt.PenStyle.NoPen)
|
||||||
|
painter.setBrush(entry.background_color)
|
||||||
|
painter.drawRect(bg_rect)
|
||||||
|
|
||||||
|
# Draw the texts
|
||||||
|
painter.setPen(entry.text_color)
|
||||||
|
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def handle_exit(signum, frame):
|
||||||
|
QApplication.quit()
|
||||||
|
|
||||||
|
def start_overlay():
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
|
||||||
|
# Enable Wayland support if available
|
||||||
|
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
|
||||||
|
app.setProperty("platform", "wayland")
|
||||||
|
|
||||||
|
overlay = TranslationOverlay()
|
||||||
|
|
||||||
|
overlay.show()
|
||||||
|
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
|
||||||
|
signal.signal(signal.SIGTERM, handle_exit)
|
||||||
|
return (app, overlay)
|
||||||
|
# sys.exit(app.exec())
|
||||||
|
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
app, overlay = start_overlay()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ADD_OVERLAY = True
|
||||||
|
if not ADD_OVERLAY:
|
||||||
|
app, overlay = start_overlay()
|
||||||
|
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
|
||||||
|
capture = overlay.capture_behind()
|
||||||
|
capture.save("capture.png")
|
||||||
|
sys.exit(app.exec())
|
||||||
109
draw.py
109
draw.py
@ -4,41 +4,49 @@ import os,io, sys, numpy as np
|
|||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||||
from utils import romanize, intercepts, add_furigana
|
from utils import romanize, intercepts, add_furigana
|
||||||
from logging_config import logger
|
from logging_config import logger
|
||||||
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, TEXT_COLOR, LINE_HEIGHT, TO_ROMANIZE
|
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
|
||||||
|
|
||||||
|
|
||||||
|
from PySide6.QtGui import QFont
|
||||||
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
|
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
|
||||||
# Load the image from bytes
|
"""Modify the image bytes with the translated text and return the modified image bytes"""
|
||||||
|
|
||||||
with io.BytesIO(image_bytes) as byte_stream:
|
with io.BytesIO(image_bytes) as byte_stream:
|
||||||
image = Image.open(byte_stream)
|
image = Image.open(byte_stream)
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
translate_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
||||||
|
|
||||||
# Save the modified image back to bytes without changing the format
|
# Save the modified image back to bytes without changing the format
|
||||||
with io.BytesIO() as byte_stream:
|
with io.BytesIO() as byte_stream:
|
||||||
image.save(byte_stream, format=image.format) # Save in original format
|
image.save(byte_stream, format=image.format) # Save in original format
|
||||||
modified_image_bytes = byte_stream.getvalue()
|
modified_image_bytes = byte_stream.getvalue()
|
||||||
|
|
||||||
return modified_image_bytes
|
return modified_image_bytes
|
||||||
|
|
||||||
def translate_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int) -> ImageDraw:
|
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
|
||||||
|
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
|
||||||
translated_number = 0
|
translated_number = 0
|
||||||
bounding_boxes = []
|
bounding_boxes = []
|
||||||
|
logger.debug(f"Translations: {len(translation)} {translation}")
|
||||||
|
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
|
||||||
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
||||||
if translated_number >= max_translate:
|
logger.debug(f"Untranslated phrase: {untranslated_phrase}")
|
||||||
|
if translated_number >= max_translate - 1:
|
||||||
break
|
break
|
||||||
translate_one_phrase(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
if replace:
|
||||||
|
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||||
|
else:
|
||||||
|
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||||
translated_number += 1
|
translated_number += 1
|
||||||
return draw
|
return draw
|
||||||
|
|
||||||
def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tuple, bounding_boxes: list, untranslated_phrase: str) -> ImageDraw:
|
def draw_one_phrase_add(draw: ImageDraw,
|
||||||
# Draw the bounding box
|
translated_phrase: str,
|
||||||
top_left, _, _, _ = position
|
position: tuple, bounding_boxes: list,
|
||||||
position = (top_left[0], top_left[1] - 60)
|
untranslated_phrase: str) -> ImageDraw:
|
||||||
|
"""Draw the bounding box rectangle and text on the image above the original text"""
|
||||||
|
|
||||||
if SOURCE_LANG == 'ja':
|
if SOURCE_LANG == 'ja':
|
||||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||||
@ -50,26 +58,75 @@ def translate_one_phrase(draw: ImageDraw, translated_phrase: str, position: tupl
|
|||||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||||
|
|
||||||
lines = text_content.split('\n')
|
lines = text_content.split('\n')
|
||||||
x,y = position
|
|
||||||
max_width = 0
|
|
||||||
total_height = 0
|
|
||||||
total_height = len(lines) * (LINE_HEIGHT + LINE_SPACING)
|
|
||||||
for line in lines:
|
|
||||||
bbox = draw.textbbox(position, line, font=font)
|
|
||||||
line_width = bbox[2] - bbox[0]
|
|
||||||
max_width = max(max_width, line_width)
|
|
||||||
bounding_box = (x, y, x + max_width, y + total_height, untranslated_phrase)
|
|
||||||
|
|
||||||
adjust_if_intersects(x, y, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
# Draw the bounding box
|
||||||
|
top_left, _, _, _ = position
|
||||||
|
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
|
||||||
|
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
|
||||||
|
right_edge = REGION[2]
|
||||||
|
|
||||||
|
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||||
|
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
|
||||||
|
y_onscreen = max(top_left[1] - total_height, 0)
|
||||||
|
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
|
||||||
|
|
||||||
|
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||||
|
|
||||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||||
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
||||||
position = (adjusted_x,adjusted_y)
|
position = (adjusted_x,adjusted_y)
|
||||||
for line in lines:
|
for line in lines:
|
||||||
draw.text(position, line, fill= TEXT_COLOR, font=font)
|
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||||
|
if ADD_OVERLAY:
|
||||||
|
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
|
||||||
adjusted_y += FONT_SIZE + LINE_SPACING
|
adjusted_y += FONT_SIZE + LINE_SPACING
|
||||||
position = (adjusted_x,adjusted_y)
|
position = (adjusted_x,adjusted_y)
|
||||||
|
|
||||||
def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: list, untranslated_phrase: str, max_width: int, total_height: int) -> tuple:
|
|
||||||
|
|
||||||
|
|
||||||
|
### Only support for horizontal text atm, vertical text is on the todo list
|
||||||
|
def draw_one_phrase_replace(draw: ImageDraw,
|
||||||
|
translated_phrase: str,
|
||||||
|
position: tuple, bounding_boxes: list,
|
||||||
|
untranslated_phrase: str) -> ImageDraw:
|
||||||
|
"""Cover up old text and add translation directly on top"""
|
||||||
|
# Draw the bounding box
|
||||||
|
top_left, _, _, bottom_right = position
|
||||||
|
max_width = bottom_right[0] - top_left[0]
|
||||||
|
font_size = bottom_right[1] - top_left[1]
|
||||||
|
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
|
||||||
|
while True:
|
||||||
|
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||||
|
if font.get_max_width < max_width:
|
||||||
|
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
|
||||||
|
break
|
||||||
|
elif font_size <= 1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
font_size -= 1
|
||||||
|
|
||||||
|
def get_max_width(lines: list, font_path, font_size) -> int:
|
||||||
|
"""Get the maximum width of the text lines"""
|
||||||
|
font = ImageFont.truetype(font_path, font_size)
|
||||||
|
max_width = 0
|
||||||
|
dummy_image = Image.new("RGB", (1, 1))
|
||||||
|
draw = ImageDraw.Draw(dummy_image)
|
||||||
|
for line in lines:
|
||||||
|
bbox = draw.textbbox((0,0), line, font=font)
|
||||||
|
line_width = bbox[2] - bbox[0]
|
||||||
|
max_width = max(max_width, line_width)
|
||||||
|
return max_width
|
||||||
|
|
||||||
|
def get_max_height(lines: list, font_size, line_spacing) -> int:
|
||||||
|
"""Get the maximum height of the text lines"""
|
||||||
|
return len(lines) * (font_size + line_spacing)
|
||||||
|
|
||||||
|
def adjust_if_intersects(x: int, y: int,
|
||||||
|
bounding_box: tuple, bounding_boxes: list,
|
||||||
|
untranslated_phrase: str,
|
||||||
|
max_width: int, total_height: int) -> tuple:
|
||||||
|
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
|
||||||
y = np.max([y,0])
|
y = np.max([y,0])
|
||||||
if len(bounding_boxes) > 0:
|
if len(bounding_boxes) > 0:
|
||||||
for box in bounding_boxes:
|
for box in bounding_boxes:
|
||||||
@ -79,5 +136,3 @@ def adjust_if_intersects(x: int, y: int, bounding_box: tuple, bounding_boxes: li
|
|||||||
bounding_boxes.append(adjusted_bounding_box)
|
bounding_boxes.append(adjusted_bounding_box)
|
||||||
return adjusted_bounding_box
|
return adjusted_bounding_box
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,37 +1,63 @@
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os , sys, torch, time
|
import os , sys, torch, time, ast
|
||||||
from multiprocessing import Process, Event
|
from werkzeug.exceptions import TooManyRequests
|
||||||
|
from multiprocessing import Process, Event, Value
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
from config import device
|
from config import device, GEMINI_API_KEY, GROQ_API_KEY
|
||||||
from logging_config import logger
|
from logging_config import logger
|
||||||
|
from groq import Groq
|
||||||
|
import google.generativeai as genai
|
||||||
|
import asyncio
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
|
||||||
class Gemini():
|
|
||||||
def __init__(self, name, rate):
|
class ApiModel():
|
||||||
self.name = name
|
def __init__(self, model, # model name
|
||||||
|
rate, # rate of calls per minute
|
||||||
|
api_key, # api key for the model wrt the site
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
self.rate = rate
|
self.rate = rate
|
||||||
self.curr_calls = 0
|
self.api_key = api_key
|
||||||
self.time = 0
|
self.curr_calls = Value('i', 0)
|
||||||
|
self.time = Value('i', 0)
|
||||||
self.process = None
|
self.process = None
|
||||||
self.stop_event = Event()
|
self.stop_event = Event()
|
||||||
|
self.site = None
|
||||||
|
self.from_lang = None
|
||||||
|
self.target_lang = None
|
||||||
|
self.request = None # request response from API
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'Model: {self.name}; Rate: {self.rate}; Current_Calls: {self.curr_calls} calls; Time Passed: {self.time} seconds.'
|
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.name
|
return self.model
|
||||||
|
|
||||||
def background_task(self):
|
def set_lang(self, from_lang, target_lang):
|
||||||
|
self.from_lang = from_lang
|
||||||
|
self.target_lang = target_lang
|
||||||
|
|
||||||
|
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
|
||||||
|
async def api_rate_check(self):
|
||||||
# Background task to manage the rate of calls to the API
|
# Background task to manage the rate of calls to the API
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
time.sleep(5)
|
start_time = time.monotonic()
|
||||||
self.time += 5
|
self.time.value += 5
|
||||||
if self.time >= 60:
|
if self.time.value >= 60:
|
||||||
self.time = 0
|
self.time.value = 0
|
||||||
self.curr_calls = 0
|
self.curr_calls.value = 0
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
# Sleep for exactly 5 seconds minus the elapsed time
|
||||||
|
sleep_time = max(0, 5 - elapsed)
|
||||||
|
await asyncio.sleep(sleep_time)
|
||||||
|
|
||||||
|
def background_task(self):
|
||||||
|
asyncio.run(self.api_rate_check())
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
# Start the background task
|
# Start the background task
|
||||||
@ -49,6 +75,68 @@ class Gemini():
|
|||||||
if self.process.is_alive():
|
if self.process.is_alive():
|
||||||
self.process.terminate()
|
self.process.terminate()
|
||||||
|
|
||||||
|
def request_func(request):
|
||||||
|
@wraps(request)
|
||||||
|
def wrapper(self, text, *args, **kwargs):
|
||||||
|
if self.curr_calls.value < self.rate:
|
||||||
|
# try:
|
||||||
|
response = request(self, text, *args, **kwargs)
|
||||||
|
self.curr_calls.value += 1
|
||||||
|
return response
|
||||||
|
# except Exception as e:
|
||||||
|
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
|
||||||
|
raise TooManyRequests('Rate limit reached for this model.')
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@request_func
|
||||||
|
def translate(self, request_fn, texts_to_translate):
|
||||||
|
if len(texts_to_translate) == 0:
|
||||||
|
return []
|
||||||
|
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
|
||||||
|
response = request_fn(self, prompt)
|
||||||
|
return ast.literal_eval(response.strip())
|
||||||
|
|
||||||
|
class Groqq(ApiModel):
|
||||||
|
def __init__(self, model, rate, api_key = GROQ_API_KEY):
|
||||||
|
super().__init__(model, rate, api_key)
|
||||||
|
self.site = "Groq"
|
||||||
|
|
||||||
|
def request(self, content):
|
||||||
|
client = Groq()
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
return chat_completion.choices[0].message.content
|
||||||
|
|
||||||
|
def translate(self, texts_to_translate):
|
||||||
|
return super().translate(Groqq.request, texts_to_translate)
|
||||||
|
|
||||||
|
class Gemini(ApiModel):
|
||||||
|
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
|
||||||
|
super().__init__(model, rate, api_key)
|
||||||
|
self.site = "Gemini"
|
||||||
|
|
||||||
|
def request(self, content):
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
safety_settings = {
|
||||||
|
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||||
|
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
||||||
|
return response.text.strip()
|
||||||
|
|
||||||
|
def translate(self, texts_to_translate):
|
||||||
|
return super().translate(Gemini.request, texts_to_translate)
|
||||||
|
|
||||||
class TranslationDataset(Dataset):
|
class TranslationDataset(Dataset):
|
||||||
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
||||||
"""
|
"""
|
||||||
@ -174,9 +262,11 @@ def generate_text(
|
|||||||
return all_generated_texts
|
return all_generated_texts
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
||||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=True).to(device)
|
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
||||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=True)
|
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
||||||
tokenizer.src_lang = "zh"
|
groq.set_lang('zh','en')
|
||||||
texts = ["你好","我"]
|
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
||||||
print(generate_text(texts,model, tokenizer, forced_bos_token_id=tokenizer.get_lang_id('en')))
|
gemini.set_lang('zh','en')
|
||||||
|
print(gemini.translate(['荷兰咯']))
|
||||||
|
print(groq.translate(['荷兰咯']))
|
||||||
@ -1,15 +1,17 @@
|
|||||||
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import torch, os, sys, ast
|
import torch, os, sys, ast, json
|
||||||
from utils import standardize_lang
|
from utils import standardize_lang
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from batching import generate_text, Gemini
|
import random
|
||||||
|
import batching
|
||||||
|
from batching import generate_text, Gemini, Groq
|
||||||
from logging_config import logger
|
from logging_config import logger
|
||||||
from multiprocessing import Process,Event
|
from multiprocessing import Process,Event
|
||||||
# root dir
|
# root dir
|
||||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# translation decorator
|
# translation decorator
|
||||||
@ -27,36 +29,30 @@ def translate(translation_func):
|
|||||||
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
def init_GEMINI(models_and_rates = None):
|
def init_API_LLM(from_lang, target_lang):
|
||||||
if not models_and_rates:
|
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||||
## this is default for free tier
|
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||||
models_and_rates = {'gemini-1.5-pro': 2, 'gemini-1.5-flash': 15, 'gemini-1.5-flash-8b': 8, 'gemini-1.0-pro': 15} # order from most pref to least pref
|
with open('api_models.json', 'r') as f:
|
||||||
models = [Gemini(name, rate) for name, rate in models_and_rates.items()]
|
models_and_rates = json.load(f)
|
||||||
|
models = []
|
||||||
|
for class_type, class_models in models_and_rates.items():
|
||||||
|
cls = getattr(batching, class_type)
|
||||||
|
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
|
||||||
|
models.extend(instantiated_objects)
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
model.start()
|
model.start()
|
||||||
genai.configure(api_key=GEMINI_KEY)
|
model.set_lang(from_lang, target_lang)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def translate_GEMINI(text, models, from_lang, target_lang):
|
def translate_API_LLM(text, models):
|
||||||
safety_settings = {
|
random.shuffle(models)
|
||||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
|
||||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
|
||||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
|
||||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
|
||||||
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {from_lang} into {target_lang} and output as a Python list ensuring proper escaping of characters: {text}"
|
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.curr_calls < model.rate:
|
|
||||||
try:
|
try:
|
||||||
response = genai.GenerativeModel(model.name).generate_content(prompt,
|
return model.translate(text)
|
||||||
safety_settings=safety_settings)
|
except:
|
||||||
model.curr_calls += 1
|
continue
|
||||||
logger.info(repr(model))
|
logger.error("All models have failed to translate the text.")
|
||||||
logger.info(f'Model Response: {response.text.strip()}')
|
|
||||||
return ast.literal_eval(response.text.strip())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error with model {model.name}. Error: {e}")
|
|
||||||
logger.error("No models available to translate. Please wait for a model to be available.")
|
|
||||||
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
||||||
@ -127,12 +123,6 @@ def init_Seq_LLM(model_type, **kwargs): # model = 'opus' or 'm2m'
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
raise ValueError(f"Invalid model. Please use {' or '.join(curr_models)}.")
|
||||||
|
|
||||||
def init_API_LLM(model_type, **kwargs): # model = 'gemma'
|
|
||||||
if model_type == 'gemini':
|
|
||||||
return init_GEMINI(**kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid model type. Please use {' or '.join(api_llm_models)}.")
|
|
||||||
|
|
||||||
def init_Causal_LLM(model_type, **kwargs):
|
def init_Causal_LLM(model_type, **kwargs):
|
||||||
pass
|
pass
|
||||||
###
|
###
|
||||||
@ -155,23 +145,8 @@ def translate_Seq_LLM(text,
|
|||||||
|
|
||||||
|
|
||||||
### if you want to use any other translation, just define a translate function with input text and output text.
|
### if you want to use any other translation, just define a translate function with input text and output text.
|
||||||
# def translate_api(text):
|
|
||||||
#@translate
|
|
||||||
#def translate_Causal_LLM(text, model_type, model)
|
|
||||||
|
|
||||||
@translate
|
#def translate_Causal_LLM(text, model_type, model)
|
||||||
def translate_API_LLM(text: list[str],
|
|
||||||
model_type: str, # 'gemma'
|
|
||||||
models: list, # list of objects of classes defined in batching.py
|
|
||||||
from_lang: str, # suggested to use ISO 639-1 codes
|
|
||||||
target_lang: str # suggested to use ISO 639-1 codes
|
|
||||||
) -> list[str]:
|
|
||||||
if model_type == 'gemini':
|
|
||||||
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
|
||||||
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
|
||||||
return translate_GEMINI(text, models, from_lang, target_lang)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid model. Please use {' or '.join(api_llm_models)}.")
|
|
||||||
|
|
||||||
@translate
|
@translate
|
||||||
def translate_Causal_LLM(text: list[str],
|
def translate_Causal_LLM(text: list[str],
|
||||||
@ -211,7 +186,5 @@ def translate_func(model):
|
|||||||
|
|
||||||
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
|
### todo: if cuda is not detected, default to online translation as cpu just won't cut it bro. Parallel process it over multiple websites to make it faster
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
models = init_GEMINI()
|
models = init_API_LLM('ja', 'en')
|
||||||
print(translate_API_LLM(['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'gemini', models, from_lang='ch_sim', target_lang='en'))
|
print(translate_API_LLM(['こんにちは'], models))
|
||||||
# model, tokenizer = init_M2M()
|
|
||||||
# print(translate_Seq_LLM( ['想要借用合成台。', '有什么卖的?', '声音太大会有什么影响吗?', '不怕丢东西吗?', '再见。', '布纳马', '想买什么自己拿。把钱留在旁边就好。', '回顾', '隐藏'], 'm2m', model, tokenizer, from_lang='ch_sim', target_lang='en'))
|
|
||||||
@ -64,4 +64,4 @@ def setup_logger(
|
|||||||
print(f"Failed to setup logger: {e}")
|
print(f"Failed to setup logger: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger = setup_logger('on_screen_translator', log_file='translate.log')
|
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)
|
||||||
115
qtapp.py
Normal file
115
qtapp.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
###################################################################################
|
||||||
|
##### IMPORT LIBRARIES #####
|
||||||
|
import os, time, sys
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||||
|
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||||
|
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||||
|
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||||
|
from logging_config import logger
|
||||||
|
from draw import modify_image_bytes
|
||||||
|
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
|
||||||
|
from create_overlay import app, overlay
|
||||||
|
from typing import Optional, List
|
||||||
|
###################################################################################
|
||||||
|
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
|
||||||
|
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
|
||||||
|
QColor, QIcon, QImage, QPixmap)
|
||||||
|
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||||
|
QLabel, QSystemTrayIcon, QMenu)
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationThread(QThread):
|
||||||
|
translation_ready = Signal(list, list) # Signal to send translation results
|
||||||
|
start_capture = Signal()
|
||||||
|
end_capture = Signal()
|
||||||
|
screen_capture = Signal(int, int, int, int)
|
||||||
|
def __init__(self, ocr, models, source_lang, target_lang, interval):
|
||||||
|
super().__init__()
|
||||||
|
self.ocr = ocr
|
||||||
|
self.models = models
|
||||||
|
self.source_lang = source_lang
|
||||||
|
self.target_lang = target_lang
|
||||||
|
self.interval = interval
|
||||||
|
self.running = True
|
||||||
|
self.prev_words = set()
|
||||||
|
self.runs = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while self.running:
|
||||||
|
self.start_capture.emit()
|
||||||
|
untranslated_image = printsc(REGION)
|
||||||
|
self.end_capture.emit()
|
||||||
|
byte_image = convert_image_to_bytes(untranslated_image)
|
||||||
|
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
|
||||||
|
|
||||||
|
if self.runs == 0:
|
||||||
|
logger.info('Initial run')
|
||||||
|
else:
|
||||||
|
logger.info(f'Run number: {self.runs}.')
|
||||||
|
self.runs += 1
|
||||||
|
|
||||||
|
curr_words = set(get_words(ocr_output))
|
||||||
|
|
||||||
|
if self.prev_words != curr_words:
|
||||||
|
logger.info('Translating')
|
||||||
|
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||||
|
translation = translate_API_LLM(to_translate, self.models)
|
||||||
|
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||||
|
|
||||||
|
# Emit the translation results
|
||||||
|
modify_image_bytes(byte_image, ocr_output, translation)
|
||||||
|
self.translation_ready.emit(ocr_output, translation)
|
||||||
|
|
||||||
|
self.prev_words = curr_words
|
||||||
|
else:
|
||||||
|
logger.info("No new words to translate. Output will not refresh.")
|
||||||
|
|
||||||
|
logger.info(f'Sleeping for {self.interval} seconds')
|
||||||
|
time.sleep(self.interval)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
# Initialize OCR
|
||||||
|
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||||
|
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||||
|
|
||||||
|
# Initialize translation
|
||||||
|
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||||
|
|
||||||
|
|
||||||
|
# Create and start translation thread
|
||||||
|
translation_thread = TranslationThread(
|
||||||
|
ocr=ocr,
|
||||||
|
models=models,
|
||||||
|
source_lang=SOURCE_LANG,
|
||||||
|
target_lang=TARGET_LANG,
|
||||||
|
interval=INTERVAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect translation results to overlay update
|
||||||
|
translation_thread.start_capture.connect(overlay.prepare_for_capture)
|
||||||
|
translation_thread.end_capture.connect(overlay.restore_after_capture)
|
||||||
|
translation_thread.translation_ready.connect(overlay.update_translation)
|
||||||
|
translation_thread.screen_capture.connect(overlay.capture_behind)
|
||||||
|
# Start the translation thread
|
||||||
|
translation_thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
# Start Qt event loop
|
||||||
|
result = app.exec()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
translation_thread.stop()
|
||||||
|
translation_thread.wait()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
6
web.py
6
web.py
@ -1,12 +1,12 @@
|
|||||||
from flask import Flask, Response, render_template
|
from flask import Flask, Response, render_template
|
||||||
import threading
|
import threading
|
||||||
import io
|
import io
|
||||||
import translate
|
import app
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Global variable to hold the current image
|
# Global variable to hold the current image
|
||||||
def curr_image():
|
def curr_image():
|
||||||
return translate.latest_image
|
return app.latest_image
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
@ -29,7 +29,7 @@ def stream_image():
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Start the image updating thread
|
# Start the image updating thread
|
||||||
threading.Thread(target=translate.main, daemon=True).start()
|
threading.Thread(target=app.main, daemon=True).start()
|
||||||
|
|
||||||
# Start the Flask web server
|
# Start the Flask web server
|
||||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user