Batched asynchronous API requests. Added additional draw options.
This commit is contained in:
parent
499a2c3972
commit
11600ae70f
@ -1,13 +1,20 @@
|
||||
{
|
||||
"Gemini": {
|
||||
"gemini-1.5-pro": 2,
|
||||
"gemini-1.5-flash": 15,
|
||||
"gemini-1.5-flash-8b": 8,
|
||||
"gemini-1.0-pro": 15
|
||||
"gemini-1.5-pro": { "rpmin": 2, "rpd": 50 },
|
||||
"gemini-1.5-flash": { "rpmin": 15, "rpd": 1500 },
|
||||
"gemini-1.5-flash-8b": { "rpmin": 15, "rpd": 1500 },
|
||||
"gemini-1.0-pro": { "rpmin": 15, "rpd": 1500 }
|
||||
},
|
||||
"Groqq": {
|
||||
"llama-3.2-90b-text-preview": 30,
|
||||
"llama3-70b-8192": 30,
|
||||
"mixtral-8x7b-32768": 30
|
||||
"Groq": {
|
||||
"llama-3.2-90b-text-preview": { "rpmin": 30, "rpd": 7000 },
|
||||
"llama3-70b-8192": { "rpmin": 30, "rpd": 14400 },
|
||||
"mixtral-8x7b-32768": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama-3.1-70b-versatile": { "rpmin": 30, "rpd": 14400 },
|
||||
"gemma2-9b-it": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama3-groq-8b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama3-groq-70b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama-3.2-90b-vision-preview": { "rpmin": 15, "rpd": 3500 },
|
||||
"llama-3.2-11b-text-preview": { "rpmin": 30, "rpd": 7000 },
|
||||
"llama-3.2-11b-vision-preview": { "rpmin": 30, "rpd": 7000 }
|
||||
}
|
||||
}
|
||||
|
||||
83
app.py
83
app.py
@ -1,83 +0,0 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||
###################################################################################
|
||||
|
||||
ADD_OVERLAY = False
|
||||
|
||||
latest_image = None
|
||||
|
||||
def main():
|
||||
global latest_image
|
||||
|
||||
##### Initialize the OCR #####
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
##### Initialize the translation #####
|
||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
###################################################################################
|
||||
runs = 0
|
||||
app.exec()
|
||||
while True:
|
||||
if ADD_OVERLAY:
|
||||
overlay.clear_all_text()
|
||||
|
||||
untranslated_image = printsc(REGION)
|
||||
|
||||
if ADD_OVERLAY:
|
||||
overlay.text_entries = overlay.text_entries_copy
|
||||
overlay.update()
|
||||
overlay.text_entries.clear()
|
||||
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||
|
||||
if runs == 0:
|
||||
logger.info('Initial run')
|
||||
prev_words = set()
|
||||
else:
|
||||
logger.info(f'Run number: {runs}.')
|
||||
runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
|
||||
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||
if prev_words != curr_words:
|
||||
logger.info('Translating')
|
||||
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
translation = translate_API_LLM(to_translate, models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||
latest_image = bytes_to_image(translated_image)
|
||||
# latest_image.show() # for debugging
|
||||
|
||||
prev_words = curr_words
|
||||
else:
|
||||
logger.info("No new words to translate. Output will not refresh.")
|
||||
|
||||
logger.info(f'Sleeping for {INTERVAL} seconds')
|
||||
time.sleep(INTERVAL)
|
||||
# if ADD_OVERLAY:
|
||||
# sys.exit(app.exec())
|
||||
|
||||
################### TODO ##################
|
||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
|
||||
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
22
config.py
22
config.py
@ -10,22 +10,31 @@ load_dotenv(override=True)
|
||||
INTERVAL = int(os.getenv('INTERVAL'))
|
||||
|
||||
### OCR
|
||||
IMAGE_CHANGE_THRESHOLD = float(os.getenv('IMAGE_CHANGE_THRESHOLD', 0.75)) # higher values mean more sensitivity to changes in the screen, too high and the screen will constantly refresh
|
||||
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
|
||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
||||
|
||||
|
||||
### Drawing/Overlay Config
|
||||
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
|
||||
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
|
||||
FONT_FILE = os.getenv('FONT_FILE')
|
||||
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
||||
FONT_SIZE_MAX = int(os.getenv('FONT_SIZE_MAX', 20))
|
||||
FONT_SIZE_MIN = int(os.getenv('FONT_SIZE_MIN', 8))
|
||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
||||
DRAW_TRANSLATIONS_MODE = os.getenv('DRAW_TRANSLATIONS_MODE', 'add')
|
||||
"""
|
||||
`learn': adds translated text, original text (should be added so when texts get moved around the translation of which it references is understood) and (optionally with the other TO_ROMANIZE option) romanized text above the original text. Texts can overlap if squished into a corner. Works well for games where texts are sparser
|
||||
'learn_cover': same as above but covers the original text with the translated text. Can help with readability and is less cluttered but with sufficiently dense text the texts can still overlap
|
||||
'translation_only_cover': cover the original text with the translated text - will not show the original text at all but not affected by overlapping texts
|
||||
"""
|
||||
|
||||
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
|
||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
||||
|
||||
|
||||
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
|
||||
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
|
||||
|
||||
@ -44,12 +53,13 @@ BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||
###################################################################################################
|
||||
|
||||
## Filepaths
|
||||
API_MODELS_FILEPATH = os.path.join(os.path.dirname(__file__), 'api_models.json')
|
||||
|
||||
|
||||
FONT_SIZE = int((FONT_SIZE_MAX + FONT_SIZE_MIN)/2)
|
||||
LINE_HEIGHT = FONT_SIZE
|
||||
|
||||
|
||||
|
||||
|
||||
if TRANSLATION_USE_GPU is False:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
|
||||
@ -1,305 +0,0 @@
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
import sys, io, os, signal, time, platform
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
|
||||
from logging_config import logger
|
||||
|
||||
|
||||
def qpixmap_to_bytes(qpixmap):
|
||||
qimage = qpixmap.toImage()
|
||||
buffer = QBuffer()
|
||||
buffer.open(QBuffer.ReadWrite)
|
||||
qimage.save(buffer, "PNG")
|
||||
return qimage
|
||||
|
||||
@dataclass
|
||||
class TextEntry:
|
||||
text: str
|
||||
x: int
|
||||
y: int
|
||||
font: QFont = QFont('Arial', FONT_SIZE)
|
||||
visible: bool = True
|
||||
text_color: Qt.GlobalColor = Qt.GlobalColor.red
|
||||
background_color: Optional[Qt.GlobalColor] = None
|
||||
padding: int = 1
|
||||
|
||||
class TranslationOverlay(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Translation Overlay")
|
||||
self.is_passthrough = True
|
||||
self.text_entries: List[TextEntry] = []
|
||||
self.setup_window_attributes()
|
||||
self.setup_shortcuts()
|
||||
self.closeEvent = lambda event: QApplication.quit()
|
||||
self.default_font = QFont('Arial', FONT_SIZE)
|
||||
self.text_entries_copy: List[TextEntry] = []
|
||||
self.next_text_entries: List[TextEntry] = []
|
||||
#self.show_background = True
|
||||
self.background_opacity = 0.5
|
||||
# self.setup_tray()
|
||||
|
||||
def prepare_for_capture(self):
|
||||
"""Preserve current state and clear overlay"""
|
||||
if ADD_OVERLAY:
|
||||
self.text_entries_copy = self.text_entries.copy()
|
||||
self.clear_all_text()
|
||||
self.update()
|
||||
|
||||
def restore_after_capture(self):
|
||||
"""Restore overlay state after capture"""
|
||||
if ADD_OVERLAY:
|
||||
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
|
||||
self.text_entries = self.text_entries_copy.copy()
|
||||
logger.debug(f"Restored text entries: {self.text_entries}")
|
||||
self.update()
|
||||
|
||||
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
|
||||
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
|
||||
"""Add new text without triggering update"""
|
||||
entry = TextEntry(
|
||||
text=text,
|
||||
x=x,
|
||||
y=y,
|
||||
font=font or self.default_font,
|
||||
text_color=text_color
|
||||
)
|
||||
self.next_text_entries.append(entry)
|
||||
|
||||
def update_translation(self, ocr_output, translation):
|
||||
# Update your overlay with new translations here
|
||||
# You'll need to implement the logic to display the translations
|
||||
self.clear_all_text()
|
||||
self.text_entries = self.next_text_entries.copy()
|
||||
self.next_text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def capture_behind(self, x=None, y=None, width=None, height=None):
|
||||
"""
|
||||
Capture the screen area behind the overlay.
|
||||
If no coordinates provided, captures the area under the window.
|
||||
"""
|
||||
# Temporarily hide the window
|
||||
self.hide()
|
||||
|
||||
# Get screen
|
||||
screen = QScreen.grabWindow(
|
||||
self.screen(),
|
||||
0,
|
||||
x if x is not None else self.x(),
|
||||
y if y is not None else self.y(),
|
||||
width if width is not None else self.width(),
|
||||
height if height is not None else self.height()
|
||||
)
|
||||
|
||||
# Show the window again
|
||||
self.show()
|
||||
screen_bytes = qpixmap_to_bytes(screen)
|
||||
return screen_bytes
|
||||
|
||||
def clear_all_text(self):
|
||||
"""Clear all text entries"""
|
||||
self.text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def setup_window_attributes(self):
|
||||
# Set window flags for overlay behavior
|
||||
self.setWindowFlags(
|
||||
Qt.WindowType.FramelessWindowHint |
|
||||
Qt.WindowType.WindowStaysOnTopHint |
|
||||
Qt.WindowType.Tool
|
||||
)
|
||||
|
||||
# Set attributes for transparency
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
|
||||
|
||||
# Make the window cover the entire screen
|
||||
self.setGeometry(QApplication.primaryScreen().geometry())
|
||||
|
||||
# Special handling for Wayland
|
||||
if platform.system() == "Linux":
|
||||
if "WAYLAND_DISPLAY" in os.environ:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
|
||||
|
||||
def setup_shortcuts(self):
|
||||
# Toggle visibility (Alt+Shift+T)
|
||||
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
|
||||
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
|
||||
|
||||
# Toggle passthrough mode (Alt+Shift+P)
|
||||
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
|
||||
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
|
||||
|
||||
# Quick hide (Escape)
|
||||
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
|
||||
self.hide_shortcut.activated.connect(self.hide)
|
||||
|
||||
# Clear all text (Alt+Shift+C)
|
||||
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
|
||||
self.clear_shortcut.activated.connect(self.clear_all_text)
|
||||
|
||||
# Toggle background
|
||||
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
|
||||
self.toggle_background_shortcut.activated.connect(self.toggle_background)
|
||||
|
||||
def toggle_visibility(self):
|
||||
if self.isVisible():
|
||||
self.hide()
|
||||
else:
|
||||
self.show()
|
||||
|
||||
def toggle_passthrough(self):
|
||||
self.is_passthrough = not self.is_passthrough
|
||||
if self.is_passthrough:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
|
||||
else:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
|
||||
|
||||
self.hide()
|
||||
self.show()
|
||||
|
||||
def toggle_background(self):
|
||||
"""Toggle background visibility"""
|
||||
self.show_background = not self.show_background
|
||||
self.update()
|
||||
|
||||
def set_background_opacity(self, opacity: float):
|
||||
"""Set background opacity (0.0 to 1.0)"""
|
||||
self.background_opacity = max(0.0, min(1.0, opacity))
|
||||
self.update()
|
||||
|
||||
|
||||
|
||||
def add_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Add new text at specific coordinates"""
|
||||
entry = TextEntry(text, x, y)
|
||||
self.text_entries.append(entry)
|
||||
self.update()
|
||||
|
||||
def update_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Update text at specific coordinates, or add if none exists"""
|
||||
# Look for existing text entry near these coordinates (within 5 pixels)
|
||||
for entry in self.text_entries:
|
||||
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
|
||||
entry.text = text
|
||||
self.update()
|
||||
return
|
||||
|
||||
# If no existing entry found, add new one
|
||||
self.add_text_at_position(x, y, text)
|
||||
|
||||
def setup_tray(self):
|
||||
self.tray_icon = QSystemTrayIcon(self)
|
||||
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
|
||||
|
||||
tray_menu = QMenu()
|
||||
|
||||
toggle_action = tray_menu.addAction("Show/Hide Overlay")
|
||||
toggle_action.triggered.connect(self.toggle_visibility)
|
||||
|
||||
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
|
||||
toggle_passthrough.triggered.connect(self.toggle_passthrough)
|
||||
|
||||
# Add background toggle to tray menu
|
||||
toggle_background = tray_menu.addAction("Toggle Background")
|
||||
toggle_background.triggered.connect(self.toggle_background)
|
||||
|
||||
clear_action = tray_menu.addAction("Clear All Text")
|
||||
clear_action.triggered.connect(self.clear_all_text)
|
||||
|
||||
tray_menu.addSeparator()
|
||||
|
||||
quit_action = tray_menu.addAction("Quit")
|
||||
quit_action.triggered.connect(self.clean_exit)
|
||||
|
||||
self.tray_icon.setToolTip("Translation Overlay")
|
||||
self.tray_icon.setContextMenu(tray_menu)
|
||||
self.tray_icon.show()
|
||||
self.tray_icon.activated.connect(self.tray_activated)
|
||||
|
||||
def remove_text_at_position(self, x: int, y: int):
|
||||
"""Remove text entry near specified coordinates"""
|
||||
self.text_entries = [
|
||||
entry for entry in self.text_entries
|
||||
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
|
||||
]
|
||||
self.update()
|
||||
|
||||
def paintEvent(self, event):
|
||||
painter = QPainter(self)
|
||||
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
||||
|
||||
# Draw each text entry
|
||||
for entry in self.text_entries:
|
||||
if not entry.visible:
|
||||
continue
|
||||
|
||||
# Set the font for this specific entry
|
||||
painter.setFont(entry.font)
|
||||
text_metrics = painter.fontMetrics()
|
||||
|
||||
# Get the bounding rectangles for text
|
||||
text_bounds = text_metrics.boundingRect(
|
||||
entry.text
|
||||
)
|
||||
total_width = text_bounds.width()
|
||||
total_height = text_bounds.height()
|
||||
|
||||
# Create rectangles for text placement
|
||||
text_rect = QRect(entry.x, entry.y, total_width, total_height)
|
||||
# Calculate background rectangle that encompasses both texts
|
||||
if entry.background_color is not None:
|
||||
bg_rect = QRect(entry.x - entry.padding,
|
||||
entry.y - entry.padding,
|
||||
total_width + (2 * entry.padding),
|
||||
total_height + (2 * entry.padding))
|
||||
|
||||
painter.setPen(Qt.PenStyle.NoPen)
|
||||
painter.setBrush(entry.background_color)
|
||||
painter.drawRect(bg_rect)
|
||||
|
||||
# Draw the texts
|
||||
painter.setPen(entry.text_color)
|
||||
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
|
||||
|
||||
|
||||
|
||||
def handle_exit(signum, frame):
|
||||
QApplication.quit()
|
||||
|
||||
def start_overlay():
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
# Enable Wayland support if available
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
|
||||
app.setProperty("platform", "wayland")
|
||||
|
||||
overlay = TranslationOverlay()
|
||||
|
||||
overlay.show()
|
||||
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
|
||||
signal.signal(signal.SIGTERM, handle_exit)
|
||||
return (app, overlay)
|
||||
# sys.exit(app.exec())
|
||||
|
||||
if ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ADD_OVERLAY = True
|
||||
if not ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
|
||||
capture = overlay.capture_behind()
|
||||
capture.save("capture.png")
|
||||
sys.exit(app.exec())
|
||||
59
data.py
Normal file
59
data.py
Normal file
@ -0,0 +1,59 @@
|
||||
from sqlalchemy import create_engine, Column, Index, Integer, String, MetaData, Table, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import relationship, declarative_base, sessionmaker
|
||||
import logging
|
||||
from logging_config import logger
|
||||
import os
|
||||
# Set up the database connection
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'database')
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
database_file = os.path.join(os.path.dirname(__file__), data_dir, 'translations.db')
|
||||
engine = create_engine(f'sqlite:///{database_file}', echo=False)
|
||||
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
|
||||
class Api(Base):
|
||||
__tablename__ = 'api'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(String, nullable=False)
|
||||
site = Column(Integer, nullable=False)
|
||||
rpmin = Column(Integer) # rate per minute
|
||||
rph = Column(Integer) # rate per hour
|
||||
rpd = Column(Integer) # rate per day
|
||||
rpw = Column(Integer) # rate per week
|
||||
rpmth = Column(Integer) # rate per month
|
||||
rpy = Column(Integer) # rate per year
|
||||
|
||||
translations = relationship("Translations", back_populates="api")
|
||||
|
||||
class Translations(Base):
|
||||
__tablename__ = 'translations'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_id = Column(Integer, ForeignKey('api.id'), nullable=False)
|
||||
source_texts = Column(String, nullable=False) # as a json string
|
||||
translated_texts = Column(String, nullable=False) # as a json string
|
||||
source_lang = Column(String, nullable=False)
|
||||
target_lang = Column(String, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False)
|
||||
translation_mismatch = Column(Boolean, nullable=False)
|
||||
api = relationship("Api", back_populates="translations")
|
||||
__table_args__ = (
|
||||
Index('idx_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
def create_tables():
|
||||
if not os.path.exists(database_file):
|
||||
Base.metadata.create_all(engine)
|
||||
logger.info(f"Database created at {database_file}")
|
||||
else:
|
||||
logger.info(f"Using Pre-existing Database at {database_file}.")
|
||||
|
||||
|
||||
BIN
database/translations.db
Normal file
BIN
database/translations.db
Normal file
Binary file not shown.
222
draw.py
222
draw.py
@ -1,14 +1,14 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||
import os, io, sys, numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from utils import romanize, intercepts, add_furigana
|
||||
from logging_config import logger
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
|
||||
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE_MAX,FONT_SIZE_MIN, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION, DRAW_TRANSLATIONS_MODE
|
||||
|
||||
from PySide6.QtGui import QFont
|
||||
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
||||
|
||||
#### CREATE A CLASS LATER so it doesn't have to inherit the same arguments all the way too confusing :| its so ass like this man i had no foresight
|
||||
|
||||
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
|
||||
"""Modify the image bytes with the translated text and return the modified image bytes"""
|
||||
@ -24,45 +24,36 @@ def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -
|
||||
modified_image_bytes = byte_stream.getvalue()
|
||||
return modified_image_bytes
|
||||
|
||||
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
|
||||
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, draw_mode: str = DRAW_TRANSLATIONS_MODE) -> ImageDraw:
|
||||
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
|
||||
translated_number = 0
|
||||
bounding_boxes = []
|
||||
logger.debug(f"Translations: {len(translation)} {translation}")
|
||||
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
|
||||
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
||||
logger.debug(f"Untranslated phrase: {untranslated_phrase}")
|
||||
if translated_number >= max_translate - 1:
|
||||
if translated_number >= len(translation): # note if using api llm some issues may cause it to return less translations than expected
|
||||
break
|
||||
if replace:
|
||||
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
else:
|
||||
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
if draw_mode == 'learn':
|
||||
draw_one_phrase_learn(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'translation_only':
|
||||
draw_one_phrase_translation_only(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'learn_cover':
|
||||
draw_one_phrase_learn_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'translation_only_cover':
|
||||
draw_one_phrase_translation_only_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
translated_number += 1
|
||||
return draw
|
||||
|
||||
def draw_one_phrase_add(draw: ImageDraw,
|
||||
def draw_one_phrase_learn(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Draw the bounding box rectangle and text on the image above the original text"""
|
||||
|
||||
if SOURCE_LANG == 'ja':
|
||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||
else:
|
||||
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
|
||||
if TO_ROMANIZE:
|
||||
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
|
||||
else:
|
||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||
|
||||
lines = text_content.split('\n')
|
||||
|
||||
lines = get_lines(untranslated_phrase, translated_phrase)
|
||||
# Draw the bounding box
|
||||
top_left, _, _, _ = position
|
||||
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
|
||||
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
|
||||
top_left, _, bottom_right,_ = position
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
max_width = get_max_width(lines, FONT_FILE, font_size)
|
||||
total_height = get_max_height(lines, font_size, LINE_SPACING)
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
right_edge = REGION[2]
|
||||
|
||||
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||
@ -75,58 +66,142 @@ def draw_one_phrase_add(draw: ImageDraw,
|
||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
for line in lines:
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
if ADD_OVERLAY:
|
||||
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
|
||||
adjusted_y += FONT_SIZE + LINE_SPACING
|
||||
if FONT_COLOUR == 'rainbow':
|
||||
rainbow_text(draw, line, *position, font)
|
||||
else:
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
adjusted_y += font_size + LINE_SPACING
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
|
||||
|
||||
### Only support for horizontal text atm, vertical text is on the todo list
|
||||
def draw_one_phrase_replace(draw: ImageDraw,
|
||||
def draw_one_phrase_translation_only_cover(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
top_left, _, _, bottom_right = position
|
||||
max_width = bottom_right[0] - top_left[0]
|
||||
font_size = bottom_right[1] - top_left[1]
|
||||
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
|
||||
while True:
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
if font.get_max_width < max_width:
|
||||
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
|
||||
break
|
||||
elif font_size <= 1:
|
||||
break
|
||||
else:
|
||||
font_size -= 1
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
top_left, _, bottom_right, _ = position
|
||||
bounding_boxes.append((top_left[0], top_left[1], bottom_right[0], bottom_right[1], untranslated_phrase)) # Debugging purposes
|
||||
max_width = bottom_right[0] - top_left[0]
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
while True:
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
phrase_width = get_max_width(translated_phrase, FONT_FILE, font_size)
|
||||
rectangle = get_rectangle_coordinates(translated_phrase, top_left, FONT_FILE, font_size, LINE_SPACING)
|
||||
|
||||
def get_max_width(lines: list, font_path, font_size) -> int:
|
||||
if phrase_width < max_width:
|
||||
draw.rectangle(rectangle, fill=FILL_COLOUR)
|
||||
if FONT_COLOUR == 'rainbow':
|
||||
rainbow_text(draw, translated_phrase, *top_left, font)
|
||||
else:
|
||||
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
|
||||
|
||||
break
|
||||
elif font_size <= FONT_SIZE_MIN:
|
||||
break
|
||||
else:
|
||||
font_size -= 1
|
||||
|
||||
def draw_one_phrase_learn_cover(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
lines = get_lines(untranslated_phrase, translated_phrase)
|
||||
# Draw the bounding box
|
||||
top_left, _, bottom_right,_ = position
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
max_width = get_max_width(lines, FONT_FILE, font_size)
|
||||
total_height = get_max_height(lines, font_size, LINE_SPACING)
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
right_edge = REGION[2]
|
||||
|
||||
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
|
||||
y_onscreen = max(top_left[1] - int(total_height/3), 0)
|
||||
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
|
||||
|
||||
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||
|
||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||
draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="black", width=2, radius=5)
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
for line in lines:
|
||||
if FONT_COLOUR == 'rainbow': # easter egg yay
|
||||
rainbow_text(draw, line, *position, font)
|
||||
else:
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
adjusted_y += font_size + LINE_SPACING
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
def draw_one_phrase_translation_only(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
pass
|
||||
|
||||
def get_rectangle_coordinates(lines: list | str, top_left: tuple | list, font_path, font_size, line_spacing, padding: int = 1) -> list:
|
||||
|
||||
"""Get the coordinates of the rectangle surrounding the text"""
|
||||
|
||||
text_width = get_max_width(lines, font_path, font_size)
|
||||
text_height = get_max_height(lines, font_size, line_spacing)
|
||||
x1 = top_left[0] - padding
|
||||
y1 = top_left[1] - padding
|
||||
x2 = top_left[0] + text_width + padding
|
||||
y2 = top_left[1] + text_height + padding
|
||||
return [(x1,y1), (x2,y2)]
|
||||
|
||||
def get_max_width(lines: list | str, font_path, font_size) -> int:
|
||||
"""Get the maximum width of the text lines"""
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
max_width = 0
|
||||
dummy_image = Image.new("RGB", (1, 1))
|
||||
draw = ImageDraw.Draw(dummy_image)
|
||||
for line in lines:
|
||||
bbox = draw.textbbox((0,0), line, font=font)
|
||||
line_width = bbox[2] - bbox[0]
|
||||
max_width = max(max_width, line_width)
|
||||
if isinstance(lines, list):
|
||||
for line in lines:
|
||||
bbox = draw.textbbox((0,0), line, font=font)
|
||||
line_width = bbox[2] - bbox[0]
|
||||
max_width = max(max_width, line_width)
|
||||
else:
|
||||
bbox = draw.textbbox((0,0), lines, font=font)
|
||||
max_width = bbox[2] - bbox[0]
|
||||
return max_width
|
||||
|
||||
def get_max_height(lines: list, font_size, line_spacing) -> int:
|
||||
def get_max_height(lines: list | str, font_size, line_spacing) -> int:
|
||||
"""Get the maximum height of the text lines"""
|
||||
return len(lines) * (font_size + line_spacing)
|
||||
no_of_lines = len(lines) if isinstance(lines, list) else 1
|
||||
return no_of_lines * (font_size + line_spacing)
|
||||
|
||||
def get_lines(untranslated_phrase: str, translated_phrase: str) -> list:
|
||||
"""Get the translated. untranslated and optionally the romanised text as a list"""
|
||||
if SOURCE_LANG == 'ja':
|
||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||
else:
|
||||
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
|
||||
if TO_ROMANIZE:
|
||||
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
|
||||
else:
|
||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||
return text_content.split('\n')
|
||||
|
||||
|
||||
def adjust_if_intersects(x: int, y: int,
|
||||
bounding_box: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str,
|
||||
max_width: int, total_height: int) -> tuple:
|
||||
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
|
||||
"""Adjust the y coordinate every time the bounding box intersects with any previous bounding boxes. OCR returns results from top to bottom so it works."""
|
||||
y = np.max([y,0])
|
||||
if len(bounding_boxes) > 0:
|
||||
for box in bounding_boxes:
|
||||
@ -136,3 +211,36 @@ def adjust_if_intersects(x: int, y: int,
|
||||
bounding_boxes.append(adjusted_bounding_box)
|
||||
return adjusted_bounding_box
|
||||
|
||||
|
||||
def get_font_size(y_1, y_2, font_size_max: int, font_size_min: int) -> int:
|
||||
"""Get the average of the maximum and minimum font sizes"""
|
||||
if font_size_min > font_size_max:
|
||||
raise ValueError("Minimum font size cannot be greater than maximum font size")
|
||||
font_size = min(
|
||||
max(int(abs(2/3*(y_2-y_1))), font_size_min),
|
||||
font_size_max)
|
||||
return font_size
|
||||
|
||||
|
||||
|
||||
|
||||
def rainbow_text(draw,text,x,y,font):
|
||||
for i, letter in enumerate(text):
|
||||
# Calculate hue for rainbow effect
|
||||
# Convert HSV to RGB (using full saturation and value)
|
||||
rgb = tuple(np.random.randint(50,255,3))
|
||||
# Get the width of this letter
|
||||
|
||||
letter_bbox = draw.textbbox((x, y), letter, font=font)
|
||||
letter_width = letter_bbox[2] - letter_bbox[0]
|
||||
|
||||
# Draw the letter
|
||||
draw.text((x, y), letter, fill=rgb, font=font)
|
||||
|
||||
# Move x position for next letter
|
||||
x += letter_width
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
@ -1,141 +1,325 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
import os , sys, torch, time, ast
|
||||
import os , sys, torch, time, ast, json, pytz
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
from multiprocessing import Process, Event, Value
|
||||
load_dotenv()
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
from config import device, GEMINI_API_KEY, GROQ_API_KEY
|
||||
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
|
||||
from logging_config import logger
|
||||
from groq import Groq
|
||||
from groq import Groq as Groqq
|
||||
import google.generativeai as genai
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from functools import wraps
|
||||
|
||||
from data import session, Api, Translations
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ApiModel():
|
||||
def __init__(self, model, # model name
|
||||
rate, # rate of calls per minute
|
||||
api_key, # api key for the model wrt the site
|
||||
def __init__(self, model, # model name as defined by the API
|
||||
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
|
||||
api_key: Optional[str] = None, # api key for the model wrt the site
|
||||
rpmin: Optional[int] = None, # rate of calls per minute
|
||||
rph: Optional[int] = None, # rate of calls per hour
|
||||
rpd: Optional[int] = None, # rate of calls per day
|
||||
rpw: Optional[int] = None, # rate of calls per week
|
||||
rpmth: Optional[int] = None, # rate of calls per month
|
||||
rpy: Optional[int] = None # rate of calls per year
|
||||
):
|
||||
self.model = model
|
||||
self.rate = rate
|
||||
self.api_key = api_key
|
||||
self.curr_calls = Value('i', 0)
|
||||
self.time = Value('i', 0)
|
||||
self.process = None
|
||||
self.stop_event = Event()
|
||||
self.site = None
|
||||
self.model = model
|
||||
self.rpmin = rpmin
|
||||
self.rph = rph
|
||||
self.rpd = rpd
|
||||
self.rpw = rpw
|
||||
self.rpmth = rpmth
|
||||
self.rpy = rpy
|
||||
self.site = site
|
||||
self.from_lang = None
|
||||
self.target_lang = None
|
||||
self.request = None # request response from API
|
||||
self.db_table = None
|
||||
self.session_calls = 0
|
||||
self._id = None
|
||||
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
|
||||
# Create the table if it does not already exist
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
|
||||
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
|
||||
|
||||
def __str__(self):
|
||||
return self.model
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
def _get_db_model_id(self):
|
||||
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
||||
if model:
|
||||
return model.id
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set_db_model_id(self):
|
||||
self._id = self._get_db_model_id()
|
||||
|
||||
@staticmethod
|
||||
def _get_time():
|
||||
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
|
||||
|
||||
def set_lang(self, from_lang, target_lang):
|
||||
self.from_lang = from_lang
|
||||
self.target_lang = target_lang
|
||||
|
||||
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
|
||||
async def api_rate_check(self):
|
||||
# Background task to manage the rate of calls to the API
|
||||
while not self.stop_event.is_set():
|
||||
start_time = time.monotonic()
|
||||
self.time.value += 5
|
||||
if self.time.value >= 60:
|
||||
self.time.value = 0
|
||||
self.curr_calls.value = 0
|
||||
elapsed = time.monotonic() - start_time
|
||||
# Sleep for exactly 5 seconds minus the elapsed time
|
||||
sleep_time = max(0, 5 - elapsed)
|
||||
await asyncio.sleep(sleep_time)
|
||||
def set_db_table(self, db_table):
|
||||
self.db_table = db_table
|
||||
|
||||
def background_task(self):
|
||||
asyncio.run(self.api_rate_check())
|
||||
def update_db(self):
|
||||
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
||||
if not api:
|
||||
api = Api(model_name = self.model,
|
||||
site = self.site,
|
||||
rpmin = self.rpmin,
|
||||
rph = self.rph,
|
||||
rpd = self.rpd,
|
||||
rpw = self.rpw,
|
||||
rpmth = self.rpmth,
|
||||
rpy = self.rpy)
|
||||
session.add(api)
|
||||
session.commit()
|
||||
self._set_db_model_id()
|
||||
else:
|
||||
api.rpmin = self.rpmin
|
||||
api.rph = self.rph
|
||||
api.rpd = self.rpd
|
||||
api.rpw = self.rpw
|
||||
api.rpmth = self.rpmth
|
||||
api.rpy = self.rpy
|
||||
session.commit()
|
||||
|
||||
def start(self):
|
||||
# Start the background task
|
||||
self.process = Process(target=self.background_task)
|
||||
self.process.daemon = True
|
||||
self.process.start()
|
||||
logger.info(f"Background process started with PID: {self.process.pid}")
|
||||
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
|
||||
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
|
||||
translation = json.dumps(translation)
|
||||
translation = Translations(source_texts = text, translated_texts = translation,
|
||||
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
|
||||
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
|
||||
translation_mismatch = mismatch)
|
||||
session.add(translation)
|
||||
session.commit()
|
||||
|
||||
def stop(self):
|
||||
# Stop the background task
|
||||
logger.info(f"Stopping background process with PID: {self.process.pid}")
|
||||
self.stop_event.set()
|
||||
if self.process:
|
||||
self.process.join(timeout=5)
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
@staticmethod
|
||||
def _single_period_calls_check(max_calls, call_count):
|
||||
if not max_calls:
|
||||
return True
|
||||
if max_calls <= call_count:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def request_func(request):
|
||||
@wraps(request)
|
||||
def wrapper(self, text, *args, **kwargs):
|
||||
if self.curr_calls.value < self.rate:
|
||||
# try:
|
||||
response = request(self, text, *args, **kwargs)
|
||||
self.curr_calls.value += 1
|
||||
return response
|
||||
# except Exception as e:
|
||||
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
||||
else:
|
||||
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
|
||||
raise TooManyRequests('Rate limit reached for this model.')
|
||||
return wrapper
|
||||
def _are_rates_good(self):
|
||||
curr_time = self._get_time()
|
||||
min_ago = curr_time - timedelta(minutes=1)
|
||||
hour_ago = curr_time - timedelta(hours=1)
|
||||
day_ago = curr_time - timedelta(days=1)
|
||||
week_ago = curr_time - timedelta(weeks=1)
|
||||
month_ago = curr_time - timedelta(days=30)
|
||||
year_ago = curr_time - timedelta(days=365)
|
||||
min_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= min_ago
|
||||
).count()
|
||||
hour_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= hour_ago
|
||||
).count()
|
||||
day_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= day_ago
|
||||
).count()
|
||||
week_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= week_ago
|
||||
).count()
|
||||
month_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= month_ago
|
||||
).count()
|
||||
year_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= year_ago
|
||||
).count()
|
||||
if self._single_period_calls_check(self.rpmin, min_calls) \
|
||||
and self._single_period_calls_check(self.rph, hour_calls) \
|
||||
and self._single_period_calls_check(self.rpd, day_calls) \
|
||||
and self._single_period_calls_check(self.rpw, week_calls) \
|
||||
and self._single_period_calls_check(self.rpmth, month_calls) \
|
||||
and self._single_period_calls_check(self.rpy, year_calls):
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
|
||||
return False
|
||||
|
||||
@request_func
|
||||
def translate(self, request_fn, texts_to_translate):
|
||||
if len(texts_to_translate) == 0:
|
||||
return []
|
||||
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
|
||||
response = request_fn(self, prompt)
|
||||
return ast.literal_eval(response.strip())
|
||||
|
||||
class Groqq(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GROQ_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Groq"
|
||||
# async def request_func(request):
|
||||
# @wraps(request)
|
||||
# async def wrapper(self, text, *args, **kwargs):
|
||||
# if await self._are_rates_good():
|
||||
# try:
|
||||
# self.session_calls += 1
|
||||
# response = await request(self, text, *args, **kwargs)
|
||||
# return response
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
||||
# else:
|
||||
# logger.error(f"Rate limit reached for this model.")
|
||||
# raise TooManyRequests('Rate limit reached for this model.')
|
||||
# return wrapper
|
||||
|
||||
def request(self, content):
|
||||
client = Groq()
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
# @request_func
|
||||
async def translate(self, texts_to_translate, store = False):
|
||||
if isinstance(texts_to_translate, str):
|
||||
texts_to_translate = [texts_to_translate]
|
||||
if len(texts_to_translate) == 0:
|
||||
return []
|
||||
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
|
||||
prompt = f"""INSTRUCTIONS:
|
||||
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
|
||||
- The translations must preserve the original order.
|
||||
- Each translation must be from the Source language to the Target language
|
||||
- Source language: {self.from_lang}
|
||||
- Target language: {self.target_lang}
|
||||
- Texts are provided in JSON array syntax.
|
||||
- Respond using ONLY valid JSON array syntax.
|
||||
- Do not include explanations or additional text
|
||||
- Escape special characters properly
|
||||
|
||||
Input texts:
|
||||
{texts_to_translate}
|
||||
|
||||
Expected format:
|
||||
["translation1", "translation2", ...]
|
||||
|
||||
Translation:"""
|
||||
response = await self._request(prompt)
|
||||
response_list = ast.literal_eval(response.strip())
|
||||
logger.debug(repr(self))
|
||||
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
|
||||
|
||||
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
|
||||
logger.error(f"{self.model} model failed to translate all the texts. Number of translations to make: {len(texts_to_translate)}; Number of translated texts: {len(response_list)}.")
|
||||
if store:
|
||||
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
|
||||
else:
|
||||
if store:
|
||||
self._db_add_translation(texts_to_translate, response_list)
|
||||
print(response_list)
|
||||
return response_list
|
||||
|
||||
class Groq(ApiModel):
|
||||
def __init__(self, # model name as defined by the API
|
||||
model,
|
||||
api_key = GROQ_API_KEY, # api key for the model wrt the site
|
||||
**kwargs):
|
||||
super().__init__(model,
|
||||
api_key = api_key,
|
||||
site = 'Groq', **kwargs)
|
||||
self.client = Groqq()
|
||||
|
||||
async def _request(self, content: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.groq.com/openai/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {GROQ_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"model": self.model
|
||||
}
|
||||
],
|
||||
model=self.model
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
return response_json["choices"][0]["message"]["content"]
|
||||
# https://console.groq.com/settings/limits for limits
|
||||
# def request(self, content):
|
||||
# chat_completion = self.client.chat.completions.create(
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": content,
|
||||
# }
|
||||
# ],
|
||||
# model=self.model
|
||||
# )
|
||||
# return chat_completion.choices[0].message.content
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Groqq.request, texts_to_translate)
|
||||
# async def translate(self, texts_to_translate):
|
||||
# return super().translate(self.request, texts_to_translate)
|
||||
|
||||
class Gemini(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Gemini"
|
||||
def __init__(self, # model name as defined by the API
|
||||
model,
|
||||
api_key = GEMINI_API_KEY, # api key for the model wrt the site
|
||||
**kwargs):
|
||||
super().__init__(model,
|
||||
api_key = api_key,
|
||||
site = 'Google',
|
||||
**kwargs)
|
||||
|
||||
def request(self, content):
|
||||
genai.configure(api_key=self.api_key)
|
||||
safety_settings = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
||||
return response.text.strip()
|
||||
# def request(self, content):
|
||||
# genai.configure(api_key=self.api_key)
|
||||
# safety_settings = {
|
||||
# "HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
# "HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
# "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
# "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||
# try:
|
||||
# response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
||||
# except ResourceExhausted as e:
|
||||
# logger.error(f"Rate limited with {self.model}. Error: {e}")
|
||||
# raise ResourceExhausted("Rate limited.")
|
||||
# return response.text.strip()
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Gemini.request, texts_to_translate)
|
||||
async def _request(self, content):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
|
||||
headers={
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": content}]}],
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
return response_json['candidates'][0]['content']['parts'][0]['text']
|
||||
|
||||
# async def translate(self, texts_to_translate):
|
||||
# return super().translate(self.request, texts_to_translate)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
|
||||
### LOCAL LLM TRANSLATION
|
||||
|
||||
class TranslationDataset(Dataset):
|
||||
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
||||
|
||||
@ -21,7 +21,6 @@ def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
|
||||
|
||||
|
||||
def _paddle_ocr(ocr, image) -> list:
|
||||
|
||||
### return a list containing the bounding box, text and confidence of the detected text
|
||||
result = ocr.ocr(image, cls=False)[0]
|
||||
if not isinstance(result, list):
|
||||
@ -32,28 +31,29 @@ def _paddle_ocr(ocr, image) -> list:
|
||||
# EasyOCR has support for many languages
|
||||
|
||||
def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
|
||||
langs = []
|
||||
for lang in easy_languages:
|
||||
langs.append(standardize_lang(lang)['easyocr_lang'])
|
||||
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
|
||||
return easyocr.Reader(easy_languages, gpu=use_GPU, **kwargs)
|
||||
|
||||
def _easy_ocr(ocr,image) -> list:
|
||||
return ocr.readtext(image)
|
||||
|
||||
# RapidOCR mostly for mandarin and some other asian languages
|
||||
|
||||
# default only supports chinese and english
|
||||
def _rapid_init(use_GPU=True, **kwargs):
|
||||
return RapidOCR(use_gpu=use_GPU, **kwargs)
|
||||
|
||||
def _rapid_ocr(ocr, image) -> list:
|
||||
return ocr(image)
|
||||
return ocr(image)[0]
|
||||
|
||||
### Initialize the OCR model
|
||||
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs):
|
||||
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch_sim', use_GPU=True):
|
||||
if model == 'paddle':
|
||||
paddle_lang = standardize_lang(paddle_lang)['paddleocr_lang']
|
||||
return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
|
||||
elif model == 'easy':
|
||||
return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU)
|
||||
langs = []
|
||||
for lang in easy_languages:
|
||||
langs.append(standardize_lang(lang)['easyocr_lang'])
|
||||
return _easy_init(easy_languages=langs, use_GPU=use_GPU)
|
||||
elif model == 'rapid':
|
||||
return _rapid_init(use_GPU=use_GPU)
|
||||
|
||||
@ -82,10 +82,11 @@ def _id_filtered(ocr, image, lang) -> list:
|
||||
return results_no_eng
|
||||
|
||||
|
||||
# ch_sim, ch_tra, ja, ko, en
|
||||
# ch_sim, ch_tra, ja, ko, en input
|
||||
def _id_lang(ocr, image, lang) -> list:
|
||||
result = _identify(ocr, image)
|
||||
lang = standardize_lang(lang)['id_model_lang']
|
||||
print(result)
|
||||
try:
|
||||
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
|
||||
except:
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
||||
import google.generativeai as genai
|
||||
import torch, os, sys, ast, json
|
||||
import torch, os, sys, ast, json, asyncio, batching, random
|
||||
from typing import List, Optional, Set
|
||||
from utils import standardize_lang
|
||||
from functools import wraps
|
||||
import random
|
||||
import batching
|
||||
from batching import generate_text, Gemini, Groq
|
||||
from batching import generate_text, Gemini, Groq, ApiModel
|
||||
from logging_config import logger
|
||||
from multiprocessing import Process,Event
|
||||
from asyncio import Task
|
||||
# root dir
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
|
||||
|
||||
##############################
|
||||
# translation decorator
|
||||
@ -32,27 +31,66 @@ def translate(translation_func):
|
||||
def init_API_LLM(from_lang, target_lang):
|
||||
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||
with open('api_models.json', 'r') as f:
|
||||
with open(API_MODELS_FILEPATH, 'r') as f:
|
||||
models_and_rates = json.load(f)
|
||||
models = []
|
||||
for class_type, class_models in models_and_rates.items():
|
||||
cls = getattr(batching, class_type)
|
||||
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
|
||||
instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
|
||||
models.extend(instantiated_objects)
|
||||
|
||||
for model in models:
|
||||
model.start()
|
||||
model.update_db()
|
||||
model.set_lang(from_lang, target_lang)
|
||||
return models
|
||||
|
||||
def translate_API_LLM(text, models):
|
||||
random.shuffle(models)
|
||||
for model in models:
|
||||
async def translate_API_LLM(texts_to_translate: List[str],
|
||||
models: List[ApiModel],
|
||||
call_size: int = 2,
|
||||
stagger_delay: int = 2) -> List[str]:
|
||||
async def try_translate(model: ApiModel) -> Optional[List[str]]:
|
||||
try:
|
||||
return model.translate(text)
|
||||
except:
|
||||
continue
|
||||
result = await model.translate(texts_to_translate, store=True)
|
||||
logger.debug(f'Try_translate result: {result}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Translation failed for {model.model} from {model.site}: {e}")
|
||||
return None
|
||||
random.shuffle(models)
|
||||
groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
|
||||
|
||||
for group in groups:
|
||||
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
|
||||
while tasks:
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
logger.debug(f"Tasks done: {done}")
|
||||
logger.debug(f"Tasks remaining: {pending}")
|
||||
for task in done:
|
||||
result = await task
|
||||
logger.debug(f'Result: {result}')
|
||||
if result is not None:
|
||||
# Cancel remaining tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
return result
|
||||
logger.error("All models have failed to translate the text.")
|
||||
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
|
||||
# def translate_API_LLM(text, models):
|
||||
# random.shuffle(models)
|
||||
# logger.debug(f"All Models Available: {models}")
|
||||
# for model in models:
|
||||
# logger.info(f"Attempting translation with model {model}.")
|
||||
# try:
|
||||
# translation = model.translate(text)
|
||||
# logger.debug(f"Translation obtained: {translation}")
|
||||
# if translation or translation == []:
|
||||
# return translation
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error with model {repr(model)}. Error: {e}")
|
||||
# continue
|
||||
# logger.error("All models have failed to translate the text.")
|
||||
# raise TypeError("Models have likely all outputted garbage translations or rate limited.")
|
||||
|
||||
###############################
|
||||
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
||||
|
||||
@ -4,7 +4,9 @@ import pyscreenshot as ImageGrab # wayland tings not sure if it will work on oth
|
||||
import mss, io, os
|
||||
from PIL import Image
|
||||
import jaconv, MeCab, unidic, pykakasi
|
||||
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
import numpy as np
|
||||
# for creating furigana
|
||||
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
|
||||
uroman = ur.Uroman()
|
||||
@ -95,7 +97,7 @@ def contains_katakana(text):
|
||||
|
||||
# use kakasi to romanize japanese text
|
||||
def romanize(text, lang):
|
||||
if lang == 'zh':
|
||||
if lang in ['zh','ch_sim','ch_tra']:
|
||||
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
|
||||
if lang == 'ja':
|
||||
return kks.convert(text)[0]['hepburn']
|
||||
@ -131,13 +133,13 @@ def standardize_lang(lang):
|
||||
id_model_lang = 'zh'
|
||||
elif lang == 'ja':
|
||||
easyocr_lang = 'ja'
|
||||
paddleocr_lang = 'ja'
|
||||
paddleocr_lang = 'japan'
|
||||
rapidocr_lang = 'ja'
|
||||
translation_model_lang = 'ja'
|
||||
id_model_lang = 'ja'
|
||||
elif lang == 'ko':
|
||||
easyocr_lang = 'korean'
|
||||
paddleocr_lang = 'ko'
|
||||
paddleocr_lang = 'korean'
|
||||
rapidocr_lang = 'ko'
|
||||
translation_model_lang = 'ko'
|
||||
id_model_lang = 'ko'
|
||||
@ -165,6 +167,23 @@ def which_ocr_lang(model):
|
||||
else:
|
||||
raise ValueError("Invalid OCR model. Please use one of 'easy', 'paddle', or 'rapid'.")
|
||||
|
||||
def similar_tfidf(list1,list2,threshold) -> float:
|
||||
"""Calculate cosine similarity using TF-IDF vectors."""
|
||||
if not list1 or not list2:
|
||||
return 0.0
|
||||
|
||||
vectorizer = TfidfVectorizer()
|
||||
all_texts = list1 + list2
|
||||
tfidf_matrix = vectorizer.fit_transform(all_texts)
|
||||
|
||||
# Calculate average vectors for each list
|
||||
vec1 = np.mean(tfidf_matrix[:len(list1)].toarray(), axis=0).reshape(1, -1)
|
||||
vec2 = np.mean(tfidf_matrix[len(list1):].toarray(), axis=0).reshape(1, -1)
|
||||
|
||||
return float(cosine_similarity(vec1, vec2)[0, 0]) > threshold
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
|
||||
@ -48,8 +48,8 @@ def setup_logger(
|
||||
|
||||
# Create a formatter and set it for both handlers
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
'%(asctime)s.%(msecs)03d - %(name)s - [%(levelname)s] %(message)s',
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(formatter)
|
||||
@ -65,3 +65,4 @@ def setup_logger(
|
||||
return None
|
||||
|
||||
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)
|
||||
|
||||
|
||||
100
main.py
Normal file
100
main.py
Normal file
@ -0,0 +1,100 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys, threading, subprocess
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image, similar_tfidf
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from data import Base, engine, create_tables
|
||||
from draw import modify_image_bytes
|
||||
import config, asyncio
|
||||
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD
|
||||
from logging_config import logger
|
||||
import web_app
|
||||
import view_buffer_app
|
||||
###################################################################################
|
||||
|
||||
|
||||
async def main():
|
||||
###################################################################################
|
||||
|
||||
# Initialisation
|
||||
##### Create the database if not present #####
|
||||
create_tables()
|
||||
|
||||
##### Initialize the OCR #####
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
##### Initialize the translation #####
|
||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
###################################################################################
|
||||
|
||||
runs = 0
|
||||
|
||||
# label, app = view_buffer_app.create_viewer()
|
||||
|
||||
# try:
|
||||
while True:
|
||||
logger.debug("Capturing screen")
|
||||
untranslated_image = printsc(REGION)
|
||||
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
|
||||
if runs == 0:
|
||||
logger.info('Initial run')
|
||||
prev_words = set()
|
||||
else:
|
||||
logger.info(f'Run number: {runs}.')
|
||||
runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
|
||||
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||
if not similar_tfidf(list(curr_words), list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD) and prev_words != curr_words:
|
||||
logger.info('Beginning Translation')
|
||||
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
try:
|
||||
translation = await translate_API_LLM(to_translate, models, call_size = 3)
|
||||
except TypeError as e:
|
||||
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for 30 seconds.")
|
||||
time.sleep(30)
|
||||
continue
|
||||
logger.debug('Translation complete. Modifying image.')
|
||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||
# view_buffer_app.show_buffer_image(translated_image, label)
|
||||
web_app.latest_image = bytes_to_image(translated_image)
|
||||
logger.debug("Image modified. Saving image.")
|
||||
# web_app.latest_image.save('/home/James/Pictures/translated.png') # home use
|
||||
# logger.debug("Image saved.")
|
||||
prev_words = curr_words
|
||||
else:
|
||||
logger.info("Skipping translation. No significant change in the screen detected.")
|
||||
logger.debug("Continuing to next iteration.")
|
||||
# logger.debug(f'Sleeping for {INTERVAL} seconds')
|
||||
asyncio.sleep(INTERVAL)
|
||||
# finally:
|
||||
# label.close()
|
||||
# app.quit()
|
||||
################### TODO ##################
|
||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
|
||||
|
||||
if __name__ == "__main__":
|
||||
# subprocess.Popen(['feh','--auto-reload', '/home/James/Pictures/translated.png'])
|
||||
# asyncio.run(main())
|
||||
# Start the image updating thread
|
||||
logger.info('Configuration:')
|
||||
for i in dir(config):
|
||||
if not callable(getattr(config, i)) and not i.startswith("__"):
|
||||
logger.info(f'{i}: {getattr(config, i)}')
|
||||
threading.Thread(target=asyncio.run, args=(main(),), daemon=True).start()
|
||||
|
||||
# Start the Flask web server
|
||||
web_app.app.run(host='0.0.0.0', port=5000, debug=False)
|
||||
115
qtapp.py
115
qtapp.py
@ -1,115 +0,0 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
|
||||
from create_overlay import app, overlay
|
||||
from typing import Optional, List
|
||||
###################################################################################
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
|
||||
QColor, QIcon, QImage, QPixmap)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TranslationThread(QThread):
|
||||
translation_ready = Signal(list, list) # Signal to send translation results
|
||||
start_capture = Signal()
|
||||
end_capture = Signal()
|
||||
screen_capture = Signal(int, int, int, int)
|
||||
def __init__(self, ocr, models, source_lang, target_lang, interval):
|
||||
super().__init__()
|
||||
self.ocr = ocr
|
||||
self.models = models
|
||||
self.source_lang = source_lang
|
||||
self.target_lang = target_lang
|
||||
self.interval = interval
|
||||
self.running = True
|
||||
self.prev_words = set()
|
||||
self.runs = 0
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
self.start_capture.emit()
|
||||
untranslated_image = printsc(REGION)
|
||||
self.end_capture.emit()
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
|
||||
|
||||
if self.runs == 0:
|
||||
logger.info('Initial run')
|
||||
else:
|
||||
logger.info(f'Run number: {self.runs}.')
|
||||
self.runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
|
||||
if self.prev_words != curr_words:
|
||||
logger.info('Translating')
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
translation = translate_API_LLM(to_translate, self.models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
|
||||
# Emit the translation results
|
||||
modify_image_bytes(byte_image, ocr_output, translation)
|
||||
self.translation_ready.emit(ocr_output, translation)
|
||||
|
||||
self.prev_words = curr_words
|
||||
else:
|
||||
logger.info("No new words to translate. Output will not refresh.")
|
||||
|
||||
logger.info(f'Sleeping for {self.interval} seconds')
|
||||
time.sleep(self.interval)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Initialize OCR
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
# Initialize translation
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
|
||||
|
||||
# Create and start translation thread
|
||||
translation_thread = TranslationThread(
|
||||
ocr=ocr,
|
||||
models=models,
|
||||
source_lang=SOURCE_LANG,
|
||||
target_lang=TARGET_LANG,
|
||||
interval=INTERVAL
|
||||
)
|
||||
|
||||
# Connect translation results to overlay update
|
||||
translation_thread.start_capture.connect(overlay.prepare_for_capture)
|
||||
translation_thread.end_capture.connect(overlay.restore_after_capture)
|
||||
translation_thread.translation_ready.connect(overlay.update_translation)
|
||||
translation_thread.screen_capture.connect(overlay.capture_behind)
|
||||
# Start the translation thread
|
||||
translation_thread.start()
|
||||
|
||||
|
||||
# Start Qt event loop
|
||||
result = app.exec()
|
||||
|
||||
# Cleanup
|
||||
translation_thread.stop()
|
||||
translation_thread.wait()
|
||||
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -17,7 +17,7 @@
|
||||
setInterval(function () {
|
||||
document.getElementById("live-image").src =
|
||||
"/image?" + new Date().getTime();
|
||||
}, 3500); // Update every 2 seconds
|
||||
}, 2500); // Update every 2.5 seconds. Beware that if the image fails to reload on time, the browser will continuously refresh without being able to display the images.
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
52
view_buffer_app.py
Normal file
52
view_buffer_app.py
Normal file
@ -0,0 +1,52 @@
|
||||
|
||||
#### Same thread as main.py so it will be relatively unresponsive. Just for use locally for a faster image display from buffer.
|
||||
|
||||
|
||||
from PySide6.QtWidgets import QApplication, QLabel
|
||||
from PySide6.QtCore import Qt
|
||||
from PySide6.QtGui import QImage, QPixmap
|
||||
import sys
|
||||
def create_viewer():
|
||||
"""Create and return a QLabel widget for displaying images"""
|
||||
app = QApplication.instance()
|
||||
if app is None:
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
label = QLabel()
|
||||
label.setWindowTitle("Image Viewer")
|
||||
label.setMinimumSize(640, 480)
|
||||
# Enable mouse tracking for potential future interactivity
|
||||
label.setMouseTracking(True)
|
||||
# Better scaling quality
|
||||
label.setScaledContents(True)
|
||||
label.show()
|
||||
|
||||
return label, app
|
||||
|
||||
def show_buffer_image(buffer, label):
|
||||
"""
|
||||
Display an image from buffer using PySide6
|
||||
|
||||
Parameters:
|
||||
buffer: bytes
|
||||
Raw image data in memory
|
||||
label: QLabel
|
||||
Qt label widget to display the image
|
||||
"""
|
||||
# Convert buffer to QImage
|
||||
qimg = QImage.fromData(buffer)
|
||||
|
||||
# Convert to QPixmap and set to label
|
||||
pixmap = QPixmap.fromImage(qimg)
|
||||
|
||||
# Scale with better quality
|
||||
scaled_pixmap = pixmap.scaled(
|
||||
label.size(),
|
||||
Qt.KeepAspectRatio,
|
||||
Qt.SmoothTransformation
|
||||
)
|
||||
|
||||
label.setPixmap(scaled_pixmap)
|
||||
|
||||
# Process Qt events to update the display
|
||||
QApplication.processEvents()
|
||||
@ -1,12 +1,12 @@
|
||||
from flask import Flask, Response, render_template
|
||||
import threading
|
||||
import io
|
||||
import app
|
||||
app = Flask(__name__)
|
||||
|
||||
latest_image = None
|
||||
# Global variable to hold the current image
|
||||
def curr_image():
|
||||
return app.latest_image
|
||||
return latest_image
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
@ -29,7 +29,8 @@ def stream_image():
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Start the image updating thread
|
||||
threading.Thread(target=app.main, daemon=True).start()
|
||||
import main, asyncio
|
||||
threading.Thread(target=asyncio.run, args=(main(),), daemon=True).start()
|
||||
|
||||
# Start the Flask web server
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
Loading…
Reference in New Issue
Block a user