Compare commits

...

5 Commits

22 changed files with 1398 additions and 824 deletions

3
.gitignore vendored
View File

@ -5,4 +5,5 @@ __pycache__/
.* .*
test.py test.py
notebooks/ notebooks/
qttest.py qttest.py
*.db

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 chickenflyshigh
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,4 +1,70 @@
## What does this do?
It continuously provides translations from a source language to another language of a specified region on your screen while also (optionally) providing romanisation (including pinyin and furigana) to provide a guide to pronounciation. The main goal of this is primarily for people that have a low/basic level of understanding of a language to further develop that language by allowing the users to have the tool to allow them to immerse themselves in native content. Main uses of this include but are not limited to: playing games and watching videos with subtitles in another language (although technically it might just be better to obtain an audio transcription, translate and replace the subtitles if possible -- however this is not always feasible if watching many episodes and/or you are watching videos spontaneously).
## Limitations
If the `learn` mode is enabled for the app, the added translations and romanisation naturally results in texts taking up three times the space and therefore this is less suitable for texts that contain tightly packed words. You can optionally change the config to insert smaller text or change the overall font size of your screen so there are less text. A pure translation mode also exists, although if it is intended for web browsing, Google itself provides a more reliable method of translation which does not rely on the computationally heavy optical character recognition (OCR).
## Usage (draft)
1. Clone the repository, navigate to the repository and install all required packages with `pip install -r requirements.txt` in a new Python environment (the OCR packages are very finnicky).
2. If using external APIs, you will need to obtain the api keys for the currently supported sites [Google (Gemini models), Groq (an assortment of open-source LLMs)] and add in the associated api keys in the environmental variables file. If using another API, you will need to define a new class with the `_request` function in `helpers/batching.py`, inheriting the `ApiModels` class. A template is created in the file under the `Gemini` and `Groq` classes. All exception handling are already taken care of.
3. Edit the `api_models.json` file for the models you want added. The first level of the json file is the respective class name defined in `helpers/batching.py`. The second level defines the `model` names from their corresponding API endpoints. For the third level, the rates of each model are specified. `rpmin`, `rph`, `rpd`, `rpw`, `rpmth`, `rpy` are respectively the rates per minute, hour, day, week, month, year.
4. Create and edit the `.env` config file. For information about all the variables to edit, check the section under "EDIT THESE ENVIRONMENTAL VARIABLES" in the `config.py` file. If CUDA is not detected, it will default to using the `CPU` mode for all local LLMs and OCRs. In this case, it is recommended to set the `OCR_MODEL` variable to `rapid` which is optimised for CPUs. Currently the only support for this is with `SOURCE_LANG=ch_tra`, `ch_sim` or `en`. Refer to [notes][1]
5. If you are using the `wayland` display protocol (only available for Linux -- check with `echo $WAYLAND_DISPLAY`), download the `grim` package onto your machine locally with any of the package managers.
- Archlinux-based: `sudo pacman -S grim`
- Debian-based: `sudo apt install grim`
- Fedora: `dnf install grim`
- OpenSUSE: `zypper install grim`
- NixOS: `nix-shell -p grim`
Screenshotting is limited in Wayland, and `grim` is one of the more lightweight options out there.
6. The RapidOCR, PaddleOCR and (maybe I can't remember) the easyOCR models need to be downloaded before any of this can be used. It should download when you execute a function that initialises the model with the desired language to OCR but appears to not work well when running the app directly. (I'll add more to this later...). And this obviously holds the same for the local translation and LLM models...
7. Run the `main.py` file and a QT6 app should appear. Alternatively if that doesn't work, go to the last line of the `main.py` file and edit the argument to `web` which will run the translations locally on `0.0.0.0:5000` or on any other port you specify.
## Notes and optimisations
- Accuracy is limited with RapidOCR especially if there is a high dynamical range in graphics. [1]
- Consider lowering the quality of the screen capture for faster OCR processing and lower screen capture time -> OCR accuracy and subsequent translations can be affected but entire translation process should be under 2 seconds without too much sacrifice in OCR quality. Edit the `helpers/utils.py` `printsc` functions (will work on setting a config for this).
- Not much of the database aspect is worked on at the moment. Right now it stores all the texts and translations in unicode/ASCII in the database/translations.db file. Use it however you want, it is stored locally only for you.
- Downloading all the models may take up a few GBs of space.
- About 3.5GB of VRAM is used by easyOCR. Up to 1.5GB of VRAM for Paddle and Rapid OCR.
## Debugging Issues ## Debugging Issues
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment. 1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove the nvidia-cudnn-cn12 from your Python environment.
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info. 2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
## Demo
[Demo](https://youtu.be/Tmv_I0GkOQc) of Korean to Chinese (simplified) translation with the `learn-cover` mode (mode intended for people learning the language to see the romanisation/pinyin/furigana etc with the translation above).
## TODO:
- Create an overlay window that works in Wayland.
- Make use of the translation data -> maybe make a personalised game that uses the data.
- Providing the option for simplifying and automating most of the install process.
# Terms of Use
## Data Collection and External API Use
1.1 Onscreen Data Transmission: The application is designed to send data displayed on your screen, including potentially sensitive or personal information, to an external API if local processing is not setup.
1.2 Third-Party API Integration: When local methods cannot fulfill certain functions, the App will transmit data to external third-party APIs. These APIs are not under our control, and we do not guarantee the security, confidentiality, or purpose of use of the data once transmitted.
## Acknowledgment
By using the app, you acknowledge that you have read, understood, and agree to these Terms of Use, including the potential risks associated with transmitting data to external APIs.

View File

@ -1,13 +1,22 @@
{ {
"Gemini": { "Gemini": {
"gemini-1.5-pro": 2, "gemini-1.5-pro": { "rpmin": 2, "rpd": 50 },
"gemini-1.5-flash": 15, "gemini-1.5-flash": { "rpmin": 15, "rpd": 1500 },
"gemini-1.5-flash-8b": 8, "gemini-1.5-flash-8b": { "rpmin": 15, "rpd": 1500 },
"gemini-1.0-pro": 15 "gemini-1.0-pro": { "rpmin": 15, "rpd": 1500 }
}, },
"Groqq": { "Groq": {
"llama-3.2-90b-text-preview": 30, "llama-3.2-90b-text-preview": { "rpmin": 30, "rpd": 7000 },
"llama3-70b-8192": 30, "llama3-70b-8192": { "rpmin": 30, "rpd": 14400 },
"mixtral-8x7b-32768": 30 "mixtral-8x7b-32768": { "rpmin": 30, "rpd": 14400 },
"llama-3.1-70b-versatile": { "rpmin": 30, "rpd": 14400 },
"gemma2-9b-it": { "rpmin": 30, "rpd": 14400 },
"llama3-groq-8b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
"llama3-groq-70b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
"llama-3.2-90b-vision-preview": { "rpmin": 15, "rpd": 3500 },
"llama-3.2-11b-text-preview": { "rpmin": 30, "rpd": 7000 },
"llama-3.2-11b-vision-preview": { "rpmin": 30, "rpd": 7000 },
"gemma-7b-it": { "rpmin": 30, "rpd": 14400 },
"llama3-8b-8192": { "rpmin": 30, "rpd": 14400 }
} }
} }

83
app.py
View File

@ -1,83 +0,0 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
###################################################################################
ADD_OVERLAY = False
latest_image = None
def main():
global latest_image
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
###################################################################################
runs = 0
app.exec()
while True:
if ADD_OVERLAY:
overlay.clear_all_text()
untranslated_image = printsc(REGION)
if ADD_OVERLAY:
overlay.text_entries = overlay.text_entries_copy
overlay.update()
overlay.text_entries.clear()
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.info(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
translation = translate_API_LLM(to_translate, models)
logger.info(f'Translation from {to_translate} to\n {translation}')
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
latest_image = bytes_to_image(translated_image)
# latest_image.show() # for debugging
prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {INTERVAL} seconds')
time.sleep(INTERVAL)
# if ADD_OVERLAY:
# sys.exit(app.exec())
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,65 +1,101 @@
import os, ast, torch import os, ast, torch, platform
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv(override=True) load_dotenv(override=True)
if platform.system() == 'Windows':
default_tmp_dir = "C:\\Users\\AppData\\Local\\Temp"
elif platform.system() in ['Linux', 'Darwin']:
default_tmp_dir = "/tmp"
################################################################################################### ###################################################################################################
### EDIT THESE VARIABLES ### ### EDIT THESE VARIABLES ###
# Create a .env file in the same directory as this file and add the variables there. Of course you can choose to edit this file but then if you pull from the repository again all the config will be goneeee unless perhaps it is saved or stashed.
# The default values should be fine for most cases. Only ones that you need to change are the API keys, and the variables under Translation and API Translation if you choose to use an external API.
# available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en' INTERVAL = float(os.getenv('INTERVAL', 1.5)) # Interval in seconds between translations. If your system is slow, a lower value is probably fine with regards to API rates.
INTERVAL = int(os.getenv('INTERVAL'))
### OCR ### OCR
IMAGE_CHANGE_THRESHOLD = float(os.getenv('IMAGE_CHANGE_THRESHOLD', 0.75)) # higher values mean more sensitivity to changes in the screen, too high and the screen will constantly refresh
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True')) OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True')) # True or False to use CUDA for OCR. Defaults to CPU if no CUDA GPU is available
### Drawing/Overlay Config ### Drawing/Overlay Config
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True')) FILL_COLOUR = os.getenv('FILL_COLOUR', 'white') # colour of the textboxes
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white') FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000") # colour of the font
FONT_FILE = os.getenv('FONT_FILE') FONT_FILE = os.getenv('FONT_FILE', os.path.join(__file__, "fonts", "Arial-Unicode-Bold.ttf")) # path to the font file. Ensure it is a unicode .ttf file if you want to be able to see most languages.
FONT_SIZE = int(os.getenv('FONT_SIZE', 16)) FONT_SIZE_MAX = int(os.getenv('FONT_SIZE_MAX', 20)) # Maximum font size you want to be able to see onscreen
LINE_SPACING = int(os.getenv('LINE_SPACING', 3)) FONT_SIZE_MIN = int(os.getenv('FONT_SIZE_MIN', 8)) # Minimum font size you want to be able to see onscreen
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)')) LINE_SPACING = int(os.getenv('LINE_SPACING', 3)) # spacing between lines of text with the learn modes in DRAW_TRANSLATIONS_MODE
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000") REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)')) # (x1, y1, x2, y2) - the region of the screen to capture
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True')) DRAW_TRANSLATIONS_MODE = os.getenv('DRAW_TRANSLATIONS_MODE', 'learn_cover')
"""
DRAW_TRANSLATIONS_MODE possible options:
`learn': adds translated text, original text (should be added so when texts get moved around the translation of which it references is understood) and (optionally with the other TO_ROMANIZE option) romanized text above the original text. Texts can overlap if squished into a corner. Works well for games where texts are sparser
'learn_cover': same as above but covers the original text with the translated text. Can help with readability and is less cluttered but with sufficiently dense text the texts can still overlap
'translation_only_cover': cover the original text with the translated text - will not show the original text at all but also will not be affected by overlapping texts
"""
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
### Translation ### Translation
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200)) MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200)) # Maximum number of phrases to send to the translation model to translate
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja') SOURCE_LANG = os.getenv('SOURCE_LANG', 'ch_sim') # Translate from 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
TARGET_LANG = os.getenv('TARGET_LANG', 'en') TARGET_LANG = os.getenv('TARGET_LANG', 'en') # Translate to 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True')) # romanize the text or not. Only available for one of the learn modes in DRAW_TRANSLATIONS_MODE. It is added above the original text
### API Translation (could be external or a local API)
# API KEYS
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') # https://ai.google.dev/
GROQ_API_KEY = os.getenv('GROQ_API_KEY') # https://console.groq.com/keys
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
### Local Translation ### Local Translation Models
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True')) TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512)) MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512)) MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6)) BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False')) LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False')) # will not attempt pinging Huggingface for the models and just use the cached local models
################################################################################################### ###################################################################################################
###################################################################################################
### DO NOT EDIT THESE VARIABLES ###
## Filepaths
API_MODELS_FILEPATH = os.path.join(os.path.dirname(__file__), 'api_models.json')
FONT_SIZE = int((FONT_SIZE_MAX + FONT_SIZE_MIN)/2)
LINE_HEIGHT = FONT_SIZE LINE_HEIGHT = FONT_SIZE
if TRANSLATION_USE_GPU is False: if TRANSLATION_USE_GPU is False:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEMP_IMG_DIR = os.getenv('TEMP_IMG_PATH', default_tmp_dir) # where the temporary images are stored
TEMP_IMG_PATH = os.path.join(TEMP_IMG_DIR, 'tempP_img91258102.png')
### Just for info ### Just for info
available_langs = ['ch_sim', 'ch_tra', 'ja', 'ko', 'en'] # there are limitations with the languages that can be used with the OCR models available_langs = ['ch_sim', 'ch_tra', 'ja', 'ko', 'en'] # there are limitations with the languages that can be used with the OCR models
seq_llm_models = ['opus', 'm2m'] seq_llm_models = ['opus', 'm2m']
api_llm_models = ['gemini'] api_llm_models = ['gemini']
causal_llm_models = [] causal_llm_models = []
curr_models = seq_llm_models + api_llm_models + causal_llm_models curr_models = seq_llm_models + api_llm_models + causal_llm_models

View File

@ -1,305 +0,0 @@
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
import sys, io, os, signal, time, platform
from PIL import Image
from dataclasses import dataclass
from typing import List, Optional
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
from logging_config import logger
def qpixmap_to_bytes(qpixmap):
qimage = qpixmap.toImage()
buffer = QBuffer()
buffer.open(QBuffer.ReadWrite)
qimage.save(buffer, "PNG")
return qimage
@dataclass
class TextEntry:
text: str
x: int
y: int
font: QFont = QFont('Arial', FONT_SIZE)
visible: bool = True
text_color: Qt.GlobalColor = Qt.GlobalColor.red
background_color: Optional[Qt.GlobalColor] = None
padding: int = 1
class TranslationOverlay(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Translation Overlay")
self.is_passthrough = True
self.text_entries: List[TextEntry] = []
self.setup_window_attributes()
self.setup_shortcuts()
self.closeEvent = lambda event: QApplication.quit()
self.default_font = QFont('Arial', FONT_SIZE)
self.text_entries_copy: List[TextEntry] = []
self.next_text_entries: List[TextEntry] = []
#self.show_background = True
self.background_opacity = 0.5
# self.setup_tray()
def prepare_for_capture(self):
"""Preserve current state and clear overlay"""
if ADD_OVERLAY:
self.text_entries_copy = self.text_entries.copy()
self.clear_all_text()
self.update()
def restore_after_capture(self):
"""Restore overlay state after capture"""
if ADD_OVERLAY:
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
self.text_entries = self.text_entries_copy.copy()
logger.debug(f"Restored text entries: {self.text_entries}")
self.update()
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
"""Add new text without triggering update"""
entry = TextEntry(
text=text,
x=x,
y=y,
font=font or self.default_font,
text_color=text_color
)
self.next_text_entries.append(entry)
def update_translation(self, ocr_output, translation):
# Update your overlay with new translations here
# You'll need to implement the logic to display the translations
self.clear_all_text()
self.text_entries = self.next_text_entries.copy()
self.next_text_entries.clear()
self.update()
def capture_behind(self, x=None, y=None, width=None, height=None):
"""
Capture the screen area behind the overlay.
If no coordinates provided, captures the area under the window.
"""
# Temporarily hide the window
self.hide()
# Get screen
screen = QScreen.grabWindow(
self.screen(),
0,
x if x is not None else self.x(),
y if y is not None else self.y(),
width if width is not None else self.width(),
height if height is not None else self.height()
)
# Show the window again
self.show()
screen_bytes = qpixmap_to_bytes(screen)
return screen_bytes
def clear_all_text(self):
"""Clear all text entries"""
self.text_entries.clear()
self.update()
def setup_window_attributes(self):
# Set window flags for overlay behavior
self.setWindowFlags(
Qt.WindowType.FramelessWindowHint |
Qt.WindowType.WindowStaysOnTopHint |
Qt.WindowType.Tool
)
# Set attributes for transparency
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
# Make the window cover the entire screen
self.setGeometry(QApplication.primaryScreen().geometry())
# Special handling for Wayland
if platform.system() == "Linux":
if "WAYLAND_DISPLAY" in os.environ:
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
def setup_shortcuts(self):
# Toggle visibility (Alt+Shift+T)
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
# Toggle passthrough mode (Alt+Shift+P)
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
# Quick hide (Escape)
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
self.hide_shortcut.activated.connect(self.hide)
# Clear all text (Alt+Shift+C)
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
self.clear_shortcut.activated.connect(self.clear_all_text)
# Toggle background
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
self.toggle_background_shortcut.activated.connect(self.toggle_background)
def toggle_visibility(self):
if self.isVisible():
self.hide()
else:
self.show()
def toggle_passthrough(self):
self.is_passthrough = not self.is_passthrough
if self.is_passthrough:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
else:
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
self.hide()
self.show()
def toggle_background(self):
"""Toggle background visibility"""
self.show_background = not self.show_background
self.update()
def set_background_opacity(self, opacity: float):
"""Set background opacity (0.0 to 1.0)"""
self.background_opacity = max(0.0, min(1.0, opacity))
self.update()
def add_text_at_position(self, x: int, y: int, text: str):
"""Add new text at specific coordinates"""
entry = TextEntry(text, x, y)
self.text_entries.append(entry)
self.update()
def update_text_at_position(self, x: int, y: int, text: str):
"""Update text at specific coordinates, or add if none exists"""
# Look for existing text entry near these coordinates (within 5 pixels)
for entry in self.text_entries:
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
entry.text = text
self.update()
return
# If no existing entry found, add new one
self.add_text_at_position(x, y, text)
def setup_tray(self):
self.tray_icon = QSystemTrayIcon(self)
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
tray_menu = QMenu()
toggle_action = tray_menu.addAction("Show/Hide Overlay")
toggle_action.triggered.connect(self.toggle_visibility)
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
toggle_passthrough.triggered.connect(self.toggle_passthrough)
# Add background toggle to tray menu
toggle_background = tray_menu.addAction("Toggle Background")
toggle_background.triggered.connect(self.toggle_background)
clear_action = tray_menu.addAction("Clear All Text")
clear_action.triggered.connect(self.clear_all_text)
tray_menu.addSeparator()
quit_action = tray_menu.addAction("Quit")
quit_action.triggered.connect(self.clean_exit)
self.tray_icon.setToolTip("Translation Overlay")
self.tray_icon.setContextMenu(tray_menu)
self.tray_icon.show()
self.tray_icon.activated.connect(self.tray_activated)
def remove_text_at_position(self, x: int, y: int):
"""Remove text entry near specified coordinates"""
self.text_entries = [
entry for entry in self.text_entries
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
]
self.update()
def paintEvent(self, event):
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
# Draw each text entry
for entry in self.text_entries:
if not entry.visible:
continue
# Set the font for this specific entry
painter.setFont(entry.font)
text_metrics = painter.fontMetrics()
# Get the bounding rectangles for text
text_bounds = text_metrics.boundingRect(
entry.text
)
total_width = text_bounds.width()
total_height = text_bounds.height()
# Create rectangles for text placement
text_rect = QRect(entry.x, entry.y, total_width, total_height)
# Calculate background rectangle that encompasses both texts
if entry.background_color is not None:
bg_rect = QRect(entry.x - entry.padding,
entry.y - entry.padding,
total_width + (2 * entry.padding),
total_height + (2 * entry.padding))
painter.setPen(Qt.PenStyle.NoPen)
painter.setBrush(entry.background_color)
painter.drawRect(bg_rect)
# Draw the texts
painter.setPen(entry.text_color)
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
def handle_exit(signum, frame):
QApplication.quit()
def start_overlay():
app = QApplication(sys.argv)
# Enable Wayland support if available
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
app.setProperty("platform", "wayland")
overlay = TranslationOverlay()
overlay.show()
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
signal.signal(signal.SIGTERM, handle_exit)
return (app, overlay)
# sys.exit(app.exec())
if ADD_OVERLAY:
app, overlay = start_overlay()
if __name__ == "__main__":
ADD_OVERLAY = True
if not ADD_OVERLAY:
app, overlay = start_overlay()
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
capture = overlay.capture_behind()
capture.save("capture.png")
sys.exit(app.exec())

59
data.py Normal file
View File

@ -0,0 +1,59 @@
from sqlalchemy import create_engine, Column, Index, Integer, String, MetaData, Table, DateTime, ForeignKey, Boolean
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import relationship, declarative_base, sessionmaker
import logging
from logging_config import logger
import os
# Set up the database connection
data_dir = os.path.join(os.path.dirname(__file__), 'database')
os.makedirs(data_dir, exist_ok=True)
database_file = os.path.join(os.path.dirname(__file__), data_dir, 'translations.db')
engine = create_engine(f'sqlite:///{database_file}', echo=False)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base()
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
class Api(Base):
__tablename__ = 'api'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(String, nullable=False)
site = Column(Integer, nullable=False)
rpmin = Column(Integer) # rate per minute
rph = Column(Integer) # rate per hour
rpd = Column(Integer) # rate per day
rpw = Column(Integer) # rate per week
rpmth = Column(Integer) # rate per month
rpy = Column(Integer) # rate per year
translations = relationship("Translations", back_populates="api")
class Translations(Base):
__tablename__ = 'translations'
id = Column(Integer, primary_key=True, autoincrement=True)
model_id = Column(Integer, ForeignKey('api.id'), nullable=False)
source_texts = Column(String, nullable=False) # as a json string
translated_texts = Column(String, nullable=False) # as a json string
source_lang = Column(String, nullable=False)
target_lang = Column(String, nullable=False)
timestamp = Column(DateTime, nullable=False)
translation_mismatch = Column(Boolean, nullable=False)
api = relationship("Api", back_populates="translations")
__table_args__ = (
Index('idx_timestamp', 'timestamp'),
)
def create_tables():
if not os.path.exists(database_file):
Base.metadata.create_all(engine)
logger.info(f"Database created at {database_file}")
else:
logger.info(f"Using Pre-existing Database at {database_file}.")

247
draw.py
View File

@ -1,68 +1,66 @@
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont, ImageFilter
import os, io, sys, numpy as np import os, io, sys, numpy as np
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from utils import romanize, intercepts, add_furigana from utils import romanize, intercepts, add_furigana
from logging_config import logger from logging_config import logger
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE_MAX,FONT_SIZE_MIN, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION, DRAW_TRANSLATIONS_MODE
from PySide6.QtGui import QFont
font = ImageFont.truetype(FONT_FILE, FONT_SIZE) font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
#### CREATE A CLASS LATER so it doesn't have to inherit the same arguments all the way too confusing :| its so ass like this man i had no foresight
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes: def modify_image(input: io.BytesIO | str, ocr_output, translation: list) -> bytes:
"""Modify the image bytes with the translated text and return the modified image bytes""" """Modify the image bytes with the translated text and return the modified image bytes. If it is a path then open directly."""
# if input is str, then check if it exists
with io.BytesIO(image_bytes) as byte_stream: if isinstance(input, str):
image = Image.open(byte_stream) image = Image.open(input)
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE) draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
elif isinstance(input, io.BytesIO):
with io.BytesIO(input) as byte_stream:
image = Image.open(byte_stream)
draw = ImageDraw.Draw(image)
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
else:
raise TypeError('Incorrect filetype input')
# Save the modified image back to bytes without changing the format # Save the modified image back to bytes without changing the format
with io.BytesIO() as byte_stream: with io.BytesIO() as byte_stream:
image.save(byte_stream, format=image.format) # Save in original format image.save(byte_stream, format='PNG') # Save in original format
modified_image_bytes = byte_stream.getvalue() modified_image_bytes = byte_stream.getvalue()
return modified_image_bytes return modified_image_bytes
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, draw_mode: str = DRAW_TRANSLATIONS_MODE) -> ImageDraw:
"""Draw the original, translated and optionally the romanisation of the texts on the image""" """Draw the original, translated and optionally the romanisation of the texts on the image"""
translated_number = 0 translated_number = 0
bounding_boxes = [] bounding_boxes = []
logger.debug(f"Translations: {len(translation)} {translation}")
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output): for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
logger.debug(f"Untranslated phrase: {untranslated_phrase}") if translated_number >= len(translation): # note if using api llm some issues may cause it to return less translations than expected
if translated_number >= max_translate - 1:
break break
if replace: if draw_mode == 'learn':
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase) draw_one_phrase_learn(draw, translation[i], position, bounding_boxes, untranslated_phrase)
else: elif draw_mode == 'translation_only':
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase) draw_one_phrase_translation_only(draw, translation[i], position, bounding_boxes, untranslated_phrase)
elif draw_mode == 'learn_cover':
draw_one_phrase_learn_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
elif draw_mode == 'translation_only_cover':
draw_one_phrase_translation_only_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
translated_number += 1 translated_number += 1
return draw
def draw_one_phrase_add(draw: ImageDraw, def draw_one_phrase_learn(draw: ImageDraw,
translated_phrase: str, translated_phrase: str,
position: tuple, bounding_boxes: list, position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw: untranslated_phrase: str) -> ImageDraw:
"""Draw the bounding box rectangle and text on the image above the original text""" """Draw the bounding box rectangle and text on the image above the original text"""
if SOURCE_LANG == 'ja': lines = get_lines(untranslated_phrase, translated_phrase)
untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja')
else:
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
if TO_ROMANIZE:
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
else:
text_content = f"{translated_phrase}\n{untranslated_phrase}"
lines = text_content.split('\n')
# Draw the bounding box # Draw the bounding box
top_left, _, _, _ = position top_left, _, bottom_right,_ = position
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE) font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING) max_width = get_max_width(lines, FONT_FILE, font_size)
total_height = get_max_height(lines, font_size, LINE_SPACING)
font = ImageFont.truetype(FONT_FILE, font_size)
right_edge = REGION[2] right_edge = REGION[2]
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate # Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
@ -73,60 +71,146 @@ def draw_one_phrase_add(draw: ImageDraw,
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height) adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1] adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1) draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
# draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
position = (adjusted_x,adjusted_y) position = (adjusted_x,adjusted_y)
for line in lines: for line in lines:
draw.text(position, line, fill= FONT_COLOUR, font=font) if FONT_COLOUR == 'rainbow':
if ADD_OVERLAY: rainbow_text(draw, line, *position, font)
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR) else:
adjusted_y += FONT_SIZE + LINE_SPACING draw.text(position, line, fill= FONT_COLOUR, font=font)
adjusted_y += font_size + LINE_SPACING
position = (adjusted_x,adjusted_y) position = (adjusted_x,adjusted_y)
### Only support for horizontal text atm, vertical text is on the todo list ### Only support for horizontal text atm, vertical text is on the todo list
def draw_one_phrase_replace(draw: ImageDraw, def draw_one_phrase_translation_only_cover(draw: ImageDraw,
translated_phrase: str, translated_phrase: str,
position: tuple, bounding_boxes: list, position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw: untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top""" """Cover up old text and add translation directly on top"""
# Draw the bounding box # Draw the bounding box
top_left, _, _, bottom_right = position top_left, _, bottom_right, _ = position
max_width = bottom_right[0] - top_left[0] bounding_boxes.append((top_left[0], top_left[1], bottom_right[0], bottom_right[1], untranslated_phrase)) # Debugging purposes
font_size = bottom_right[1] - top_left[1] max_width = bottom_right[0] - top_left[0]
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR) font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
while True: while True:
font = ImageFont.truetype(FONT_FILE, font_size) font = ImageFont.truetype(FONT_FILE, font_size)
if font.get_max_width < max_width: phrase_width = get_max_width(translated_phrase, FONT_FILE, font_size)
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font) rectangle = get_rectangle_coordinates(translated_phrase, top_left, FONT_FILE, font_size, LINE_SPACING)
break
elif font_size <= 1: if phrase_width < max_width:
break draw.rectangle(rectangle, fill=FILL_COLOUR)
if FONT_COLOUR == 'rainbow':
rainbow_text(draw, translated_phrase, *top_left, font)
else: else:
font_size -= 1 draw.rounded_rectangle([top_left, bottom_right], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
def get_max_width(lines: list, font_path, font_size) -> int:
break
elif font_size <= FONT_SIZE_MIN:
break
else:
font_size -= 1
def draw_one_phrase_learn_cover(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
lines = get_lines(untranslated_phrase, translated_phrase)
# Draw the bounding box
top_left, _, bottom_right,_ = position
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
max_width = get_max_width(lines, FONT_FILE, font_size)
total_height = get_max_height(lines, font_size, LINE_SPACING)
font = ImageFont.truetype(FONT_FILE, font_size)
right_edge = REGION[2]
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
y_onscreen = max(top_left[1] - int(total_height/3), 0)
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
position = (adjusted_x,adjusted_y)
for line in lines:
if FONT_COLOUR == 'rainbow': # easter egg yay
rainbow_text(draw, line, *position, font)
else:
draw.text(position, line, fill= FONT_COLOUR, font=font)
adjusted_y += font_size + LINE_SPACING
position = (adjusted_x,adjusted_y)
def draw_one_phrase_translation_only(draw: ImageDraw,
translated_phrase: str,
position: tuple, bounding_boxes: list,
untranslated_phrase: str) -> ImageDraw:
"""Cover up old text and add translation directly on top"""
# Draw the bounding box
pass
def get_rectangle_coordinates(lines: list | str, top_left: tuple | list, font_path, font_size, line_spacing, padding: int = 1) -> list:
"""Get the coordinates of the rectangle surrounding the text"""
text_width = get_max_width(lines, font_path, font_size)
text_height = get_max_height(lines, font_size, line_spacing)
x1 = top_left[0] - padding
y1 = top_left[1] - padding
x2 = top_left[0] + text_width + padding
y2 = top_left[1] + text_height + padding
return [(x1,y1), (x2,y2)]
def get_max_width(lines: list | str, font_path, font_size) -> int:
"""Get the maximum width of the text lines""" """Get the maximum width of the text lines"""
font = ImageFont.truetype(font_path, font_size) font = ImageFont.truetype(font_path, font_size)
max_width = 0 max_width = 0
dummy_image = Image.new("RGB", (1, 1)) dummy_image = Image.new("RGB", (1, 1))
draw = ImageDraw.Draw(dummy_image) draw = ImageDraw.Draw(dummy_image)
for line in lines: if isinstance(lines, list):
bbox = draw.textbbox((0,0), line, font=font) for line in lines:
line_width = bbox[2] - bbox[0] bbox = draw.textbbox((0,0), line, font=font)
max_width = max(max_width, line_width) line_width = bbox[2] - bbox[0]
max_width = max(max_width, line_width)
else:
bbox = draw.textbbox((0,0), lines, font=font)
max_width = bbox[2] - bbox[0]
return max_width return max_width
def get_max_height(lines: list, font_size, line_spacing) -> int: def get_max_height(lines: list | str, font_size, line_spacing) -> int:
"""Get the maximum height of the text lines""" """Get the maximum height of the text lines"""
return len(lines) * (font_size + line_spacing) no_of_lines = len(lines) if isinstance(lines, list) else 1
return no_of_lines * (font_size + line_spacing)
def get_lines(untranslated_phrase: str, translated_phrase: str) -> list:
"""Get the translated. untranslated and optionally the romanised text as a list"""
if SOURCE_LANG == 'ja':
untranslated_phrase = add_furigana(untranslated_phrase)
romanized_phrase = romanize(untranslated_phrase, 'ja')
else:
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
if TO_ROMANIZE:
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
else:
text_content = f"{translated_phrase}\n{untranslated_phrase}"
return text_content.split('\n')
def adjust_if_intersects(x: int, y: int, def adjust_if_intersects(x: int, y: int,
bounding_box: tuple, bounding_boxes: list, bounding_box: tuple, bounding_boxes: list,
untranslated_phrase: str, untranslated_phrase: str,
max_width: int, total_height: int) -> tuple: max_width: int, total_height: int) -> tuple:
"""Adjust the y coordinate if the bounding box intersects with any other bounding box""" """Adjust the y coordinate every time the bounding box intersects with any previous bounding boxes. OCR returns results from top to bottom so it works."""
y = np.max([y,0]) y = np.max([y,0])
if len(bounding_boxes) > 0: if len(bounding_boxes) > 0:
for box in bounding_boxes: for box in bounding_boxes:
@ -136,3 +220,36 @@ def adjust_if_intersects(x: int, y: int,
bounding_boxes.append(adjusted_bounding_box) bounding_boxes.append(adjusted_bounding_box)
return adjusted_bounding_box return adjusted_bounding_box
def get_font_size(y_1, y_2, font_size_max: int, font_size_min: int) -> int:
"""Get the average of the maximum and minimum font sizes"""
if font_size_min > font_size_max:
raise ValueError("Minimum font size cannot be greater than maximum font size")
font_size = min(
max(int(abs(2/3*(y_2-y_1))), font_size_min),
font_size_max)
return font_size
def rainbow_text(draw,text,x,y,font):
for i, letter in enumerate(text):
# Calculate hue for rainbow effect
# Convert HSV to RGB (using full saturation and value)
rgb = tuple(np.random.randint(50,255,3))
# Get the width of this letter
letter_bbox = draw.textbbox((x, y), letter, font=font)
letter_width = letter_bbox[2] - letter_bbox[0]
# Draw the letter
draw.text((x, y), letter, fill=rgb, font=font)
# Move x position for next letter
x += letter_width
if __name__ == "__main__":
pass

Binary file not shown.

View File

@ -1,142 +1,315 @@
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from typing import List, Dict from typing import List, Dict
from datetime import datetime, timedelta
from dotenv import load_dotenv from dotenv import load_dotenv
import os , sys, torch, time, ast import os , sys, torch, time, ast, json, pytz
from werkzeug.exceptions import TooManyRequests from werkzeug.exceptions import TooManyRequests
from multiprocessing import Process, Event, Value from multiprocessing import Process, Event, Value
load_dotenv() load_dotenv()
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import device, GEMINI_API_KEY, GROQ_API_KEY from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
from logging_config import logger from logging_config import logger
from groq import Groq from groq import Groq as Groqq
import google.generativeai as genai import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
import asyncio import asyncio
import aiohttp
from functools import wraps from functools import wraps
from data import session, Api, Translations
from typing import Optional
class ApiModel(): class ApiModel():
def __init__(self, model, # model name def __init__(self, model, # model name as defined by the API
rate, # rate of calls per minute site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
api_key, # api key for the model wrt the site api_key: Optional[str] = None, # api key for the model wrt the site
rpmin: Optional[int] = None, # rate of calls per minute
rph: Optional[int] = None, # rate of calls per hour
rpd: Optional[int] = None, # rate of calls per day
rpw: Optional[int] = None, # rate of calls per week
rpmth: Optional[int] = None, # rate of calls per month
rpy: Optional[int] = None # rate of calls per year
): ):
self.model = model
self.rate = rate
self.api_key = api_key self.api_key = api_key
self.curr_calls = Value('i', 0) self.model = model
self.time = Value('i', 0) self.rpmin = rpmin
self.process = None self.rph = rph
self.stop_event = Event() self.rpd = rpd
self.site = None self.rpw = rpw
self.rpmth = rpmth
self.rpy = rpy
self.site = site
self.from_lang = None self.from_lang = None
self.target_lang = None self.target_lang = None
self.request = None # request response from API self.db_table = None
self.session_calls = 0
self._id = None
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
# Create the table if it does not already exist
def __repr__(self): def __repr__(self):
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.' return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
def __str__(self): def __str__(self):
return self.model return self.model
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
def _get_db_model_id(self):
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
if model:
return model.id
else:
return None
def _set_db_model_id(self):
self._id = self._get_db_model_id()
@staticmethod
def _get_time():
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
def set_lang(self, from_lang, target_lang): def set_lang(self, from_lang, target_lang):
self.from_lang = from_lang self.from_lang = from_lang
self.target_lang = target_lang self.target_lang = target_lang
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit def set_db_table(self, db_table):
async def api_rate_check(self): self.db_table = db_table
# Background task to manage the rate of calls to the API
while not self.stop_event.is_set(): def update_db(self):
start_time = time.monotonic() api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
self.time.value += 5 if not api:
if self.time.value >= 60: api = Api(model_name = self.model,
self.time.value = 0 site = self.site,
self.curr_calls.value = 0 rpmin = self.rpmin,
elapsed = time.monotonic() - start_time rph = self.rph,
# Sleep for exactly 5 seconds minus the elapsed time rpd = self.rpd,
sleep_time = max(0, 5 - elapsed) rpw = self.rpw,
await asyncio.sleep(sleep_time) rpmth = self.rpmth,
rpy = self.rpy)
session.add(api)
session.commit()
self._set_db_model_id()
else:
api.rpmin = self.rpmin
api.rph = self.rph
api.rpd = self.rpd
api.rpw = self.rpw
api.rpmth = self.rpmth
api.rpy = self.rpy
session.commit()
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
translation = json.dumps(translation)
translation = Translations(source_texts = text, translated_texts = translation,
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
translation_mismatch = mismatch)
session.add(translation)
session.commit()
@staticmethod
def _single_period_calls_check(max_calls, call_count):
if not max_calls:
return True
if max_calls <= call_count:
return False
else:
return True
def _are_rates_good(self):
curr_time = self._get_time()
min_ago = curr_time - timedelta(minutes=1)
hour_ago = curr_time - timedelta(hours=1)
day_ago = curr_time - timedelta(days=1)
week_ago = curr_time - timedelta(weeks=1)
month_ago = curr_time - timedelta(days=30)
year_ago = curr_time - timedelta(days=365)
min_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= min_ago
).count()
hour_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= hour_ago
).count()
day_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= day_ago
).count()
week_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= week_ago
).count()
month_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= month_ago
).count()
year_calls = session.query(Translations).join(Api). \
filter(Api.id==self._id,
Translations.timestamp >= year_ago
).count()
if self._single_period_calls_check(self.rpmin, min_calls) \
and self._single_period_calls_check(self.rph, hour_calls) \
and self._single_period_calls_check(self.rpd, day_calls) \
and self._single_period_calls_check(self.rpw, week_calls) \
and self._single_period_calls_check(self.rpmth, month_calls) \
and self._single_period_calls_check(self.rpy, year_calls):
return True
else:
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
return False
async def translate(self, texts_to_translate, store = False) -> tuple[int, # exit code: 0 for success, 1 for incorrect response type, 2 for incorrect translation count
list[str],
int # number of translations that do not match the number of texts to translate
]:
"""Main Translation Function. All API models will need to define a new class and also define a _request function as shown below in the Gemini and Groq class models."""
if isinstance(texts_to_translate, str):
texts_to_translate = [texts_to_translate]
if len(texts_to_translate) == 0:
return (0, [], 0)
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
prompt = f"""INSTRUCTIONS:
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
- Respond using ONLY valid JSON array syntax. Do not use any Python-like dictionary syntax and therefore it must not contain any keys or curly braces.
- Do not include explanations or additional text
- The translations must preserve the original order.
- Each translation must be from the Source language to the Target language
- Source language: {self.from_lang}
- Target language: {self.target_lang}
- Escape special characters properly
Input texts:
{texts_to_translate}
Expected format:
["translation1", "translation2", ...]
Translation:"""
try:
response = await self._request(prompt)
response_list = ast.literal_eval(response.strip())
except Exception as e:
logger.error(f"Failed to evaluate response from {self.model} from {self.site}. Error: {e}.")
return (1, [], 99999)
logger.debug(repr(self))
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
if not isinstance(response_list, list):
# raise TypeError(f"Incorrect response type. Expected list, got {type(response_list)}")
logger.error(f"Incorrect response type. Expected list, got {type(response_list)}")
return (1, [], 99999)
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
logger.error(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
if store:
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
# raise ValueError(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
return (2, response_list, abs(len(texts_to_translate) - len(response_list)))
else:
if store:
self._db_add_translation(texts_to_translate, response_list)
return (0, response_list, 0)
class Groq(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = GROQ_API_KEY, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = 'Groq', **kwargs)
self.client = Groqq()
def background_task(self): async def _request(self, content: str) -> str:
asyncio.run(self.api_rate_check()) async with aiohttp.ClientSession() as session:
async with session.post(
def start(self): "https://api.groq.com/openai/v1/chat/completions",
# Start the background task headers={
self.process = Process(target=self.background_task) "Authorization": f"Bearer {GROQ_API_KEY}",
self.process.daemon = True "Content-Type": "application/json"
self.process.start() },
logger.info(f"Background process started with PID: {self.process.pid}") json={
"messages": [{"role": "user", "content": content}],
def stop(self): "model": self.model
# Stop the background task
logger.info(f"Stopping background process with PID: {self.process.pid}")
self.stop_event.set()
if self.process:
self.process.join(timeout=5)
if self.process.is_alive():
self.process.terminate()
def request_func(request):
@wraps(request)
def wrapper(self, text, *args, **kwargs):
if self.curr_calls.value < self.rate:
# try:
response = request(self, text, *args, **kwargs)
self.curr_calls.value += 1
return response
# except Exception as e:
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
else:
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
raise TooManyRequests('Rate limit reached for this model.')
return wrapper
@request_func
def translate(self, request_fn, texts_to_translate):
if len(texts_to_translate) == 0:
return []
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
response = request_fn(self, prompt)
return ast.literal_eval(response.strip())
class Groqq(ApiModel):
def __init__(self, model, rate, api_key = GROQ_API_KEY):
super().__init__(model, rate, api_key)
self.site = "Groq"
def request(self, content):
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": content,
} }
], ) as response:
model=self.model response_json = await response.json()
) return response_json["choices"][0]["message"]["content"]
return chat_completion.choices[0].message.content # https://console.groq.com/settings/limits for limits
def translate(self, texts_to_translate): class Gemini(ApiModel):
return super().translate(Groqq.request, texts_to_translate) def __init__(self, # model name as defined by the API
model,
class Gemini(ApiModel): api_key = GEMINI_API_KEY, # api key for the model wrt the site
def __init__(self, model, rate, api_key = GEMINI_API_KEY): **kwargs):
super().__init__(model, rate, api_key) super().__init__(model,
self.site = "Gemini" api_key = api_key,
site = 'Google',
**kwargs)
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
headers={
"Content-Type": "application/json"
},
json={
"contents": [{"parts": [{"text": content}]}],
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
}
]
}
) as response:
response_json = await response.json()
return response_json['candidates'][0]['content']['parts'][0]['text']
"""
DEFINE YOUR OWN API MODELS BELOW WITH THE SAME TEMPLATE AS BELOW. All fields required are indicated by <required field>.
class <NameOfWebsite>(ApiModel):
def __init__(self, # model name as defined by the API
model,
api_key = <API_KEY>, # api key for the model wrt the site
**kwargs):
super().__init__(model,
api_key = api_key,
site = <name_of_website>,
**kwargs)
async def _request(self, content):
async with aiohttp.ClientSession() as session:
async with session.post(
<API ENDPOINT e.g. https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}>,
headers={
"Content-Type": "application/json"
<ANY OTHER HEADERS REQUIRED BY THE API separated by commas>
},
json={
"contents": [{"parts": [{"text": content}]}]
<ANY OTHER JSON PAIRS REQUIRED separated by commas>
}
) as response:
response_json = await response.json()
return <Anything needed to extract the message response from `response_json`>
"""
###################################################################################################
### LOCAL LLM TRANSLATION
def request(self, content):
genai.configure(api_key=self.api_key)
safety_settings = {
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
return response.text.strip()
def translate(self, texts_to_translate):
return super().translate(Gemini.request, texts_to_translate)
class TranslationDataset(Dataset): class TranslationDataset(Dataset):
def __init__(self, texts: List[str], tokenizer, max_length: int = 512): def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
""" """
@ -262,11 +435,13 @@ def generate_text(
return all_generated_texts return all_generated_texts
if __name__ == '__main__': if __name__ == '__main__':
GROQ_API_KEY = os.getenv('GROQ_API_KEY') from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100Tokenizer, M2M100ForConditionalGeneration
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') opus_model = 'Helsinki-NLP/opus-mt-en-zh'
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY) LOCAL_FILES_ONLY = True
groq.set_lang('zh','en') tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY) model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
gemini.set_lang('zh','en') # tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
print(gemini.translate(['荷兰咯'])) # tokenizer.src_lang = "en"
print(groq.translate(['荷兰咯'])) # model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
print(generate_text([ i.lower().capitalize() for i in ['placeholder','Story','StoRY', 'TufoRIaL', 'CovFfG', 'LoaD DaTA', 'SAME DATa', 'ReTulN@TitIE', 'View', '@niirm', 'SysceM', 'MeNu:', 'MaND', 'CoM', 'SeLEcT', 'Frogguingang', 'Tutorias', 'Back']], model, tokenizer))

View File

@ -21,7 +21,6 @@ def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
def _paddle_ocr(ocr, image) -> list: def _paddle_ocr(ocr, image) -> list:
### return a list containing the bounding box, text and confidence of the detected text ### return a list containing the bounding box, text and confidence of the detected text
result = ocr.ocr(image, cls=False)[0] result = ocr.ocr(image, cls=False)[0]
if not isinstance(result, list): if not isinstance(result, list):
@ -32,28 +31,30 @@ def _paddle_ocr(ocr, image) -> list:
# EasyOCR has support for many languages # EasyOCR has support for many languages
def _easy_init(easy_languages: list, use_GPU=True, **kwargs): def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
langs = [] return easyocr.Reader(easy_languages, gpu=use_GPU, **kwargs)
for lang in easy_languages:
langs.append(standardize_lang(lang)['easyocr_lang'])
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
def _easy_ocr(ocr,image) -> list: def _easy_ocr(ocr,image) -> list:
return ocr.readtext(image) detected_texts = ocr.readtext(image)
return detected_texts
# RapidOCR mostly for mandarin and some other asian languages # RapidOCR mostly for mandarin and some other asian languages
# default only supports chinese and english
def _rapid_init(use_GPU=True, **kwargs): def _rapid_init(use_GPU=True, **kwargs):
return RapidOCR(use_gpu=use_GPU, **kwargs) return RapidOCR(use_gpu=use_GPU, **kwargs)
def _rapid_ocr(ocr, image) -> list: def _rapid_ocr(ocr, image) -> list:
return ocr(image) return ocr(image)[0]
### Initialize the OCR model ### Initialize the OCR model
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs): def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch_sim', use_GPU=True):
if model == 'paddle': if model == 'paddle':
paddle_lang = standardize_lang(paddle_lang)['paddleocr_lang']
return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU) return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
elif model == 'easy': elif model == 'easy':
return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU) langs = []
for lang in easy_languages:
langs.append(standardize_lang(lang)['easyocr_lang'])
return _easy_init(easy_languages=langs, use_GPU=use_GPU)
elif model == 'rapid': elif model == 'rapid':
return _rapid_init(use_GPU=use_GPU) return _rapid_init(use_GPU=use_GPU)
@ -82,15 +83,16 @@ def _id_filtered(ocr, image, lang) -> list:
return results_no_eng return results_no_eng
# ch_sim, ch_tra, ja, ko, en # ch_sim, ch_tra, ja, ko, en input
def _id_lang(ocr, image, lang) -> list: def _id_lang(ocr, image, lang) -> list:
result = _identify(ocr, image) result = _identify(ocr, image)
lang = standardize_lang(lang)['id_model_lang'] lang = standardize_lang(lang)['id_model_lang']
try: # try:
filtered = [entry for entry in result if contains_lang(entry[1], lang)] logger.info(f"Filtering out phrases not in {lang}.")
except: filtered = [entry for entry in result if contains_lang(entry[1], lang)]
logger.error(f"Selected language not part of default: {default_languages}.") # except:
raise ValueError(f"Selected language not part of default: {default_languages}.") # logger.error(f"Selected language not part of default: {default_languages}.")
# raise ValueError(f"Selected language not part of default: {default_languages}.")
return filtered return filtered
def id_keep_source_lang(ocr, image, lang) -> list: def id_keep_source_lang(ocr, image, lang) -> list:
@ -116,9 +118,6 @@ def get_confidences(ocr_output) -> list:
if __name__ == '__main__': if __name__ == '__main__':
# OCR_languages = ['ch_sim','en'] OCR_languages = ['ch_sim','en']
# image_old = '/home/James/Pictures/Screenshots/DP-1.jpg' reader = easyocr.Reader(OCR_languages, gpu=True)
# reader = easyocr.Reader(OCR_languages, gpu=True) # this needs to run only once to load the model into memory
# result = reader.readtext(image_old)
# print(result)
print(id_keep_source_lang(init_OCR(model='paddle', paddle_lang='zh', easy_languages=['en', 'ch_sim']), '/home/James/Pictures/Screenshots/DP-1.jpg', 'ch_sim'))

View File

@ -1,17 +1,16 @@
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
import google.generativeai as genai import google.generativeai as genai
import torch, os, sys, ast, json import torch, os, sys, ast, json, asyncio, batching, random
from typing import List, Optional, Set
from utils import standardize_lang from utils import standardize_lang
from functools import wraps from functools import wraps
import random from batching import generate_text, Gemini, Groq, ApiModel
import batching
from batching import generate_text, Gemini, Groq
from logging_config import logger from logging_config import logger
from multiprocessing import Process,Event from asyncio import Task
# root dir # root dir
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
############################## ##############################
# translation decorator # translation decorator
@ -30,30 +29,73 @@ def translate(translation_func):
############################### ###############################
def init_API_LLM(from_lang, target_lang): def init_API_LLM(from_lang, target_lang):
"""Initialise the API models. The models are stored in a json file. The models are instantiated, added to database/database api rates are updated and the languages are set."""
from_lang = standardize_lang(from_lang)['translation_model_lang'] from_lang = standardize_lang(from_lang)['translation_model_lang']
target_lang = standardize_lang(target_lang)['translation_model_lang'] target_lang = standardize_lang(target_lang)['translation_model_lang']
with open('api_models.json', 'r') as f: with open(API_MODELS_FILEPATH, 'r') as f:
models_and_rates = json.load(f) models_and_rates = json.load(f)
models = [] models = []
for class_type, class_models in models_and_rates.items(): for class_type, class_models in models_and_rates.items():
cls = getattr(batching, class_type) cls = getattr(batching, class_type)
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()] instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
models.extend(instantiated_objects) models.extend(instantiated_objects)
for model in models: for model in models:
model.start() model.update_db()
model.set_lang(from_lang, target_lang) model.set_lang(from_lang, target_lang)
return models return models
def translate_API_LLM(text, models): async def translate_API_LLM(texts_to_translate: List[str],
models: List[ApiModel],
call_size: int = 2) -> List[str]:
"""Translate the texts using the models three at a time. If the models fail to translate the text, it will try the next model in the list."""
async def try_translate(model: ApiModel) -> Optional[List[str]]:
result = await model.translate(texts_to_translate, store=True)
logger.debug(f'Try_translate result: {result}')
return result
random.shuffle(models) random.shuffle(models)
for model in models: groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
try: no_of_models = len(models)
return model.translate(text) translation_attempts = 0
except:
continue best_translation = None # (model, translation_errors)
logger.error("All models have failed to translate the text.")
for group in groups:
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
while tasks:
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_COMPLETED
)
logger.debug(f"Tasks done: {done}")
logger.debug(f"Tasks remaining: {pending}")
for task in done:
result = await task
logger.debug(f'Result: {result}')
if result is not None:
tasks.discard(task)
translation_attempts += 1
status_code, translations, translation_mismatches = result
if status_code == 0:
# Cancel remaining tasks
for t in pending:
t.cancel()
return translations
else:
logger.error(f"Model has failed to translate the text. Result: {result}")
if translation_attempts == no_of_models:
if best_translation is not None:
return translations
else:
logger.error("All models have failed to translate the text.")
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
elif status_code == 2:
if best_translation is None:
best_translation = (translations, translation_mismatches)
else:
best_translation = (translations, translation_mismatches) if len(result[2]) < len(best_translation[1]) else best_translation
else:
continue
############################### ###############################
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well. # Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
def init_AYA(): def init_AYA():
@ -101,15 +143,17 @@ def get_OPUS_model(from_lang, target_lang):
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'): def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
opus_model = get_OPUS_model(from_lang, target_lang) opus_model = get_OPUS_model(from_lang, target_lang)
logger.debug(f"OPUS model: {opus_model}")
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY) tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device) model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
model.eval() model.eval()
return (model, tokenizer) return (model, tokenizer)
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]: def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
translated_text = generate_text(model,tokenizer, text, translated_text = generate_text(text, model,tokenizer,
batch_size=BATCH_SIZE, device=device, batch_size=BATCH_SIZE, device=device,
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS) max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
logger.debug(f"Translated text: {translated_text}")
return translated_text return translated_text
############################### ###############################
@ -132,6 +176,7 @@ def translate_Seq_LLM(text,
model, model,
tokenizer, tokenizer,
**kwargs): **kwargs):
text = [t.lower().capitalize() for t in text]
if model_type == 'opus': if model_type == 'opus':
return translate_OPUS(text, model, tokenizer) return translate_OPUS(text, model, tokenizer)
elif model_type == 'm2m': elif model_type == 'm2m':

View File

@ -4,7 +4,10 @@ import pyscreenshot as ImageGrab # wayland tings not sure if it will work on oth
import mss, io, os import mss, io, os
from PIL import Image from PIL import Image
import jaconv, MeCab, unidic, pykakasi import jaconv, MeCab, unidic, pykakasi
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import subprocess
# for creating furigana # for creating furigana
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR)) mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
uroman = ur.Uroman() uroman = ur.Uroman()
@ -23,33 +26,26 @@ def intercepts(x,y):
def is_wayland(): def is_wayland():
return 'WAYLAND_DISPLAY' in os.environ return 'WAYLAND_DISPLAY' in os.environ
# path to save screenshot of monitor to # please install grim otherwise this is way too slow for wayland
def printsc_wayland(region, save: bool = False, path: str = None): def printsc_wayland(region: tuple, path: str):
if save: subprocess.run(['grim','-g', f'{region[0]},{region[1]} {region[2]-region[0]}x{region[3]-region[1]}', '-t', 'jpeg', '-q','90', path])
im = ImageGrab.grab(bbox=region)
im.save(path)
else:
return ImageGrab.grab(bbox=region)
def printsc_non_wayland(region: tuple, path: str):
def printsc_non_wayland(region, save: bool = False, path: str = None):
# use mss to capture the screen # use mss to capture the screen
with mss.mss() as sct: with mss.mss() as sct:
# grab the screen # grab the screen
img = sct.grab(region) img = sct.grab(region)
# convert the image to a PIL image # convert the image to a PIL image
image = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX") image = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
# save the image if save is True image.save(path)
if save:
image.save(path)
def printsc(region, save: bool = False, path: str = None): def printsc(region: tuple, path: str):
try: try:
if is_wayland(): if is_wayland():
return printsc_wayland(region, save, path) printsc_wayland(region, path)
else: else:
return printsc_non_wayland(region, save, path) printsc_non_wayland(region, path)
except Exception as e: except Exception as e:
print(f'Error {e}') print(f'Error {e}')
@ -95,10 +91,10 @@ def contains_katakana(text):
# use kakasi to romanize japanese text # use kakasi to romanize japanese text
def romanize(text, lang): def romanize(text, lang):
if lang == 'zh': if lang in ['zh','ch_sim','ch_tra']:
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)]) return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
if lang == 'ja': if lang == 'ja':
return kks.convert(text)[0]['hepburn'] return ' '.join([romaji['hepburn'] for romaji in kks.convert(text)])
return uroman.romanize_string(text) return uroman.romanize_string(text)
# check if a string contains words from a language # check if a string contains words from a language
@ -107,7 +103,7 @@ def contains_lang(text, lang):
if lang == 'zh': if lang == 'zh':
return bool(re.search(r'[\u4e00-\u9fff]', text)) return bool(re.search(r'[\u4e00-\u9fff]', text))
elif lang == 'ja': elif lang == 'ja':
return bool(re.search(r'[\u3040-\u30ff]', text)) return bool(re.search(r'[\u3040-\u30ff]', text)) or bool(re.search(r'[\u4e00-\u9fff]', text))
elif lang == 'ko': elif lang == 'ko':
return bool(re.search(r'[\uac00-\ud7af]', text)) return bool(re.search(r'[\uac00-\ud7af]', text))
elif lang == 'en': elif lang == 'en':
@ -131,13 +127,13 @@ def standardize_lang(lang):
id_model_lang = 'zh' id_model_lang = 'zh'
elif lang == 'ja': elif lang == 'ja':
easyocr_lang = 'ja' easyocr_lang = 'ja'
paddleocr_lang = 'ja' paddleocr_lang = 'japan'
rapidocr_lang = 'ja' rapidocr_lang = 'ja'
translation_model_lang = 'ja' translation_model_lang = 'ja'
id_model_lang = 'ja' id_model_lang = 'ja'
elif lang == 'ko': elif lang == 'ko':
easyocr_lang = 'korean' easyocr_lang = 'ko'
paddleocr_lang = 'ko' paddleocr_lang = 'korean'
rapidocr_lang = 'ko' rapidocr_lang = 'ko'
translation_model_lang = 'ko' translation_model_lang = 'ko'
id_model_lang = 'ko' id_model_lang = 'ko'
@ -165,8 +161,38 @@ def which_ocr_lang(model):
else: else:
raise ValueError("Invalid OCR model. Please use one of 'easy', 'paddle', or 'rapid'.") raise ValueError("Invalid OCR model. Please use one of 'easy', 'paddle', or 'rapid'.")
def similar_tfidf(list1,list2) -> float:
"""Calculate cosine similarity using TF-IDF vectors."""
if not list1 or not list2:
return 0.0
vectorizer = TfidfVectorizer()
all_texts = list1 + list2
tfidf_matrix = vectorizer.fit_transform(all_texts)
# Calculate average vectors for each list
vec1 = np.mean(tfidf_matrix[:len(list1)].toarray(), axis=0).reshape(1, -1)
vec2 = np.mean(tfidf_matrix[len(list1):].toarray(), axis=0).reshape(1, -1)
return cosine_similarity(vec1, vec2)[0, 0]
def similar_jacard(list1, list2) -> float:
if not list1 or not list2:
return 0.0
return len(set(list1).intersection(set(list2))) / len(set(list1).union(set(list2)))
def check_similarity(list1, list2, threshold, method = 'tfidf'):
if method == 'tfidf':
try:
confidence = similar_tfidf(list1, list2)
except ValueError:
return True
return True if confidence > threshold else False
elif method == 'jacard':
return True if similar_jacard(list1, list2) >= threshold else False
else:
raise ValueError("Invalid method. Please use one of 'tfidf' or 'jacard'.")
if __name__ == "__main__": if __name__ == "__main__":
# Example usage # Example usage
japanesetext = "本が好きにちは" print(romanize(lang='ja', text='世界はひろい'))
print(add_furigana(japanesetext))

View File

@ -48,8 +48,8 @@ def setup_logger(
# Create a formatter and set it for both handlers # Create a formatter and set it for both handlers
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s - %(name)s - [%(levelname)s] %(message)s', '%(asctime)s.%(msecs)03d - %(name)s - [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt="%Y-%m-%d %H:%M:%S"
) )
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
@ -64,4 +64,5 @@ def setup_logger(
print(f"Failed to setup logger: {e}") print(f"Failed to setup logger: {e}")
return None return None
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG) logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.INFO)

33
main.py Normal file
View File

@ -0,0 +1,33 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys, threading, subprocess, asyncio
import config
import web_app, qt_app
from logging_config import logger
from data import create_tables
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
###################################################################################
create_tables()
def main(app):
logger.info('Configuration:')
for i in dir(config):
if not callable(getattr(config, i)) and not i.startswith("__"):
logger.info(f'{i}: {getattr(config, i)}')
if app == 'qt':
# Start the Qt app
qt_app.qt_app_main()
elif app == 'web':
threading.Thread(target=asyncio.run, args=(web_app.web_app_main(),), daemon=True).start()
web_app.app.run(host='0.0.0.0', port=5000, debug=False)
################### TODO ##################
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
if __name__ == '__main__':
main('qt')

149
qt_app.py Normal file
View File

@ -0,0 +1,149 @@
import config, asyncio, sys, os, time, numpy as np, qt_app, web_app
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image, check_similarity, is_wayland
from ocr import get_words, init_OCR, id_keep_source_lang
from data import Base, engine, create_tables
from draw import modify_image
from config import (SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY,
REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL,
IMAGE_CHANGE_THRESHOLD, TEMP_IMG_PATH)
from logging_config import logger
from PySide6.QtWidgets import QMainWindow, QLabel, QVBoxLayout, QWidget, QApplication
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtGui import QPixmap, QImage
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("Translator")
# Create main widget and layout
main_widget = QWidget()
self.setCentralWidget(main_widget)
layout = QVBoxLayout(main_widget)
# Create image label
self.image_label = QLabel()
layout.addWidget(self.image_label)
# Set up image generator thread
self.generator = qt_app.ImageGenerator()
self.generator.image_ready.connect(self.update_image)
self.generator.start()
# Set initial window size
window_width, width_height = REGION[2] - REGION[0], REGION[3] - REGION[1]
self.resize(window_width, width_height)
def update_image(self, image_buffer):
"""Update the displayed image directly from buffer bytes"""
if image_buffer is None:
return
# Convert buffer to QImage
q_image = QImage.fromData(image_buffer)
if q_image.isNull():
logger.error("Failed to create QImage from buffer")
return
# Convert QImage to QPixmap and display it
pixmap = QPixmap.fromImage(q_image)
# Scale the pixmap to fit the label while maintaining aspect ratio
scaled_pixmap = pixmap.scaled(
self.image_label.size(),
Qt.KeepAspectRatio,
Qt.SmoothTransformation
)
self.image_label.setPixmap(scaled_pixmap)
class ImageGenerator(QThread):
"""Thread for generating images continuously"""
image_ready = Signal(np.ndarray)
def __init__(self):
super().__init__()
printsc(REGION, TEMP_IMG_PATH)
self.running = True
self.OCR_LANGUAGES = [SOURCE_LANG, 'en']
self.ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = self.OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
self.ocr_output = id_keep_source_lang(self.ocr, TEMP_IMG_PATH, SOURCE_LANG)
self.models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
# self.model, self.tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
self.runs = 0
self.prev_words = set()
self.curr_words = set(get_words(self.ocr_output))
self.translated_image = None
def run(self):
asyncio.run(self.async_run())
async def async_run(self):
while self.running:
logger.debug("Capturing screen")
printsc(REGION, TEMP_IMG_PATH)
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
self.ocr_output = id_keep_source_lang(self.ocr, TEMP_IMG_PATH, SOURCE_LANG) # keep only phrases containing the source language
logger.debug(f"OCR completed. Detected {len(self.ocr_output)} phrases.")
if self.runs == 0:
logger.info('Initial run')
self.prev_words = set()
else:
logger.debug(f'Run number: {self.runs}.')
self.runs += 1
self.curr_words = set(get_words(self.ocr_output))
logger.debug(f'Current words: {self.curr_words} Previous words: {self.prev_words}')
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if self.prev_words != self.curr_words and not check_similarity(list(self.curr_words), list(self.prev_words), threshold = IMAGE_CHANGE_THRESHOLD, method="tfidf"):
logger.info('Beginning Translation')
to_translate = [entry[1] for entry in self.ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = self.model, tokenizer = self.tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
try:
translation = await translate_API_LLM(to_translate, self.models, call_size = 3)
except TypeError as e:
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for {2*INTERVAL} seconds.")
time.sleep(2*INTERVAL)
continue
logger.debug('Translation complete. Modifying image.')
self.translated_image = modify_image(TEMP_IMG_PATH, self.ocr_output, translation)
# view_buffer_app.show_buffer_image(translated_image, label)
logger.debug("Image modified. Saving image.")
self.prev_words = self.curr_words
else:
logger.info(f"Skipping translation. No significant change in the screen detected. Total translation attempts so far: {self.runs}.")
logger.debug("Continuing to next iteration.")
time.sleep(INTERVAL)
self.image_ready.emit(self.translated_image)
def stop(self):
self.running = False
self.wait()
def closeEvent(self, event):
"""Clean up when closing the window"""
self.generator.stop()
event.accept()
def qt_app_main():
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec())
if __name__ == "__main__":
qt_app_main()

115
qtapp.py
View File

@ -1,115 +0,0 @@
###################################################################################
##### IMPORT LIBRARIES #####
import os, time, sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image
from ocr import get_words, init_OCR, id_keep_source_lang
from logging_config import logger
from draw import modify_image_bytes
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
from create_overlay import app, overlay
from typing import Optional, List
###################################################################################
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
QColor, QIcon, QImage, QPixmap)
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QLabel, QSystemTrayIcon, QMenu)
from dataclasses import dataclass
class TranslationThread(QThread):
translation_ready = Signal(list, list) # Signal to send translation results
start_capture = Signal()
end_capture = Signal()
screen_capture = Signal(int, int, int, int)
def __init__(self, ocr, models, source_lang, target_lang, interval):
super().__init__()
self.ocr = ocr
self.models = models
self.source_lang = source_lang
self.target_lang = target_lang
self.interval = interval
self.running = True
self.prev_words = set()
self.runs = 0
def run(self):
while self.running:
self.start_capture.emit()
untranslated_image = printsc(REGION)
self.end_capture.emit()
byte_image = convert_image_to_bytes(untranslated_image)
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
if self.runs == 0:
logger.info('Initial run')
else:
logger.info(f'Run number: {self.runs}.')
self.runs += 1
curr_words = set(get_words(ocr_output))
if self.prev_words != curr_words:
logger.info('Translating')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
translation = translate_API_LLM(to_translate, self.models)
logger.info(f'Translation from {to_translate} to\n {translation}')
# Emit the translation results
modify_image_bytes(byte_image, ocr_output, translation)
self.translation_ready.emit(ocr_output, translation)
self.prev_words = curr_words
else:
logger.info("No new words to translate. Output will not refresh.")
logger.info(f'Sleeping for {self.interval} seconds')
time.sleep(self.interval)
def stop(self):
self.running = False
def main():
# Initialize OCR
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
# Initialize translation
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
# Create and start translation thread
translation_thread = TranslationThread(
ocr=ocr,
models=models,
source_lang=SOURCE_LANG,
target_lang=TARGET_LANG,
interval=INTERVAL
)
# Connect translation results to overlay update
translation_thread.start_capture.connect(overlay.prepare_for_capture)
translation_thread.end_capture.connect(overlay.restore_after_capture)
translation_thread.translation_ready.connect(overlay.update_translation)
translation_thread.screen_capture.connect(overlay.capture_behind)
# Start the translation thread
translation_thread.start()
# Start Qt event loop
result = app.exec()
# Cleanup
translation_thread.stop()
translation_thread.wait()
return result
if __name__ == "__main__":
sys.exit(main())

262
requirements.txt Normal file
View File

@ -0,0 +1,262 @@
absl-py==2.1.0
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
albucore==0.0.13
albumentations==1.4.10
annotated-types==0.7.0
anyio==4.6.2.post1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astor==0.8.1
asttokens==2.4.1
astunparse==1.6.3
async-lru==2.0.4
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bleach==6.2.0
blinker==1.8.2
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
charset-normalizer==3.4.0
click==8.1.7
coloredlogs==15.0.1
comm==0.2.2
contourpy==1.3.0
ctranslate2==4.5.0
cycler==0.12.1
Cython==3.0.11
datasets==3.1.0
debugpy==1.8.7
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.14
dill==0.3.8
distro==1.9.0
easyocr==1.7.2
EasyProcess==1.1
entrypoint2==1.1
eval_type_backport==0.2.0
executing==2.1.0
fastjsonschema==2.20.0
filelock==3.16.1
fire==0.7.0
Flask==3.0.3
Flask-SSE==1.0.0
flatbuffers==24.3.25
fonttools==4.54.1
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2024.10.0
gast==0.6.0
google==3.0.0
google-ai-generativelanguage==0.6.10
google-api-core==2.22.0
google-api-python-client==2.151.0
google-auth==2.35.0
google-auth-httplib2==0.2.0
google-generativeai==0.8.3
google-pasta==0.2.0
googleapis-common-protos==1.65.0
greenlet==3.1.1
groq==0.11.0
grpcio==1.67.1
grpcio-status==1.67.1
h11==0.14.0
h5py==3.12.1
httpcore==1.0.6
httplib2==0.22.0
httpx==0.27.2
huggingface-hub==0.26.2
humanfriendly==10.0
idna==3.10
imageio==2.36.0
imgaug==0.4.0
ipykernel==6.29.5
ipython==8.29.0
ipywidgets==8.1.5
isoduration==20.11.0
itsdangerous==2.2.0
jaconv==0.4.0
jedi==0.19.1
jeepney==0.8.0
Jinja2==3.1.4
jiter==0.7.0
joblib==1.4.2
json5==0.9.25
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
keras==3.6.0
kiwisolver==1.4.7
langid==1.1.6
lazy_loader==0.4
libclang==18.1.1
lmdb==1.5.1
lxml==5.3.0
Markdown==3.7
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mecab-python3==1.0.10
mistralai==1.1.0
mistune==3.0.2
ml-dtypes==0.4.1
mpmath==1.3.0
mss==9.0.2
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
namex==0.0.8
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
ninja==1.11.1.1
notebook==7.2.2
notebook_shim==0.2.4
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
onnxruntime==1.19.2
openai==1.53.0
opencv-contrib-python==4.10.0.84
opt-einsum==3.3.0
optimum==1.23.3
optree==0.13.0
overrides==7.7.0
packaging==24.1
paddleocr==2.9.1
paddlepaddle-gpu==2.6.2
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.0.0
pinyin==0.4.0
plac==1.4.3
platformdirs==4.3.6
prometheus_client==0.21.0
prompt_toolkit==3.0.48
propcache==0.2.0
proto-plus==1.25.0
protobuf==5.28.3
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
pyarrow==18.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pyclipper==1.3.0.post6
pycparser==2.22
pydantic==2.9.2
pydantic_core==2.23.4
pydotenv==0.0.7
Pygments==2.18.0
pykakasi==2.3.0
pyparsing==3.2.0
pypinyin==0.53.0
pyscreenshot==3.1
PySide6==6.8.0.2
PySide6_Addons==6.8.0.2
PySide6_Essentials==6.8.0.2
python-bidi==0.6.3
python-dateutil==2.8.2
python-docx==1.1.2
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
RapidFuzz==3.10.1
rapidocr-onnxruntime==1.3.25
redis==5.2.0
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.9.3
rpds-py==0.20.0
rsa==4.9
sacremoses==0.1.1
safetensors==0.4.5
scikit-image==0.24.0
scikit-learn==1.5.2
scipy==1.14.1
Send2Trash==1.8.3
sentencepiece==0.2.0
setuptools==75.3.0
shapely==2.0.6
shiboken6==6.8.0.2
six==1.16.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.36
stack-data==0.6.3
sympy==1.13.1
tensorboard==2.18.0
tensorboard-data-server==0.7.2
termcolor==2.5.0
terminado==0.18.1
threadpoolctl==3.5.0
tifffile==2024.9.20
tinycss2==1.4.0
tokenizers==0.20.1
tomli==2.0.2
torch==2.5.1
torchvision==0.20.1
tornado==6.4.1
tqdm==4.66.6
traitlets==5.14.3
transformers==4.46.1
triton==3.1.0
types-python-dateutil==2.9.0.20241003
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
unidic==1.1.0
uri-template==1.3.0
uritemplate==4.1.1
urllib3==2.2.3
uroman==1.3.1.1
wasabi==0.10.1
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
Werkzeug==3.0.6
wheel==0.44.0
widgetsnbextension==4.0.13
wrapt==1.16.0
xxhash==3.5.0
yarl==1.17.1

View File

@ -17,7 +17,7 @@
setInterval(function () { setInterval(function () {
document.getElementById("live-image").src = document.getElementById("live-image").src =
"/image?" + new Date().getTime(); "/image?" + new Date().getTime();
}, 3500); // Update every 2 seconds }, 1500); // Update every 2.5 seconds. Beware that if the image fails to reload on time, the browser will continuously refresh without being able to display the images.
</script> </script>
</body> </body>
</html> </html>

35
web.py
View File

@ -1,35 +0,0 @@
from flask import Flask, Response, render_template
import threading
import io
import app
app = Flask(__name__)
# Global variable to hold the current image
def curr_image():
return app.latest_image
@app.route('/')
def index():
return render_template('index.html')
@app.route('/image')
def stream_image():
if curr_image() is None:
return "No image generated yet.", 503
file_object = io.BytesIO()
curr_image().save(file_object, 'PNG')
file_object.seek(0)
response = Response(file_object.getvalue(), mimetype='image/png')
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' # HTTP 1.1
response.headers['Pragma'] = 'no-cache' # HTTP 1.0
response.headers['Expires'] = '0' # Proxies
return response
if __name__ == '__main__':
# Start the image updating thread
threading.Thread(target=app.main, daemon=True).start()
# Start the Flask web server
app.run(host='0.0.0.0', port=5000, debug=True)

113
web_app.py Normal file
View File

@ -0,0 +1,113 @@
from flask import Flask, Response, render_template
import threading
import io
import os, time, sys, threading, subprocess
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
from utils import printsc, convert_image_to_bytes, bytes_to_image, check_similarity, is_wayland
from ocr import get_words, init_OCR, id_keep_source_lang
from data import Base, engine, create_tables
from draw import modify_image
import asyncio
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD, TEMP_IMG_PATH
from logging_config import logger
app = Flask(__name__)
latest_image = None
async def web_app_main():
###################################################################################
global latest_image
# Initialisation
##### Create the database if not present #####
create_tables()
##### Initialize the OCR #####
OCR_LANGUAGES = [SOURCE_LANG, 'en']
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
##### Initialize the translation #####
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
###################################################################################
runs = 0
# label, app = view_buffer_app.create_viewer()
# try:
while True:
logger.debug("Capturing screen")
printsc(REGION, TEMP_IMG_PATH)
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
ocr_output = id_keep_source_lang(ocr, TEMP_IMG_PATH, SOURCE_LANG) # keep only phrases containing the source language
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
if runs == 0:
logger.info('Initial run')
prev_words = set()
else:
logger.debug(f'Run number: {runs}.')
runs += 1
curr_words = set(get_words(ocr_output))
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
if prev_words != curr_words and not check_similarity(list(curr_words),list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD, method="tfidf"):
logger.info('Beginning Translation')
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
try:
if len(to_translate) == 0:
logger.info("No text detected. Skipping translation. Continuing to next iteration.")
continue
translation = await translate_API_LLM(to_translate, models, call_size = 3)
except TypeError as e:
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for {2*INTERVAL} seconds.")
time.sleep(2*INTERVAL)
continue
logger.debug('Translation complete. Modifying image.')
translated_image = modify_image(TEMP_IMG_PATH, ocr_output, translation)
latest_image = bytes_to_image(translated_image)
logger.debug("Image modified. Saving image.")
prev_words = curr_words
else:
logger.info(f"Skipping translation. No significant change in the screen detected. Total translation attempts so far: {runs}.")
logger.debug("Continuing to next iteration.")
time.sleep(INTERVAL)
# Global variable to hold the current image
def curr_image():
return latest_image
@app.route('/')
def index():
return render_template('index.html')
@app.route('/image')
def stream_image():
if curr_image() is None:
return "No image generated yet.", 503
file_object = io.BytesIO()
curr_image().save(file_object, 'PNG')
file_object.seek(0)
response = Response(file_object.getvalue(), mimetype='image/png')
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' # HTTP 1.1
response.headers['Pragma'] = 'no-cache' # HTTP 1.0
response.headers['Expires'] = '0' # Proxies
return response
if __name__ == '__main__':
# Start the image updating thread
threading.Thread(target=asyncio.run, args=(web_app_main(),), daemon=True).start()
# Start the Flask web server
app.run(host='0.0.0.0', port=5000, debug=False)