Compare commits
5 Commits
main
...
wayland-op
| Author | SHA1 | Date | |
|---|---|---|---|
| 66bc8f205c | |||
| ecc264cf65 | |||
| 7e80713191 | |||
| 56d8c18871 | |||
| 11600ae70f |
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,3 +6,4 @@ __pycache__/
|
||||
test.py
|
||||
notebooks/
|
||||
qttest.py
|
||||
*.db
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 chickenflyshigh
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
68
README.md
68
README.md
@ -1,4 +1,70 @@
|
||||
## What does this do?
|
||||
|
||||
It continuously provides translations from a source language to another language of a specified region on your screen while also (optionally) providing romanisation (including pinyin and furigana) to provide a guide to pronounciation. The main goal of this is primarily for people that have a low/basic level of understanding of a language to further develop that language by allowing the users to have the tool to allow them to immerse themselves in native content. Main uses of this include but are not limited to: playing games and watching videos with subtitles in another language (although technically it might just be better to obtain an audio transcription, translate and replace the subtitles if possible -- however this is not always feasible if watching many episodes and/or you are watching videos spontaneously).
|
||||
|
||||
## Limitations
|
||||
|
||||
If the `learn` mode is enabled for the app, the added translations and romanisation naturally results in texts taking up three times the space and therefore this is less suitable for texts that contain tightly packed words. You can optionally change the config to insert smaller text or change the overall font size of your screen so there are less text. A pure translation mode also exists, although if it is intended for web browsing, Google itself provides a more reliable method of translation which does not rely on the computationally heavy optical character recognition (OCR).
|
||||
|
||||
## Usage (draft)
|
||||
|
||||
1. Clone the repository, navigate to the repository and install all required packages with `pip install -r requirements.txt` in a new Python environment (the OCR packages are very finnicky).
|
||||
|
||||
2. If using external APIs, you will need to obtain the api keys for the currently supported sites [Google (Gemini models), Groq (an assortment of open-source LLMs)] and add in the associated api keys in the environmental variables file. If using another API, you will need to define a new class with the `_request` function in `helpers/batching.py`, inheriting the `ApiModels` class. A template is created in the file under the `Gemini` and `Groq` classes. All exception handling are already taken care of.
|
||||
|
||||
3. Edit the `api_models.json` file for the models you want added. The first level of the json file is the respective class name defined in `helpers/batching.py`. The second level defines the `model` names from their corresponding API endpoints. For the third level, the rates of each model are specified. `rpmin`, `rph`, `rpd`, `rpw`, `rpmth`, `rpy` are respectively the rates per minute, hour, day, week, month, year.
|
||||
|
||||
4. Create and edit the `.env` config file. For information about all the variables to edit, check the section under "EDIT THESE ENVIRONMENTAL VARIABLES" in the `config.py` file. If CUDA is not detected, it will default to using the `CPU` mode for all local LLMs and OCRs. In this case, it is recommended to set the `OCR_MODEL` variable to `rapid` which is optimised for CPUs. Currently the only support for this is with `SOURCE_LANG=ch_tra`, `ch_sim` or `en`. Refer to [notes][1]
|
||||
|
||||
5. If you are using the `wayland` display protocol (only available for Linux -- check with `echo $WAYLAND_DISPLAY`), download the `grim` package onto your machine locally with any of the package managers.
|
||||
|
||||
- Archlinux-based: `sudo pacman -S grim`
|
||||
- Debian-based: `sudo apt install grim`
|
||||
- Fedora: `dnf install grim`
|
||||
- OpenSUSE: `zypper install grim`
|
||||
- NixOS: `nix-shell -p grim`
|
||||
|
||||
Screenshotting is limited in Wayland, and `grim` is one of the more lightweight options out there.
|
||||
|
||||
6. The RapidOCR, PaddleOCR and (maybe I can't remember) the easyOCR models need to be downloaded before any of this can be used. It should download when you execute a function that initialises the model with the desired language to OCR but appears to not work well when running the app directly. (I'll add more to this later...). And this obviously holds the same for the local translation and LLM models...
|
||||
|
||||
7. Run the `main.py` file and a QT6 app should appear. Alternatively if that doesn't work, go to the last line of the `main.py` file and edit the argument to `web` which will run the translations locally on `0.0.0.0:5000` or on any other port you specify.
|
||||
|
||||
## Notes and optimisations
|
||||
|
||||
- Accuracy is limited with RapidOCR especially if there is a high dynamical range in graphics. [1]
|
||||
|
||||
- Consider lowering the quality of the screen capture for faster OCR processing and lower screen capture time -> OCR accuracy and subsequent translations can be affected but entire translation process should be under 2 seconds without too much sacrifice in OCR quality. Edit the `helpers/utils.py` `printsc` functions (will work on setting a config for this).
|
||||
|
||||
- Not much of the database aspect is worked on at the moment. Right now it stores all the texts and translations in unicode/ASCII in the database/translations.db file. Use it however you want, it is stored locally only for you.
|
||||
|
||||
- Downloading all the models may take up a few GBs of space.
|
||||
|
||||
- About 3.5GB of VRAM is used by easyOCR. Up to 1.5GB of VRAM for Paddle and Rapid OCR.
|
||||
|
||||
## Debugging Issues
|
||||
|
||||
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove nvidia-cudnn-cn12 from python environment.
|
||||
1. CUDNN Version mismatch when using PaddleOCR. Check if LD_LIBRARY_PATH is correctly set to the directory containing the cudnn.so file. If using a local installation, it could help to just remove the nvidia-cudnn-cn12 from your Python environment.
|
||||
2. Segmentation fault when using PaddleOCR, EasyOCR or RapidOCR. Ensure the only cv2 library is the opencv-contrib-python library. Check out https://pypi.org/project/opencv-python-headless/ for more info.
|
||||
|
||||
## Demo
|
||||
|
||||
[Demo](https://youtu.be/Tmv_I0GkOQc) of Korean to Chinese (simplified) translation with the `learn-cover` mode (mode intended for people learning the language to see the romanisation/pinyin/furigana etc with the translation above).
|
||||
|
||||
## TODO:
|
||||
|
||||
- Create an overlay window that works in Wayland.
|
||||
- Make use of the translation data -> maybe make a personalised game that uses the data.
|
||||
- Providing the option for simplifying and automating most of the install process.
|
||||
|
||||
# Terms of Use
|
||||
|
||||
## Data Collection and External API Use
|
||||
|
||||
1.1 Onscreen Data Transmission: The application is designed to send data displayed on your screen, including potentially sensitive or personal information, to an external API if local processing is not setup.
|
||||
|
||||
1.2 Third-Party API Integration: When local methods cannot fulfill certain functions, the App will transmit data to external third-party APIs. These APIs are not under our control, and we do not guarantee the security, confidentiality, or purpose of use of the data once transmitted.
|
||||
|
||||
## Acknowledgment
|
||||
|
||||
By using the app, you acknowledge that you have read, understood, and agree to these Terms of Use, including the potential risks associated with transmitting data to external APIs.
|
||||
|
||||
@ -1,13 +1,22 @@
|
||||
{
|
||||
"Gemini": {
|
||||
"gemini-1.5-pro": 2,
|
||||
"gemini-1.5-flash": 15,
|
||||
"gemini-1.5-flash-8b": 8,
|
||||
"gemini-1.0-pro": 15
|
||||
"gemini-1.5-pro": { "rpmin": 2, "rpd": 50 },
|
||||
"gemini-1.5-flash": { "rpmin": 15, "rpd": 1500 },
|
||||
"gemini-1.5-flash-8b": { "rpmin": 15, "rpd": 1500 },
|
||||
"gemini-1.0-pro": { "rpmin": 15, "rpd": 1500 }
|
||||
},
|
||||
"Groqq": {
|
||||
"llama-3.2-90b-text-preview": 30,
|
||||
"llama3-70b-8192": 30,
|
||||
"mixtral-8x7b-32768": 30
|
||||
"Groq": {
|
||||
"llama-3.2-90b-text-preview": { "rpmin": 30, "rpd": 7000 },
|
||||
"llama3-70b-8192": { "rpmin": 30, "rpd": 14400 },
|
||||
"mixtral-8x7b-32768": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama-3.1-70b-versatile": { "rpmin": 30, "rpd": 14400 },
|
||||
"gemma2-9b-it": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama3-groq-8b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama3-groq-70b-8192-tool-use-preview": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama-3.2-90b-vision-preview": { "rpmin": 15, "rpd": 3500 },
|
||||
"llama-3.2-11b-text-preview": { "rpmin": 30, "rpd": 7000 },
|
||||
"llama-3.2-11b-vision-preview": { "rpmin": 30, "rpd": 7000 },
|
||||
"gemma-7b-it": { "rpmin": 30, "rpd": 14400 },
|
||||
"llama3-8b-8192": { "rpmin": 30, "rpd": 14400 }
|
||||
}
|
||||
}
|
||||
|
||||
83
app.py
83
app.py
@ -1,83 +0,0 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL
|
||||
###################################################################################
|
||||
|
||||
ADD_OVERLAY = False
|
||||
|
||||
latest_image = None
|
||||
|
||||
def main():
|
||||
global latest_image
|
||||
|
||||
##### Initialize the OCR #####
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
##### Initialize the translation #####
|
||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
###################################################################################
|
||||
runs = 0
|
||||
app.exec()
|
||||
while True:
|
||||
if ADD_OVERLAY:
|
||||
overlay.clear_all_text()
|
||||
|
||||
untranslated_image = printsc(REGION)
|
||||
|
||||
if ADD_OVERLAY:
|
||||
overlay.text_entries = overlay.text_entries_copy
|
||||
overlay.update()
|
||||
overlay.text_entries.clear()
|
||||
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(ocr, byte_image, SOURCE_LANG) # keep only phrases containing the source language
|
||||
|
||||
if runs == 0:
|
||||
logger.info('Initial run')
|
||||
prev_words = set()
|
||||
else:
|
||||
logger.info(f'Run number: {runs}.')
|
||||
runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
|
||||
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||
if prev_words != curr_words:
|
||||
logger.info('Translating')
|
||||
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
translation = translate_API_LLM(to_translate, models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
translated_image = modify_image_bytes(byte_image, ocr_output, translation)
|
||||
latest_image = bytes_to_image(translated_image)
|
||||
# latest_image.show() # for debugging
|
||||
|
||||
prev_words = curr_words
|
||||
else:
|
||||
logger.info("No new words to translate. Output will not refresh.")
|
||||
|
||||
logger.info(f'Sleeping for {INTERVAL} seconds')
|
||||
time.sleep(INTERVAL)
|
||||
# if ADD_OVERLAY:
|
||||
# sys.exit(app.exec())
|
||||
|
||||
################### TODO ##################
|
||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||
# 5. Maybe refreshing issue of flask app. Also get webpage to update only if the image changes.
|
||||
# Create a way for it to just replace the text and provide only the translation on-screen. Qt6
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
88
config.py
88
config.py
@ -1,61 +1,97 @@
|
||||
import os, ast, torch
|
||||
import os, ast, torch, platform
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
default_tmp_dir = "C:\\Users\\AppData\\Local\\Temp"
|
||||
elif platform.system() in ['Linux', 'Darwin']:
|
||||
default_tmp_dir = "/tmp"
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
### EDIT THESE VARIABLES ###
|
||||
# Create a .env file in the same directory as this file and add the variables there. Of course you can choose to edit this file but then if you pull from the repository again all the config will be goneeee unless perhaps it is saved or stashed.
|
||||
# The default values should be fine for most cases. Only ones that you need to change are the API keys, and the variables under Translation and API Translation if you choose to use an external API.
|
||||
# available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||
|
||||
### available languages: 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||
|
||||
INTERVAL = int(os.getenv('INTERVAL'))
|
||||
INTERVAL = float(os.getenv('INTERVAL', 1.5)) # Interval in seconds between translations. If your system is slow, a lower value is probably fine with regards to API rates.
|
||||
|
||||
### OCR
|
||||
IMAGE_CHANGE_THRESHOLD = float(os.getenv('IMAGE_CHANGE_THRESHOLD', 0.75)) # higher values mean more sensitivity to changes in the screen, too high and the screen will constantly refresh
|
||||
OCR_MODEL = os.getenv('OCR_MODEL', 'easy') # 'easy', 'paddle', 'rapid' ### easy is the most accurate, paddle is the fastest with CUDA and rapid is the fastest with CPU. Rapid has only between Chinese and English unless you add more languages
|
||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True'))
|
||||
OCR_USE_GPU = ast.literal_eval(os.getenv('OCR_USE_GPU', 'True')) # True or False to use CUDA for OCR. Defaults to CPU if no CUDA GPU is available
|
||||
|
||||
|
||||
|
||||
|
||||
### Drawing/Overlay Config
|
||||
ADD_OVERLAY = ast.literal_eval(os.getenv('ADD_OVERLAY', 'True'))
|
||||
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white')
|
||||
FONT_FILE = os.getenv('FONT_FILE')
|
||||
FONT_SIZE = int(os.getenv('FONT_SIZE', 16))
|
||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3))
|
||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)'))
|
||||
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000")
|
||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True'))
|
||||
FILL_COLOUR = os.getenv('FILL_COLOUR', 'white') # colour of the textboxes
|
||||
FONT_COLOUR = os.getenv('FONT_COLOUR', "#ff0000") # colour of the font
|
||||
FONT_FILE = os.getenv('FONT_FILE', os.path.join(__file__, "fonts", "Arial-Unicode-Bold.ttf")) # path to the font file. Ensure it is a unicode .ttf file if you want to be able to see most languages.
|
||||
FONT_SIZE_MAX = int(os.getenv('FONT_SIZE_MAX', 20)) # Maximum font size you want to be able to see onscreen
|
||||
FONT_SIZE_MIN = int(os.getenv('FONT_SIZE_MIN', 8)) # Minimum font size you want to be able to see onscreen
|
||||
LINE_SPACING = int(os.getenv('LINE_SPACING', 3)) # spacing between lines of text with the learn modes in DRAW_TRANSLATIONS_MODE
|
||||
REGION = ast.literal_eval(os.getenv('REGION','(0,0,2560,1440)')) # (x1, y1, x2, y2) - the region of the screen to capture
|
||||
DRAW_TRANSLATIONS_MODE = os.getenv('DRAW_TRANSLATIONS_MODE', 'learn_cover')
|
||||
"""
|
||||
DRAW_TRANSLATIONS_MODE possible options:
|
||||
`learn': adds translated text, original text (should be added so when texts get moved around the translation of which it references is understood) and (optionally with the other TO_ROMANIZE option) romanized text above the original text. Texts can overlap if squished into a corner. Works well for games where texts are sparser
|
||||
'learn_cover': same as above but covers the original text with the translated text. Can help with readability and is less cluttered but with sufficiently dense text the texts can still overlap
|
||||
'translation_only_cover': cover the original text with the translated text - will not show the original text at all but also will not be affected by overlapping texts
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
# API KEYS https://github.com/cheahjs/free-llm-api-resources?tab=readme-ov-file
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_KEY')
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY') #
|
||||
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
|
||||
|
||||
### Translation
|
||||
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200))
|
||||
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ja')
|
||||
TARGET_LANG = os.getenv('TARGET_LANG', 'en')
|
||||
MAX_TRANSLATE = int(os.getenv('MAX_TRANSLATION', 200)) # Maximum number of phrases to send to the translation model to translate
|
||||
SOURCE_LANG = os.getenv('SOURCE_LANG', 'ch_sim') # Translate from 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||
TARGET_LANG = os.getenv('TARGET_LANG', 'en') # Translate to 'ch_sim', 'ch_tra', 'ja', 'ko', 'en'
|
||||
TO_ROMANIZE = ast.literal_eval(os.getenv('TO_ROMANIZE', 'True')) # romanize the text or not. Only available for one of the learn modes in DRAW_TRANSLATIONS_MODE. It is added above the original text
|
||||
|
||||
### API Translation (could be external or a local API)
|
||||
# API KEYS
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') # https://ai.google.dev/
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY') # https://console.groq.com/keys
|
||||
# MISTRAL_API_KEY = os.getenv('MISTRAL_API_KEY') # https://console.mistral.ai/api-keys/ slow asf
|
||||
|
||||
### Local Translation
|
||||
### Local Translation Models
|
||||
TRANSLATION_MODEL= os.environ['TRANSLATION_MODEL'] # 'opus' or 'm2m' # opus is a lot more lightweight
|
||||
TRANSLATION_USE_GPU = ast.literal_eval(os.getenv('TRANSLATION_USE_GPU', 'True'))
|
||||
MAX_INPUT_TOKENS = int(os.getenv('MAX_INPUT_TOKENS', 512))
|
||||
MAX_OUTPUT_TOKENS = int(os.getenv('MAX_OUTPUT_TOKENS', 512))
|
||||
BATCH_SIZE = int(os.getenv('BATCH_SIZE', 6))
|
||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False'))
|
||||
LOCAL_FILES_ONLY = ast.literal_eval(os.getenv('LOCAL_FILES_ONLY', 'False')) # will not attempt pinging Huggingface for the models and just use the cached local models
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
###################################################################################################
|
||||
### DO NOT EDIT THESE VARIABLES ###
|
||||
## Filepaths
|
||||
API_MODELS_FILEPATH = os.path.join(os.path.dirname(__file__), 'api_models.json')
|
||||
|
||||
|
||||
FONT_SIZE = int((FONT_SIZE_MAX + FONT_SIZE_MIN)/2)
|
||||
LINE_HEIGHT = FONT_SIZE
|
||||
|
||||
|
||||
|
||||
|
||||
if TRANSLATION_USE_GPU is False:
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
TEMP_IMG_DIR = os.getenv('TEMP_IMG_PATH', default_tmp_dir) # where the temporary images are stored
|
||||
TEMP_IMG_PATH = os.path.join(TEMP_IMG_DIR, 'tempP_img91258102.png')
|
||||
### Just for info
|
||||
|
||||
available_langs = ['ch_sim', 'ch_tra', 'ja', 'ko', 'en'] # there are limitations with the languages that can be used with the OCR models
|
||||
|
||||
@ -1,305 +0,0 @@
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QBuffer
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont, QScreen, QIcon)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
import sys, io, os, signal, time, platform
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from config import ADD_OVERLAY, FONT_FILE, FONT_SIZE
|
||||
from logging_config import logger
|
||||
|
||||
|
||||
def qpixmap_to_bytes(qpixmap):
|
||||
qimage = qpixmap.toImage()
|
||||
buffer = QBuffer()
|
||||
buffer.open(QBuffer.ReadWrite)
|
||||
qimage.save(buffer, "PNG")
|
||||
return qimage
|
||||
|
||||
@dataclass
|
||||
class TextEntry:
|
||||
text: str
|
||||
x: int
|
||||
y: int
|
||||
font: QFont = QFont('Arial', FONT_SIZE)
|
||||
visible: bool = True
|
||||
text_color: Qt.GlobalColor = Qt.GlobalColor.red
|
||||
background_color: Optional[Qt.GlobalColor] = None
|
||||
padding: int = 1
|
||||
|
||||
class TranslationOverlay(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Translation Overlay")
|
||||
self.is_passthrough = True
|
||||
self.text_entries: List[TextEntry] = []
|
||||
self.setup_window_attributes()
|
||||
self.setup_shortcuts()
|
||||
self.closeEvent = lambda event: QApplication.quit()
|
||||
self.default_font = QFont('Arial', FONT_SIZE)
|
||||
self.text_entries_copy: List[TextEntry] = []
|
||||
self.next_text_entries: List[TextEntry] = []
|
||||
#self.show_background = True
|
||||
self.background_opacity = 0.5
|
||||
# self.setup_tray()
|
||||
|
||||
def prepare_for_capture(self):
|
||||
"""Preserve current state and clear overlay"""
|
||||
if ADD_OVERLAY:
|
||||
self.text_entries_copy = self.text_entries.copy()
|
||||
self.clear_all_text()
|
||||
self.update()
|
||||
|
||||
def restore_after_capture(self):
|
||||
"""Restore overlay state after capture"""
|
||||
if ADD_OVERLAY:
|
||||
logger.debug(f'Text entries copy during initial phase of restore_after_capture: {self.text_entries_copy}')
|
||||
self.text_entries = self.text_entries_copy.copy()
|
||||
logger.debug(f"Restored text entries: {self.text_entries}")
|
||||
self.update()
|
||||
|
||||
def add_next_text_at_position_no_update(self, x: int, y: int, text: str,
|
||||
font: Optional[QFont] = None, text_color: Qt.GlobalColor = Qt.GlobalColor.red):
|
||||
"""Add new text without triggering update"""
|
||||
entry = TextEntry(
|
||||
text=text,
|
||||
x=x,
|
||||
y=y,
|
||||
font=font or self.default_font,
|
||||
text_color=text_color
|
||||
)
|
||||
self.next_text_entries.append(entry)
|
||||
|
||||
def update_translation(self, ocr_output, translation):
|
||||
# Update your overlay with new translations here
|
||||
# You'll need to implement the logic to display the translations
|
||||
self.clear_all_text()
|
||||
self.text_entries = self.next_text_entries.copy()
|
||||
self.next_text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def capture_behind(self, x=None, y=None, width=None, height=None):
|
||||
"""
|
||||
Capture the screen area behind the overlay.
|
||||
If no coordinates provided, captures the area under the window.
|
||||
"""
|
||||
# Temporarily hide the window
|
||||
self.hide()
|
||||
|
||||
# Get screen
|
||||
screen = QScreen.grabWindow(
|
||||
self.screen(),
|
||||
0,
|
||||
x if x is not None else self.x(),
|
||||
y if y is not None else self.y(),
|
||||
width if width is not None else self.width(),
|
||||
height if height is not None else self.height()
|
||||
)
|
||||
|
||||
# Show the window again
|
||||
self.show()
|
||||
screen_bytes = qpixmap_to_bytes(screen)
|
||||
return screen_bytes
|
||||
|
||||
def clear_all_text(self):
|
||||
"""Clear all text entries"""
|
||||
self.text_entries.clear()
|
||||
self.update()
|
||||
|
||||
def setup_window_attributes(self):
|
||||
# Set window flags for overlay behavior
|
||||
self.setWindowFlags(
|
||||
Qt.WindowType.FramelessWindowHint |
|
||||
Qt.WindowType.WindowStaysOnTopHint |
|
||||
Qt.WindowType.Tool
|
||||
)
|
||||
|
||||
# Set attributes for transparency
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TranslucentBackground)
|
||||
|
||||
# Make the window cover the entire screen
|
||||
self.setGeometry(QApplication.primaryScreen().geometry())
|
||||
|
||||
# Special handling for Wayland
|
||||
if platform.system() == "Linux":
|
||||
if "WAYLAND_DISPLAY" in os.environ:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_X11NetWmWindowTypeCombo)
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_DontCreateNativeAncestors)
|
||||
|
||||
def setup_shortcuts(self):
|
||||
# Toggle visibility (Alt+Shift+T)
|
||||
self.toggle_visibility_shortcut = QShortcut(QKeySequence("Alt+Shift+T"), self)
|
||||
self.toggle_visibility_shortcut.activated.connect(self.toggle_visibility)
|
||||
|
||||
# Toggle passthrough mode (Alt+Shift+P)
|
||||
self.toggle_passthrough_shortcut = QShortcut(QKeySequence("Alt+Shift+P"), self)
|
||||
self.toggle_passthrough_shortcut.activated.connect(self.toggle_passthrough)
|
||||
|
||||
# Quick hide (Escape)
|
||||
self.hide_shortcut = QShortcut(QKeySequence("Esc"), self)
|
||||
self.hide_shortcut.activated.connect(self.hide)
|
||||
|
||||
# Clear all text (Alt+Shift+C)
|
||||
self.clear_shortcut = QShortcut(QKeySequence("Alt+Shift+C"), self)
|
||||
self.clear_shortcut.activated.connect(self.clear_all_text)
|
||||
|
||||
# Toggle background
|
||||
self.toggle_background_shortcut = QShortcut(QKeySequence("Alt+Shift+B"), self)
|
||||
self.toggle_background_shortcut.activated.connect(self.toggle_background)
|
||||
|
||||
def toggle_visibility(self):
|
||||
if self.isVisible():
|
||||
self.hide()
|
||||
else:
|
||||
self.show()
|
||||
|
||||
def toggle_passthrough(self):
|
||||
self.is_passthrough = not self.is_passthrough
|
||||
if self.is_passthrough:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() | Qt.WindowType.X11BypassWindowManagerHint)
|
||||
else:
|
||||
self.setAttribute(Qt.WidgetAttribute.WA_TransparentForMouseEvents, False)
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" not in os.environ:
|
||||
self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.X11BypassWindowManagerHint)
|
||||
|
||||
self.hide()
|
||||
self.show()
|
||||
|
||||
def toggle_background(self):
|
||||
"""Toggle background visibility"""
|
||||
self.show_background = not self.show_background
|
||||
self.update()
|
||||
|
||||
def set_background_opacity(self, opacity: float):
|
||||
"""Set background opacity (0.0 to 1.0)"""
|
||||
self.background_opacity = max(0.0, min(1.0, opacity))
|
||||
self.update()
|
||||
|
||||
|
||||
|
||||
def add_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Add new text at specific coordinates"""
|
||||
entry = TextEntry(text, x, y)
|
||||
self.text_entries.append(entry)
|
||||
self.update()
|
||||
|
||||
def update_text_at_position(self, x: int, y: int, text: str):
|
||||
"""Update text at specific coordinates, or add if none exists"""
|
||||
# Look for existing text entry near these coordinates (within 5 pixels)
|
||||
for entry in self.text_entries:
|
||||
if abs(entry.x - x) <= 1 and abs(entry.y - y) <= 1:
|
||||
entry.text = text
|
||||
self.update()
|
||||
return
|
||||
|
||||
# If no existing entry found, add new one
|
||||
self.add_text_at_position(x, y, text)
|
||||
|
||||
def setup_tray(self):
|
||||
self.tray_icon = QSystemTrayIcon(self)
|
||||
self.tray_icon.setIcon(QIcon.fromTheme("applications-system"))
|
||||
|
||||
tray_menu = QMenu()
|
||||
|
||||
toggle_action = tray_menu.addAction("Show/Hide Overlay")
|
||||
toggle_action.triggered.connect(self.toggle_visibility)
|
||||
|
||||
toggle_passthrough = tray_menu.addAction("Toggle Passthrough")
|
||||
toggle_passthrough.triggered.connect(self.toggle_passthrough)
|
||||
|
||||
# Add background toggle to tray menu
|
||||
toggle_background = tray_menu.addAction("Toggle Background")
|
||||
toggle_background.triggered.connect(self.toggle_background)
|
||||
|
||||
clear_action = tray_menu.addAction("Clear All Text")
|
||||
clear_action.triggered.connect(self.clear_all_text)
|
||||
|
||||
tray_menu.addSeparator()
|
||||
|
||||
quit_action = tray_menu.addAction("Quit")
|
||||
quit_action.triggered.connect(self.clean_exit)
|
||||
|
||||
self.tray_icon.setToolTip("Translation Overlay")
|
||||
self.tray_icon.setContextMenu(tray_menu)
|
||||
self.tray_icon.show()
|
||||
self.tray_icon.activated.connect(self.tray_activated)
|
||||
|
||||
def remove_text_at_position(self, x: int, y: int):
|
||||
"""Remove text entry near specified coordinates"""
|
||||
self.text_entries = [
|
||||
entry for entry in self.text_entries
|
||||
if abs(entry.x - x) > 1 or abs(entry.y - y) > 1
|
||||
]
|
||||
self.update()
|
||||
|
||||
def paintEvent(self, event):
|
||||
painter = QPainter(self)
|
||||
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
|
||||
|
||||
# Draw each text entry
|
||||
for entry in self.text_entries:
|
||||
if not entry.visible:
|
||||
continue
|
||||
|
||||
# Set the font for this specific entry
|
||||
painter.setFont(entry.font)
|
||||
text_metrics = painter.fontMetrics()
|
||||
|
||||
# Get the bounding rectangles for text
|
||||
text_bounds = text_metrics.boundingRect(
|
||||
entry.text
|
||||
)
|
||||
total_width = text_bounds.width()
|
||||
total_height = text_bounds.height()
|
||||
|
||||
# Create rectangles for text placement
|
||||
text_rect = QRect(entry.x, entry.y, total_width, total_height)
|
||||
# Calculate background rectangle that encompasses both texts
|
||||
if entry.background_color is not None:
|
||||
bg_rect = QRect(entry.x - entry.padding,
|
||||
entry.y - entry.padding,
|
||||
total_width + (2 * entry.padding),
|
||||
total_height + (2 * entry.padding))
|
||||
|
||||
painter.setPen(Qt.PenStyle.NoPen)
|
||||
painter.setBrush(entry.background_color)
|
||||
painter.drawRect(bg_rect)
|
||||
|
||||
# Draw the texts
|
||||
painter.setPen(entry.text_color)
|
||||
painter.drawText(text_rect, Qt.AlignmentFlag.AlignLeft, entry.text)
|
||||
|
||||
|
||||
|
||||
def handle_exit(signum, frame):
|
||||
QApplication.quit()
|
||||
|
||||
def start_overlay():
|
||||
app = QApplication(sys.argv)
|
||||
|
||||
# Enable Wayland support if available
|
||||
if platform.system() == "Linux" and "WAYLAND_DISPLAY" in os.environ:
|
||||
app.setProperty("platform", "wayland")
|
||||
|
||||
overlay = TranslationOverlay()
|
||||
|
||||
overlay.show()
|
||||
signal.signal(signal.SIGINT, handle_exit) # Handle Ctrl+C (KeyboardInterrupt)
|
||||
signal.signal(signal.SIGTERM, handle_exit)
|
||||
return (app, overlay)
|
||||
# sys.exit(app.exec())
|
||||
|
||||
if ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
|
||||
if __name__ == "__main__":
|
||||
ADD_OVERLAY = True
|
||||
if not ADD_OVERLAY:
|
||||
app, overlay = start_overlay()
|
||||
overlay.add_text_at_position(600, 100, "Hello World I AM A BIG FAAT FOROGGGGGGGGGG")
|
||||
capture = overlay.capture_behind()
|
||||
capture.save("capture.png")
|
||||
sys.exit(app.exec())
|
||||
59
data.py
Normal file
59
data.py
Normal file
@ -0,0 +1,59 @@
|
||||
from sqlalchemy import create_engine, Column, Index, Integer, String, MetaData, Table, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import relationship, declarative_base, sessionmaker
|
||||
import logging
|
||||
from logging_config import logger
|
||||
import os
|
||||
# Set up the database connection
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'database')
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
database_file = os.path.join(os.path.dirname(__file__), data_dir, 'translations.db')
|
||||
engine = create_engine(f'sqlite:///{database_file}', echo=False)
|
||||
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
|
||||
class Api(Base):
|
||||
__tablename__ = 'api'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(String, nullable=False)
|
||||
site = Column(Integer, nullable=False)
|
||||
rpmin = Column(Integer) # rate per minute
|
||||
rph = Column(Integer) # rate per hour
|
||||
rpd = Column(Integer) # rate per day
|
||||
rpw = Column(Integer) # rate per week
|
||||
rpmth = Column(Integer) # rate per month
|
||||
rpy = Column(Integer) # rate per year
|
||||
|
||||
translations = relationship("Translations", back_populates="api")
|
||||
|
||||
class Translations(Base):
|
||||
__tablename__ = 'translations'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_id = Column(Integer, ForeignKey('api.id'), nullable=False)
|
||||
source_texts = Column(String, nullable=False) # as a json string
|
||||
translated_texts = Column(String, nullable=False) # as a json string
|
||||
source_lang = Column(String, nullable=False)
|
||||
target_lang = Column(String, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False)
|
||||
translation_mismatch = Column(Boolean, nullable=False)
|
||||
api = relationship("Api", back_populates="translations")
|
||||
__table_args__ = (
|
||||
Index('idx_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
def create_tables():
|
||||
if not os.path.exists(database_file):
|
||||
Base.metadata.create_all(engine)
|
||||
logger.info(f"Database created at {database_file}")
|
||||
else:
|
||||
logger.info(f"Using Pre-existing Database at {database_file}.")
|
||||
|
||||
|
||||
215
draw.py
215
draw.py
@ -1,68 +1,66 @@
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||
import os, io, sys, numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from utils import romanize, intercepts, add_furigana
|
||||
from logging_config import logger
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION
|
||||
from config import SOURCE_LANG, MAX_TRANSLATE, FONT_FILE, FONT_SIZE_MAX,FONT_SIZE_MIN, FONT_SIZE, LINE_SPACING, FONT_COLOUR, LINE_HEIGHT, TO_ROMANIZE, FILL_COLOUR, REGION, DRAW_TRANSLATIONS_MODE
|
||||
|
||||
from PySide6.QtGui import QFont
|
||||
font = ImageFont.truetype(FONT_FILE, FONT_SIZE)
|
||||
|
||||
#### CREATE A CLASS LATER so it doesn't have to inherit the same arguments all the way too confusing :| its so ass like this man i had no foresight
|
||||
|
||||
def modify_image_bytes(image_bytes: io.BytesIO, ocr_output, translation: list) -> bytes:
|
||||
"""Modify the image bytes with the translated text and return the modified image bytes"""
|
||||
|
||||
with io.BytesIO(image_bytes) as byte_stream:
|
||||
def modify_image(input: io.BytesIO | str, ocr_output, translation: list) -> bytes:
|
||||
"""Modify the image bytes with the translated text and return the modified image bytes. If it is a path then open directly."""
|
||||
# if input is str, then check if it exists
|
||||
if isinstance(input, str):
|
||||
image = Image.open(input)
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
||||
elif isinstance(input, io.BytesIO):
|
||||
with io.BytesIO(input) as byte_stream:
|
||||
image = Image.open(byte_stream)
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw_on_image(draw, translation, ocr_output, MAX_TRANSLATE)
|
||||
|
||||
else:
|
||||
raise TypeError('Incorrect filetype input')
|
||||
# Save the modified image back to bytes without changing the format
|
||||
with io.BytesIO() as byte_stream:
|
||||
image.save(byte_stream, format=image.format) # Save in original format
|
||||
image.save(byte_stream, format='PNG') # Save in original format
|
||||
modified_image_bytes = byte_stream.getvalue()
|
||||
return modified_image_bytes
|
||||
|
||||
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, replace = False) -> ImageDraw:
|
||||
|
||||
def draw_on_image(draw: ImageDraw, translation: list, ocr_output: list, max_translate: int, draw_mode: str = DRAW_TRANSLATIONS_MODE) -> ImageDraw:
|
||||
"""Draw the original, translated and optionally the romanisation of the texts on the image"""
|
||||
translated_number = 0
|
||||
bounding_boxes = []
|
||||
logger.debug(f"Translations: {len(translation)} {translation}")
|
||||
logger.debug(f"OCR output: {len(ocr_output)} {ocr_output}")
|
||||
for i, (position, untranslated_phrase, confidence) in enumerate(ocr_output):
|
||||
logger.debug(f"Untranslated phrase: {untranslated_phrase}")
|
||||
if translated_number >= max_translate - 1:
|
||||
if translated_number >= len(translation): # note if using api llm some issues may cause it to return less translations than expected
|
||||
break
|
||||
if replace:
|
||||
draw = draw_one_phrase_replace(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
else:
|
||||
draw_one_phrase_add(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
if draw_mode == 'learn':
|
||||
draw_one_phrase_learn(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'translation_only':
|
||||
draw_one_phrase_translation_only(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'learn_cover':
|
||||
draw_one_phrase_learn_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
elif draw_mode == 'translation_only_cover':
|
||||
draw_one_phrase_translation_only_cover(draw, translation[i], position, bounding_boxes, untranslated_phrase)
|
||||
translated_number += 1
|
||||
return draw
|
||||
|
||||
def draw_one_phrase_add(draw: ImageDraw,
|
||||
def draw_one_phrase_learn(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Draw the bounding box rectangle and text on the image above the original text"""
|
||||
|
||||
if SOURCE_LANG == 'ja':
|
||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||
else:
|
||||
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
|
||||
if TO_ROMANIZE:
|
||||
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
|
||||
else:
|
||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||
|
||||
lines = text_content.split('\n')
|
||||
|
||||
lines = get_lines(untranslated_phrase, translated_phrase)
|
||||
# Draw the bounding box
|
||||
top_left, _, _, _ = position
|
||||
max_width = get_max_width(lines, FONT_FILE, FONT_SIZE)
|
||||
total_height = get_max_height(lines, FONT_SIZE, LINE_SPACING)
|
||||
top_left, _, bottom_right,_ = position
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
max_width = get_max_width(lines, FONT_FILE, font_size)
|
||||
total_height = get_max_height(lines, font_size, LINE_SPACING)
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
right_edge = REGION[2]
|
||||
|
||||
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||
@ -73,60 +71,146 @@ def draw_one_phrase_add(draw: ImageDraw,
|
||||
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||
|
||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||
draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
||||
draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
|
||||
# draw.rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], outline="black", width=1)
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
for line in lines:
|
||||
if FONT_COLOUR == 'rainbow':
|
||||
rainbow_text(draw, line, *position, font)
|
||||
else:
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
if ADD_OVERLAY:
|
||||
overlay.add_next_text_at_position_no_update(position[0], position[1], line, text_color=FONT_COLOUR)
|
||||
adjusted_y += FONT_SIZE + LINE_SPACING
|
||||
adjusted_y += font_size + LINE_SPACING
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
|
||||
|
||||
### Only support for horizontal text atm, vertical text is on the todo list
|
||||
def draw_one_phrase_replace(draw: ImageDraw,
|
||||
def draw_one_phrase_translation_only_cover(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
top_left, _, _, bottom_right = position
|
||||
top_left, _, bottom_right, _ = position
|
||||
bounding_boxes.append((top_left[0], top_left[1], bottom_right[0], bottom_right[1], untranslated_phrase)) # Debugging purposes
|
||||
max_width = bottom_right[0] - top_left[0]
|
||||
font_size = bottom_right[1] - top_left[1]
|
||||
draw.rectangle([top_left, bottom_right], fill=FILL_COLOUR)
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
while True:
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
if font.get_max_width < max_width:
|
||||
phrase_width = get_max_width(translated_phrase, FONT_FILE, font_size)
|
||||
rectangle = get_rectangle_coordinates(translated_phrase, top_left, FONT_FILE, font_size, LINE_SPACING)
|
||||
|
||||
if phrase_width < max_width:
|
||||
draw.rectangle(rectangle, fill=FILL_COLOUR)
|
||||
if FONT_COLOUR == 'rainbow':
|
||||
rainbow_text(draw, translated_phrase, *top_left, font)
|
||||
else:
|
||||
draw.rounded_rectangle([top_left, bottom_right], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
|
||||
draw.text(top_left, translated_phrase, fill= FONT_COLOUR, font=font)
|
||||
|
||||
break
|
||||
elif font_size <= 1:
|
||||
elif font_size <= FONT_SIZE_MIN:
|
||||
break
|
||||
else:
|
||||
font_size -= 1
|
||||
|
||||
def get_max_width(lines: list, font_path, font_size) -> int:
|
||||
def draw_one_phrase_learn_cover(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
lines = get_lines(untranslated_phrase, translated_phrase)
|
||||
# Draw the bounding box
|
||||
top_left, _, bottom_right,_ = position
|
||||
font_size = get_font_size(top_left[1], bottom_right[1], FONT_SIZE_MAX, FONT_SIZE_MIN)
|
||||
max_width = get_max_width(lines, FONT_FILE, font_size)
|
||||
total_height = get_max_height(lines, font_size, LINE_SPACING)
|
||||
font = ImageFont.truetype(FONT_FILE, font_size)
|
||||
right_edge = REGION[2]
|
||||
|
||||
# Ensure the text is within the screen. P.S. Text on the edge may still be squished together if there are too many to translate
|
||||
x_onscreen = top_left[0] if top_left[0] + max_width <= right_edge else right_edge - max_width
|
||||
y_onscreen = max(top_left[1] - int(total_height/3), 0)
|
||||
bounding_box = (x_onscreen, y_onscreen, x_onscreen + max_width, y_onscreen + total_height, untranslated_phrase)
|
||||
|
||||
adjust_if_intersects(x_onscreen, y_onscreen, bounding_box, bounding_boxes, untranslated_phrase, max_width, total_height)
|
||||
|
||||
adjusted_x, adjusted_y, adjusted_max_x, adjusted_max_y, _ = bounding_boxes[-1]
|
||||
draw.rounded_rectangle([(adjusted_x,adjusted_y), (adjusted_max_x, adjusted_max_y)], fill=FILL_COLOUR,outline="purple", width=1, radius=5)
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
|
||||
for line in lines:
|
||||
if FONT_COLOUR == 'rainbow': # easter egg yay
|
||||
rainbow_text(draw, line, *position, font)
|
||||
else:
|
||||
draw.text(position, line, fill= FONT_COLOUR, font=font)
|
||||
adjusted_y += font_size + LINE_SPACING
|
||||
position = (adjusted_x,adjusted_y)
|
||||
|
||||
def draw_one_phrase_translation_only(draw: ImageDraw,
|
||||
translated_phrase: str,
|
||||
position: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str) -> ImageDraw:
|
||||
"""Cover up old text and add translation directly on top"""
|
||||
# Draw the bounding box
|
||||
pass
|
||||
|
||||
def get_rectangle_coordinates(lines: list | str, top_left: tuple | list, font_path, font_size, line_spacing, padding: int = 1) -> list:
|
||||
|
||||
"""Get the coordinates of the rectangle surrounding the text"""
|
||||
|
||||
text_width = get_max_width(lines, font_path, font_size)
|
||||
text_height = get_max_height(lines, font_size, line_spacing)
|
||||
x1 = top_left[0] - padding
|
||||
y1 = top_left[1] - padding
|
||||
x2 = top_left[0] + text_width + padding
|
||||
y2 = top_left[1] + text_height + padding
|
||||
return [(x1,y1), (x2,y2)]
|
||||
|
||||
def get_max_width(lines: list | str, font_path, font_size) -> int:
|
||||
"""Get the maximum width of the text lines"""
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
max_width = 0
|
||||
dummy_image = Image.new("RGB", (1, 1))
|
||||
draw = ImageDraw.Draw(dummy_image)
|
||||
if isinstance(lines, list):
|
||||
for line in lines:
|
||||
bbox = draw.textbbox((0,0), line, font=font)
|
||||
line_width = bbox[2] - bbox[0]
|
||||
max_width = max(max_width, line_width)
|
||||
else:
|
||||
bbox = draw.textbbox((0,0), lines, font=font)
|
||||
max_width = bbox[2] - bbox[0]
|
||||
return max_width
|
||||
|
||||
def get_max_height(lines: list, font_size, line_spacing) -> int:
|
||||
def get_max_height(lines: list | str, font_size, line_spacing) -> int:
|
||||
"""Get the maximum height of the text lines"""
|
||||
return len(lines) * (font_size + line_spacing)
|
||||
no_of_lines = len(lines) if isinstance(lines, list) else 1
|
||||
return no_of_lines * (font_size + line_spacing)
|
||||
|
||||
def get_lines(untranslated_phrase: str, translated_phrase: str) -> list:
|
||||
"""Get the translated. untranslated and optionally the romanised text as a list"""
|
||||
if SOURCE_LANG == 'ja':
|
||||
untranslated_phrase = add_furigana(untranslated_phrase)
|
||||
romanized_phrase = romanize(untranslated_phrase, 'ja')
|
||||
else:
|
||||
romanized_phrase = romanize(untranslated_phrase, SOURCE_LANG)
|
||||
if TO_ROMANIZE:
|
||||
text_content = f"{translated_phrase}\n{romanized_phrase}\n{untranslated_phrase}"
|
||||
else:
|
||||
text_content = f"{translated_phrase}\n{untranslated_phrase}"
|
||||
return text_content.split('\n')
|
||||
|
||||
|
||||
def adjust_if_intersects(x: int, y: int,
|
||||
bounding_box: tuple, bounding_boxes: list,
|
||||
untranslated_phrase: str,
|
||||
max_width: int, total_height: int) -> tuple:
|
||||
"""Adjust the y coordinate if the bounding box intersects with any other bounding box"""
|
||||
"""Adjust the y coordinate every time the bounding box intersects with any previous bounding boxes. OCR returns results from top to bottom so it works."""
|
||||
y = np.max([y,0])
|
||||
if len(bounding_boxes) > 0:
|
||||
for box in bounding_boxes:
|
||||
@ -136,3 +220,36 @@ def adjust_if_intersects(x: int, y: int,
|
||||
bounding_boxes.append(adjusted_bounding_box)
|
||||
return adjusted_bounding_box
|
||||
|
||||
|
||||
def get_font_size(y_1, y_2, font_size_max: int, font_size_min: int) -> int:
|
||||
"""Get the average of the maximum and minimum font sizes"""
|
||||
if font_size_min > font_size_max:
|
||||
raise ValueError("Minimum font size cannot be greater than maximum font size")
|
||||
font_size = min(
|
||||
max(int(abs(2/3*(y_2-y_1))), font_size_min),
|
||||
font_size_max)
|
||||
return font_size
|
||||
|
||||
|
||||
|
||||
|
||||
def rainbow_text(draw,text,x,y,font):
|
||||
for i, letter in enumerate(text):
|
||||
# Calculate hue for rainbow effect
|
||||
# Convert HSV to RGB (using full saturation and value)
|
||||
rgb = tuple(np.random.randint(50,255,3))
|
||||
# Get the width of this letter
|
||||
|
||||
letter_bbox = draw.textbbox((x, y), letter, font=font)
|
||||
letter_width = letter_bbox[2] - letter_bbox[0]
|
||||
|
||||
# Draw the letter
|
||||
draw.text((x, y), letter, fill=rgb, font=font)
|
||||
|
||||
# Move x position for next letter
|
||||
x += letter_width
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
BIN
fonts/Arial-Unicode-Bold.ttf
Normal file
BIN
fonts/Arial-Unicode-Bold.ttf
Normal file
Binary file not shown.
@ -1,141 +1,314 @@
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from dotenv import load_dotenv
|
||||
import os , sys, torch, time, ast
|
||||
import os , sys, torch, time, ast, json, pytz
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
from multiprocessing import Process, Event, Value
|
||||
load_dotenv()
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
from config import device, GEMINI_API_KEY, GROQ_API_KEY
|
||||
from config import device, GEMINI_API_KEY, GROQ_API_KEY, MAX_TRANSLATE
|
||||
from logging_config import logger
|
||||
from groq import Groq
|
||||
from groq import Groq as Groqq
|
||||
import google.generativeai as genai
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from functools import wraps
|
||||
|
||||
from data import session, Api, Translations
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ApiModel():
|
||||
def __init__(self, model, # model name
|
||||
rate, # rate of calls per minute
|
||||
api_key, # api key for the model wrt the site
|
||||
def __init__(self, model, # model name as defined by the API
|
||||
site, # site of the model; # to be precise, use the name as defined precisely by the class names in this script, i.e. Groqq and Gemini
|
||||
api_key: Optional[str] = None, # api key for the model wrt the site
|
||||
rpmin: Optional[int] = None, # rate of calls per minute
|
||||
rph: Optional[int] = None, # rate of calls per hour
|
||||
rpd: Optional[int] = None, # rate of calls per day
|
||||
rpw: Optional[int] = None, # rate of calls per week
|
||||
rpmth: Optional[int] = None, # rate of calls per month
|
||||
rpy: Optional[int] = None # rate of calls per year
|
||||
):
|
||||
self.model = model
|
||||
self.rate = rate
|
||||
self.api_key = api_key
|
||||
self.curr_calls = Value('i', 0)
|
||||
self.time = Value('i', 0)
|
||||
self.process = None
|
||||
self.stop_event = Event()
|
||||
self.site = None
|
||||
self.model = model
|
||||
self.rpmin = rpmin
|
||||
self.rph = rph
|
||||
self.rpd = rpd
|
||||
self.rpw = rpw
|
||||
self.rpmth = rpmth
|
||||
self.rpy = rpy
|
||||
self.site = site
|
||||
self.from_lang = None
|
||||
self.target_lang = None
|
||||
self.request = None # request response from API
|
||||
self.db_table = None
|
||||
self.session_calls = 0
|
||||
self._id = None
|
||||
self._set_db_model_id() if self._get_db_model_id() else self.update_db()
|
||||
# Create the table if it does not already exist
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.site} Model: {self.model}; Rate: {self.rate}; Current_Calls: {self.curr_calls.value} calls; Time Passed: {self.time.value} seconds.'
|
||||
return f'{self.site} Model: {self.model}; Total calls this session: {self.session_calls}; rpmin: {self.rpmin}; rph: {self.rph}; rpd: {self.rpd}; rpw: {self.rpw}; rpmth: {self.rpmth}; rpy: {self.rpy}'
|
||||
|
||||
def __str__(self):
|
||||
return self.model
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
def _get_db_model_id(self):
|
||||
model = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
||||
if model:
|
||||
return model.id
|
||||
else:
|
||||
return None
|
||||
|
||||
def _set_db_model_id(self):
|
||||
self._id = self._get_db_model_id()
|
||||
|
||||
@staticmethod
|
||||
def _get_time():
|
||||
return datetime.now(tz=pytz.timezone('Australia/Sydney'))
|
||||
|
||||
def set_lang(self, from_lang, target_lang):
|
||||
self.from_lang = from_lang
|
||||
self.target_lang = target_lang
|
||||
|
||||
### CHECK MINUTELY API RATES. For working with hourly rates and monthly will need to create another file. Also just unlikely those rates will be hit
|
||||
async def api_rate_check(self):
|
||||
# Background task to manage the rate of calls to the API
|
||||
while not self.stop_event.is_set():
|
||||
start_time = time.monotonic()
|
||||
self.time.value += 5
|
||||
if self.time.value >= 60:
|
||||
self.time.value = 0
|
||||
self.curr_calls.value = 0
|
||||
elapsed = time.monotonic() - start_time
|
||||
# Sleep for exactly 5 seconds minus the elapsed time
|
||||
sleep_time = max(0, 5 - elapsed)
|
||||
await asyncio.sleep(sleep_time)
|
||||
def set_db_table(self, db_table):
|
||||
self.db_table = db_table
|
||||
|
||||
def background_task(self):
|
||||
asyncio.run(self.api_rate_check())
|
||||
|
||||
def start(self):
|
||||
# Start the background task
|
||||
self.process = Process(target=self.background_task)
|
||||
self.process.daemon = True
|
||||
self.process.start()
|
||||
logger.info(f"Background process started with PID: {self.process.pid}")
|
||||
|
||||
def stop(self):
|
||||
# Stop the background task
|
||||
logger.info(f"Stopping background process with PID: {self.process.pid}")
|
||||
self.stop_event.set()
|
||||
if self.process:
|
||||
self.process.join(timeout=5)
|
||||
if self.process.is_alive():
|
||||
self.process.terminate()
|
||||
|
||||
def request_func(request):
|
||||
@wraps(request)
|
||||
def wrapper(self, text, *args, **kwargs):
|
||||
if self.curr_calls.value < self.rate:
|
||||
# try:
|
||||
response = request(self, text, *args, **kwargs)
|
||||
self.curr_calls.value += 1
|
||||
return response
|
||||
# except Exception as e:
|
||||
#logger.error(f"Error with model {self.model} from {self.site}. Error: {e}")
|
||||
def update_db(self):
|
||||
api = session.query(Api).filter_by(model_name = self.model, site = self.site).first()
|
||||
if not api:
|
||||
api = Api(model_name = self.model,
|
||||
site = self.site,
|
||||
rpmin = self.rpmin,
|
||||
rph = self.rph,
|
||||
rpd = self.rpd,
|
||||
rpw = self.rpw,
|
||||
rpmth = self.rpmth,
|
||||
rpy = self.rpy)
|
||||
session.add(api)
|
||||
session.commit()
|
||||
self._set_db_model_id()
|
||||
else:
|
||||
logger.error(f"Rate limit reached for this model. Please wait for the rate to reset in {60 - self.time} seconds.")
|
||||
raise TooManyRequests('Rate limit reached for this model.')
|
||||
return wrapper
|
||||
api.rpmin = self.rpmin
|
||||
api.rph = self.rph
|
||||
api.rpd = self.rpd
|
||||
api.rpw = self.rpw
|
||||
api.rpmth = self.rpmth
|
||||
api.rpy = self.rpy
|
||||
session.commit()
|
||||
|
||||
@request_func
|
||||
def translate(self, request_fn, texts_to_translate):
|
||||
def _db_add_translation(self, text: list | str, translation: list, mismatch = False):
|
||||
text = json.dumps(text) if isinstance(text, list) else json.dumps([text])
|
||||
translation = json.dumps(translation)
|
||||
translation = Translations(source_texts = text, translated_texts = translation,
|
||||
model_id = self._id, source_lang = self.from_lang, target_lang = self.target_lang,
|
||||
timestamp = datetime.now(tz=pytz.timezone('Australia/Sydney')),
|
||||
translation_mismatch = mismatch)
|
||||
session.add(translation)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _single_period_calls_check(max_calls, call_count):
|
||||
if not max_calls:
|
||||
return True
|
||||
if max_calls <= call_count:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _are_rates_good(self):
|
||||
curr_time = self._get_time()
|
||||
min_ago = curr_time - timedelta(minutes=1)
|
||||
hour_ago = curr_time - timedelta(hours=1)
|
||||
day_ago = curr_time - timedelta(days=1)
|
||||
week_ago = curr_time - timedelta(weeks=1)
|
||||
month_ago = curr_time - timedelta(days=30)
|
||||
year_ago = curr_time - timedelta(days=365)
|
||||
min_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= min_ago
|
||||
).count()
|
||||
hour_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= hour_ago
|
||||
).count()
|
||||
day_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= day_ago
|
||||
).count()
|
||||
week_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= week_ago
|
||||
).count()
|
||||
month_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= month_ago
|
||||
).count()
|
||||
year_calls = session.query(Translations).join(Api). \
|
||||
filter(Api.id==self._id,
|
||||
Translations.timestamp >= year_ago
|
||||
).count()
|
||||
if self._single_period_calls_check(self.rpmin, min_calls) \
|
||||
and self._single_period_calls_check(self.rph, hour_calls) \
|
||||
and self._single_period_calls_check(self.rpd, day_calls) \
|
||||
and self._single_period_calls_check(self.rpw, week_calls) \
|
||||
and self._single_period_calls_check(self.rpmth, month_calls) \
|
||||
and self._single_period_calls_check(self.rpy, year_calls):
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Rate limit reached for {self.model} from {self.site}. Current calls: {min_calls} in the last minute; {hour_calls} in the last hour; {day_calls} in the last day; {week_calls} in the last week; {month_calls} in the last month; {year_calls} in the last year.")
|
||||
return False
|
||||
|
||||
|
||||
async def translate(self, texts_to_translate, store = False) -> tuple[int, # exit code: 0 for success, 1 for incorrect response type, 2 for incorrect translation count
|
||||
list[str],
|
||||
int # number of translations that do not match the number of texts to translate
|
||||
]:
|
||||
"""Main Translation Function. All API models will need to define a new class and also define a _request function as shown below in the Gemini and Groq class models."""
|
||||
if isinstance(texts_to_translate, str):
|
||||
texts_to_translate = [texts_to_translate]
|
||||
if len(texts_to_translate) == 0:
|
||||
return []
|
||||
prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters: {texts_to_translate}"
|
||||
response = request_fn(self, prompt)
|
||||
return ast.literal_eval(response.strip())
|
||||
return (0, [], 0)
|
||||
#prompt = f"Without any additional remarks, and without any code, translate the following items of the Python list from {self.from_lang} into {self.target_lang} and output as a Python list ensuring proper escaping of characters and ensuring the length of the list given is exactly equal to the length of the list you provide. Do not output in any other language other than the specified target language: {texts_to_translate}"
|
||||
prompt = f"""INSTRUCTIONS:
|
||||
- Provide ONE and ONLY ONE translation to each text provided in the JSON array given.
|
||||
- Respond using ONLY valid JSON array syntax. Do not use any Python-like dictionary syntax and therefore it must not contain any keys or curly braces.
|
||||
- Do not include explanations or additional text
|
||||
- The translations must preserve the original order.
|
||||
- Each translation must be from the Source language to the Target language
|
||||
- Source language: {self.from_lang}
|
||||
- Target language: {self.target_lang}
|
||||
- Escape special characters properly
|
||||
|
||||
class Groqq(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GROQ_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Groq"
|
||||
Input texts:
|
||||
{texts_to_translate}
|
||||
|
||||
def request(self, content):
|
||||
client = Groq()
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
Expected format:
|
||||
["translation1", "translation2", ...]
|
||||
|
||||
Translation:"""
|
||||
|
||||
try:
|
||||
response = await self._request(prompt)
|
||||
response_list = ast.literal_eval(response.strip())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to evaluate response from {self.model} from {self.site}. Error: {e}.")
|
||||
return (1, [], 99999)
|
||||
logger.debug(repr(self))
|
||||
logger.info(f'{self.model} translated texts from: {texts_to_translate} to {response_list}.')
|
||||
if not isinstance(response_list, list):
|
||||
# raise TypeError(f"Incorrect response type. Expected list, got {type(response_list)}")
|
||||
logger.error(f"Incorrect response type. Expected list, got {type(response_list)}")
|
||||
return (1, [], 99999)
|
||||
if len(response_list) != len(texts_to_translate) and len(texts_to_translate) <= MAX_TRANSLATE:
|
||||
logger.error(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
|
||||
if store:
|
||||
self._db_add_translation(texts_to_translate, response_list, mismatch=True)
|
||||
# raise ValueError(f"Number of translations does not match number of texts to translate. Sent: {len(texts_to_translate)}. Received: {len(response_list)}.")
|
||||
return (2, response_list, abs(len(texts_to_translate) - len(response_list)))
|
||||
else:
|
||||
if store:
|
||||
self._db_add_translation(texts_to_translate, response_list)
|
||||
return (0, response_list, 0)
|
||||
|
||||
class Groq(ApiModel):
|
||||
def __init__(self, # model name as defined by the API
|
||||
model,
|
||||
api_key = GROQ_API_KEY, # api key for the model wrt the site
|
||||
**kwargs):
|
||||
super().__init__(model,
|
||||
api_key = api_key,
|
||||
site = 'Groq', **kwargs)
|
||||
self.client = Groqq()
|
||||
|
||||
async def _request(self, content: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://api.groq.com/openai/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {GROQ_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"model": self.model
|
||||
}
|
||||
],
|
||||
model=self.model
|
||||
)
|
||||
return chat_completion.choices[0].message.content
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Groqq.request, texts_to_translate)
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
return response_json["choices"][0]["message"]["content"]
|
||||
# https://console.groq.com/settings/limits for limits
|
||||
|
||||
class Gemini(ApiModel):
|
||||
def __init__(self, model, rate, api_key = GEMINI_API_KEY):
|
||||
super().__init__(model, rate, api_key)
|
||||
self.site = "Gemini"
|
||||
def __init__(self, # model name as defined by the API
|
||||
model,
|
||||
api_key = GEMINI_API_KEY, # api key for the model wrt the site
|
||||
**kwargs):
|
||||
super().__init__(model,
|
||||
api_key = api_key,
|
||||
site = 'Google',
|
||||
**kwargs)
|
||||
|
||||
def request(self, content):
|
||||
genai.configure(api_key=self.api_key)
|
||||
safety_settings = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_NONE"}
|
||||
response = genai.GenerativeModel(self.model).generate_content(content, safety_settings=safety_settings)
|
||||
return response.text.strip()
|
||||
async def _request(self, content):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}",
|
||||
headers={
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": content}]}],
|
||||
"safetySettings": [
|
||||
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE",
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"
|
||||
}
|
||||
]
|
||||
}
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
return response_json['candidates'][0]['content']['parts'][0]['text']
|
||||
"""
|
||||
DEFINE YOUR OWN API MODELS BELOW WITH THE SAME TEMPLATE AS BELOW. All fields required are indicated by <required field>.
|
||||
|
||||
def translate(self, texts_to_translate):
|
||||
return super().translate(Gemini.request, texts_to_translate)
|
||||
class <NameOfWebsite>(ApiModel):
|
||||
def __init__(self, # model name as defined by the API
|
||||
model,
|
||||
api_key = <API_KEY>, # api key for the model wrt the site
|
||||
**kwargs):
|
||||
super().__init__(model,
|
||||
api_key = api_key,
|
||||
site = <name_of_website>,
|
||||
**kwargs)
|
||||
|
||||
async def _request(self, content):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
<API ENDPOINT e.g. https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={self.api_key}>,
|
||||
headers={
|
||||
"Content-Type": "application/json"
|
||||
<ANY OTHER HEADERS REQUIRED BY THE API separated by commas>
|
||||
},
|
||||
json={
|
||||
"contents": [{"parts": [{"text": content}]}]
|
||||
<ANY OTHER JSON PAIRS REQUIRED separated by commas>
|
||||
}
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
return <Anything needed to extract the message response from `response_json`>
|
||||
"""
|
||||
|
||||
###################################################################################################
|
||||
|
||||
### LOCAL LLM TRANSLATION
|
||||
|
||||
class TranslationDataset(Dataset):
|
||||
def __init__(self, texts: List[str], tokenizer, max_length: int = 512):
|
||||
@ -262,11 +435,13 @@ def generate_text(
|
||||
return all_generated_texts
|
||||
|
||||
if __name__ == '__main__':
|
||||
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
|
||||
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
|
||||
groq = Groqq('gemma-7b-it', 15, GROQ_API_KEY)
|
||||
groq.set_lang('zh','en')
|
||||
gemini = Gemini('gemini-1.5-pro', 15, GEMINI_API_KEY)
|
||||
gemini.set_lang('zh','en')
|
||||
print(gemini.translate(['荷兰咯']))
|
||||
print(groq.translate(['荷兰咯']))
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, M2M100Tokenizer, M2M100ForConditionalGeneration
|
||||
opus_model = 'Helsinki-NLP/opus-mt-en-zh'
|
||||
LOCAL_FILES_ONLY = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
|
||||
# tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY)
|
||||
# tokenizer.src_lang = "en"
|
||||
# model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M", local_files_only=LOCAL_FILES_ONLY).to(device)
|
||||
|
||||
print(generate_text([ i.lower().capitalize() for i in ['placeholder','Story','StoRY', 'TufoRIaL', 'CovFfG', 'LoaD DaTA', 'SAME DATa', 'ReTulN@TitIE', 'View', '@niirm', 'SysceM', 'MeNu:', 'MaND', 'CoM', 'SeLEcT', 'Frogguingang', 'Tutorias', 'Back']], model, tokenizer))
|
||||
@ -21,7 +21,6 @@ def _paddle_init(paddle_lang, use_angle_cls=False, use_GPU=True, **kwargs):
|
||||
|
||||
|
||||
def _paddle_ocr(ocr, image) -> list:
|
||||
|
||||
### return a list containing the bounding box, text and confidence of the detected text
|
||||
result = ocr.ocr(image, cls=False)[0]
|
||||
if not isinstance(result, list):
|
||||
@ -32,28 +31,30 @@ def _paddle_ocr(ocr, image) -> list:
|
||||
# EasyOCR has support for many languages
|
||||
|
||||
def _easy_init(easy_languages: list, use_GPU=True, **kwargs):
|
||||
langs = []
|
||||
for lang in easy_languages:
|
||||
langs.append(standardize_lang(lang)['easyocr_lang'])
|
||||
return easyocr.Reader(langs, gpu=use_GPU, **kwargs)
|
||||
return easyocr.Reader(easy_languages, gpu=use_GPU, **kwargs)
|
||||
|
||||
def _easy_ocr(ocr,image) -> list:
|
||||
return ocr.readtext(image)
|
||||
detected_texts = ocr.readtext(image)
|
||||
return detected_texts
|
||||
|
||||
# RapidOCR mostly for mandarin and some other asian languages
|
||||
|
||||
# default only supports chinese and english
|
||||
def _rapid_init(use_GPU=True, **kwargs):
|
||||
return RapidOCR(use_gpu=use_GPU, **kwargs)
|
||||
|
||||
def _rapid_ocr(ocr, image) -> list:
|
||||
return ocr(image)
|
||||
return ocr(image)[0]
|
||||
|
||||
### Initialize the OCR model
|
||||
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch', use_GPU=True, **kwargs):
|
||||
def init_OCR(model='paddle', easy_languages: Optional[list] = ['ch_sim','en'], paddle_lang: Optional[str] = 'ch_sim', use_GPU=True):
|
||||
if model == 'paddle':
|
||||
paddle_lang = standardize_lang(paddle_lang)['paddleocr_lang']
|
||||
return _paddle_init(paddle_lang=paddle_lang, use_GPU=use_GPU)
|
||||
elif model == 'easy':
|
||||
return _easy_init(easy_languages=easy_languages, use_GPU=use_GPU)
|
||||
langs = []
|
||||
for lang in easy_languages:
|
||||
langs.append(standardize_lang(lang)['easyocr_lang'])
|
||||
return _easy_init(easy_languages=langs, use_GPU=use_GPU)
|
||||
elif model == 'rapid':
|
||||
return _rapid_init(use_GPU=use_GPU)
|
||||
|
||||
@ -82,15 +83,16 @@ def _id_filtered(ocr, image, lang) -> list:
|
||||
return results_no_eng
|
||||
|
||||
|
||||
# ch_sim, ch_tra, ja, ko, en
|
||||
# ch_sim, ch_tra, ja, ko, en input
|
||||
def _id_lang(ocr, image, lang) -> list:
|
||||
result = _identify(ocr, image)
|
||||
lang = standardize_lang(lang)['id_model_lang']
|
||||
try:
|
||||
# try:
|
||||
logger.info(f"Filtering out phrases not in {lang}.")
|
||||
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
|
||||
except:
|
||||
logger.error(f"Selected language not part of default: {default_languages}.")
|
||||
raise ValueError(f"Selected language not part of default: {default_languages}.")
|
||||
# except:
|
||||
# logger.error(f"Selected language not part of default: {default_languages}.")
|
||||
# raise ValueError(f"Selected language not part of default: {default_languages}.")
|
||||
return filtered
|
||||
|
||||
def id_keep_source_lang(ocr, image, lang) -> list:
|
||||
@ -116,9 +118,6 @@ def get_confidences(ocr_output) -> list:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# OCR_languages = ['ch_sim','en']
|
||||
# image_old = '/home/James/Pictures/Screenshots/DP-1.jpg'
|
||||
# reader = easyocr.Reader(OCR_languages, gpu=True) # this needs to run only once to load the model into memory
|
||||
# result = reader.readtext(image_old)
|
||||
# print(result)
|
||||
print(id_keep_source_lang(init_OCR(model='paddle', paddle_lang='zh', easy_languages=['en', 'ch_sim']), '/home/James/Pictures/Screenshots/DP-1.jpg', 'ch_sim'))
|
||||
OCR_languages = ['ch_sim','en']
|
||||
reader = easyocr.Reader(OCR_languages, gpu=True)
|
||||
|
||||
@ -1,17 +1,16 @@
|
||||
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, GPTQConfig, AutoModelForCausalLM
|
||||
import google.generativeai as genai
|
||||
import torch, os, sys, ast, json
|
||||
import torch, os, sys, ast, json, asyncio, batching, random
|
||||
from typing import List, Optional, Set
|
||||
from utils import standardize_lang
|
||||
from functools import wraps
|
||||
import random
|
||||
import batching
|
||||
from batching import generate_text, Gemini, Groq
|
||||
from batching import generate_text, Gemini, Groq, ApiModel
|
||||
from logging_config import logger
|
||||
from multiprocessing import Process,Event
|
||||
from asyncio import Task
|
||||
# root dir
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models
|
||||
from config import LOCAL_FILES_ONLY, available_langs, curr_models, BATCH_SIZE, device, GEMINI_API_KEY, MAX_INPUT_TOKENS, MAX_OUTPUT_TOKENS, seq_llm_models, api_llm_models, causal_llm_models, API_MODELS_FILEPATH
|
||||
|
||||
##############################
|
||||
# translation decorator
|
||||
@ -30,29 +29,72 @@ def translate(translation_func):
|
||||
|
||||
###############################
|
||||
def init_API_LLM(from_lang, target_lang):
|
||||
"""Initialise the API models. The models are stored in a json file. The models are instantiated, added to database/database api rates are updated and the languages are set."""
|
||||
from_lang = standardize_lang(from_lang)['translation_model_lang']
|
||||
target_lang = standardize_lang(target_lang)['translation_model_lang']
|
||||
with open('api_models.json', 'r') as f:
|
||||
with open(API_MODELS_FILEPATH, 'r') as f:
|
||||
models_and_rates = json.load(f)
|
||||
models = []
|
||||
for class_type, class_models in models_and_rates.items():
|
||||
cls = getattr(batching, class_type)
|
||||
instantiated_objects = [ cls(model, rate) for model, rate in class_models.items()]
|
||||
instantiated_objects = [ cls(model = model, **rates) for model, rates in class_models.items()]
|
||||
models.extend(instantiated_objects)
|
||||
|
||||
for model in models:
|
||||
model.start()
|
||||
model.update_db()
|
||||
model.set_lang(from_lang, target_lang)
|
||||
return models
|
||||
|
||||
def translate_API_LLM(text, models):
|
||||
async def translate_API_LLM(texts_to_translate: List[str],
|
||||
models: List[ApiModel],
|
||||
call_size: int = 2) -> List[str]:
|
||||
"""Translate the texts using the models three at a time. If the models fail to translate the text, it will try the next model in the list."""
|
||||
async def try_translate(model: ApiModel) -> Optional[List[str]]:
|
||||
result = await model.translate(texts_to_translate, store=True)
|
||||
logger.debug(f'Try_translate result: {result}')
|
||||
return result
|
||||
random.shuffle(models)
|
||||
for model in models:
|
||||
try:
|
||||
return model.translate(text)
|
||||
except:
|
||||
continue
|
||||
groups = [models[i:i+call_size] for i in range(0, len(models), call_size)]
|
||||
no_of_models = len(models)
|
||||
translation_attempts = 0
|
||||
|
||||
best_translation = None # (model, translation_errors)
|
||||
|
||||
for group in groups:
|
||||
tasks = set(asyncio.create_task(try_translate(model)) for model in group)
|
||||
while tasks:
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
logger.debug(f"Tasks done: {done}")
|
||||
logger.debug(f"Tasks remaining: {pending}")
|
||||
for task in done:
|
||||
result = await task
|
||||
logger.debug(f'Result: {result}')
|
||||
if result is not None:
|
||||
tasks.discard(task)
|
||||
translation_attempts += 1
|
||||
status_code, translations, translation_mismatches = result
|
||||
if status_code == 0:
|
||||
# Cancel remaining tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
return translations
|
||||
else:
|
||||
logger.error(f"Model has failed to translate the text. Result: {result}")
|
||||
if translation_attempts == no_of_models:
|
||||
if best_translation is not None:
|
||||
return translations
|
||||
else:
|
||||
logger.error("All models have failed to translate the text.")
|
||||
raise TypeError("Models have likely all outputted garbage translations or rate limited.")
|
||||
elif status_code == 2:
|
||||
if best_translation is None:
|
||||
best_translation = (translations, translation_mismatches)
|
||||
else:
|
||||
best_translation = (translations, translation_mismatches) if len(result[2]) < len(best_translation[1]) else best_translation
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
###############################
|
||||
# Best model by far. Aya-23-8B. Gemma is relatively good. If I get the time to quantize either gemma or aya those will be good to use. llama3.2 is really good as well.
|
||||
@ -101,15 +143,17 @@ def get_OPUS_model(from_lang, target_lang):
|
||||
|
||||
def init_OPUS(from_lang = 'ch_sim', target_lang = 'en'):
|
||||
opus_model = get_OPUS_model(from_lang, target_lang)
|
||||
logger.debug(f"OPUS model: {opus_model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(opus_model, local_files_only=LOCAL_FILES_ONLY, torch_dtype=torch.float16).to(device)
|
||||
model.eval()
|
||||
return (model, tokenizer)
|
||||
|
||||
def translate_OPUS(text: list[str], model, tokenizer) -> list[str]:
|
||||
translated_text = generate_text(model,tokenizer, text,
|
||||
translated_text = generate_text(text, model,tokenizer,
|
||||
batch_size=BATCH_SIZE, device=device,
|
||||
max_length=MAX_INPUT_TOKENS, max_new_tokens=MAX_OUTPUT_TOKENS)
|
||||
logger.debug(f"Translated text: {translated_text}")
|
||||
return translated_text
|
||||
|
||||
###############################
|
||||
@ -132,6 +176,7 @@ def translate_Seq_LLM(text,
|
||||
model,
|
||||
tokenizer,
|
||||
**kwargs):
|
||||
text = [t.lower().capitalize() for t in text]
|
||||
if model_type == 'opus':
|
||||
return translate_OPUS(text, model, tokenizer)
|
||||
elif model_type == 'm2m':
|
||||
|
||||
@ -4,7 +4,10 @@ import pyscreenshot as ImageGrab # wayland tings not sure if it will work on oth
|
||||
import mss, io, os
|
||||
from PIL import Image
|
||||
import jaconv, MeCab, unidic, pykakasi
|
||||
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
import numpy as np
|
||||
import subprocess
|
||||
# for creating furigana
|
||||
mecab = MeCab.Tagger('-d "{}"'.format(unidic.DICDIR))
|
||||
uroman = ur.Uroman()
|
||||
@ -23,33 +26,26 @@ def intercepts(x,y):
|
||||
def is_wayland():
|
||||
return 'WAYLAND_DISPLAY' in os.environ
|
||||
|
||||
# path to save screenshot of monitor to
|
||||
def printsc_wayland(region, save: bool = False, path: str = None):
|
||||
if save:
|
||||
im = ImageGrab.grab(bbox=region)
|
||||
im.save(path)
|
||||
else:
|
||||
return ImageGrab.grab(bbox=region)
|
||||
# please install grim otherwise this is way too slow for wayland
|
||||
def printsc_wayland(region: tuple, path: str):
|
||||
subprocess.run(['grim','-g', f'{region[0]},{region[1]} {region[2]-region[0]}x{region[3]-region[1]}', '-t', 'jpeg', '-q','90', path])
|
||||
|
||||
|
||||
def printsc_non_wayland(region, save: bool = False, path: str = None):
|
||||
def printsc_non_wayland(region: tuple, path: str):
|
||||
# use mss to capture the screen
|
||||
with mss.mss() as sct:
|
||||
# grab the screen
|
||||
img = sct.grab(region)
|
||||
# convert the image to a PIL image
|
||||
image = Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
|
||||
# save the image if save is True
|
||||
if save:
|
||||
image.save(path)
|
||||
|
||||
|
||||
def printsc(region, save: bool = False, path: str = None):
|
||||
def printsc(region: tuple, path: str):
|
||||
try:
|
||||
if is_wayland():
|
||||
return printsc_wayland(region, save, path)
|
||||
printsc_wayland(region, path)
|
||||
else:
|
||||
return printsc_non_wayland(region, save, path)
|
||||
printsc_non_wayland(region, path)
|
||||
except Exception as e:
|
||||
print(f'Error {e}')
|
||||
|
||||
@ -95,10 +91,10 @@ def contains_katakana(text):
|
||||
|
||||
# use kakasi to romanize japanese text
|
||||
def romanize(text, lang):
|
||||
if lang == 'zh':
|
||||
if lang in ['zh','ch_sim','ch_tra']:
|
||||
return ' '.join([ py[0] for py in pinyin(text, heteronym=True)])
|
||||
if lang == 'ja':
|
||||
return kks.convert(text)[0]['hepburn']
|
||||
return ' '.join([romaji['hepburn'] for romaji in kks.convert(text)])
|
||||
return uroman.romanize_string(text)
|
||||
|
||||
# check if a string contains words from a language
|
||||
@ -107,7 +103,7 @@ def contains_lang(text, lang):
|
||||
if lang == 'zh':
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
elif lang == 'ja':
|
||||
return bool(re.search(r'[\u3040-\u30ff]', text))
|
||||
return bool(re.search(r'[\u3040-\u30ff]', text)) or bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
elif lang == 'ko':
|
||||
return bool(re.search(r'[\uac00-\ud7af]', text))
|
||||
elif lang == 'en':
|
||||
@ -131,13 +127,13 @@ def standardize_lang(lang):
|
||||
id_model_lang = 'zh'
|
||||
elif lang == 'ja':
|
||||
easyocr_lang = 'ja'
|
||||
paddleocr_lang = 'ja'
|
||||
paddleocr_lang = 'japan'
|
||||
rapidocr_lang = 'ja'
|
||||
translation_model_lang = 'ja'
|
||||
id_model_lang = 'ja'
|
||||
elif lang == 'ko':
|
||||
easyocr_lang = 'korean'
|
||||
paddleocr_lang = 'ko'
|
||||
easyocr_lang = 'ko'
|
||||
paddleocr_lang = 'korean'
|
||||
rapidocr_lang = 'ko'
|
||||
translation_model_lang = 'ko'
|
||||
id_model_lang = 'ko'
|
||||
@ -165,8 +161,38 @@ def which_ocr_lang(model):
|
||||
else:
|
||||
raise ValueError("Invalid OCR model. Please use one of 'easy', 'paddle', or 'rapid'.")
|
||||
|
||||
def similar_tfidf(list1,list2) -> float:
|
||||
"""Calculate cosine similarity using TF-IDF vectors."""
|
||||
if not list1 or not list2:
|
||||
return 0.0
|
||||
|
||||
vectorizer = TfidfVectorizer()
|
||||
all_texts = list1 + list2
|
||||
tfidf_matrix = vectorizer.fit_transform(all_texts)
|
||||
|
||||
# Calculate average vectors for each list
|
||||
vec1 = np.mean(tfidf_matrix[:len(list1)].toarray(), axis=0).reshape(1, -1)
|
||||
vec2 = np.mean(tfidf_matrix[len(list1):].toarray(), axis=0).reshape(1, -1)
|
||||
|
||||
return cosine_similarity(vec1, vec2)[0, 0]
|
||||
|
||||
def similar_jacard(list1, list2) -> float:
|
||||
if not list1 or not list2:
|
||||
return 0.0
|
||||
return len(set(list1).intersection(set(list2))) / len(set(list1).union(set(list2)))
|
||||
|
||||
def check_similarity(list1, list2, threshold, method = 'tfidf'):
|
||||
if method == 'tfidf':
|
||||
try:
|
||||
confidence = similar_tfidf(list1, list2)
|
||||
except ValueError:
|
||||
return True
|
||||
return True if confidence > threshold else False
|
||||
elif method == 'jacard':
|
||||
return True if similar_jacard(list1, list2) >= threshold else False
|
||||
else:
|
||||
raise ValueError("Invalid method. Please use one of 'tfidf' or 'jacard'.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
japanesetext = "本が好きにちは"
|
||||
print(add_furigana(japanesetext))
|
||||
print(romanize(lang='ja', text='世界はひろい'))
|
||||
@ -48,8 +48,8 @@ def setup_logger(
|
||||
|
||||
# Create a formatter and set it for both handlers
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - [%(levelname)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
'%(asctime)s.%(msecs)03d - %(name)s - [%(levelname)s] %(message)s',
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
console_handler.setFormatter(formatter)
|
||||
@ -64,4 +64,5 @@ def setup_logger(
|
||||
print(f"Failed to setup logger: {e}")
|
||||
return None
|
||||
|
||||
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.DEBUG)
|
||||
logger = setup_logger('on_screen_translator', log_file='translate.log', level=logging.INFO)
|
||||
|
||||
|
||||
33
main.py
Normal file
33
main.py
Normal file
@ -0,0 +1,33 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys, threading, subprocess, asyncio
|
||||
import config
|
||||
import web_app, qt_app
|
||||
from logging_config import logger
|
||||
from data import create_tables
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
|
||||
|
||||
###################################################################################
|
||||
create_tables()
|
||||
|
||||
def main(app):
|
||||
logger.info('Configuration:')
|
||||
for i in dir(config):
|
||||
if not callable(getattr(config, i)) and not i.startswith("__"):
|
||||
logger.info(f'{i}: {getattr(config, i)}')
|
||||
|
||||
if app == 'qt':
|
||||
# Start the Qt app
|
||||
qt_app.qt_app_main()
|
||||
elif app == 'web':
|
||||
threading.Thread(target=asyncio.run, args=(web_app.web_app_main(),), daemon=True).start()
|
||||
|
||||
web_app.app.run(host='0.0.0.0', port=5000, debug=False)
|
||||
################### TODO ##################
|
||||
# 3. Quantising/finetuning larger LLMs. Consider using Aya-23-8B, Gemma, llama3.2 models.
|
||||
|
||||
if __name__ == '__main__':
|
||||
main('qt')
|
||||
|
||||
149
qt_app.py
Normal file
149
qt_app.py
Normal file
@ -0,0 +1,149 @@
|
||||
import config, asyncio, sys, os, time, numpy as np, qt_app, web_app
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image, check_similarity, is_wayland
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from data import Base, engine, create_tables
|
||||
from draw import modify_image
|
||||
|
||||
from config import (SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY,
|
||||
REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL,
|
||||
IMAGE_CHANGE_THRESHOLD, TEMP_IMG_PATH)
|
||||
from logging_config import logger
|
||||
from PySide6.QtWidgets import QMainWindow, QLabel, QVBoxLayout, QWidget, QApplication
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from PySide6.QtGui import QPixmap, QImage
|
||||
|
||||
|
||||
class MainWindow(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Translator")
|
||||
|
||||
# Create main widget and layout
|
||||
main_widget = QWidget()
|
||||
self.setCentralWidget(main_widget)
|
||||
layout = QVBoxLayout(main_widget)
|
||||
|
||||
# Create image label
|
||||
self.image_label = QLabel()
|
||||
layout.addWidget(self.image_label)
|
||||
|
||||
# Set up image generator thread
|
||||
self.generator = qt_app.ImageGenerator()
|
||||
self.generator.image_ready.connect(self.update_image)
|
||||
self.generator.start()
|
||||
|
||||
# Set initial window size
|
||||
window_width, width_height = REGION[2] - REGION[0], REGION[3] - REGION[1]
|
||||
|
||||
self.resize(window_width, width_height)
|
||||
|
||||
def update_image(self, image_buffer):
|
||||
"""Update the displayed image directly from buffer bytes"""
|
||||
if image_buffer is None:
|
||||
return
|
||||
|
||||
# Convert buffer to QImage
|
||||
q_image = QImage.fromData(image_buffer)
|
||||
|
||||
if q_image.isNull():
|
||||
logger.error("Failed to create QImage from buffer")
|
||||
return
|
||||
|
||||
# Convert QImage to QPixmap and display it
|
||||
pixmap = QPixmap.fromImage(q_image)
|
||||
|
||||
# Scale the pixmap to fit the label while maintaining aspect ratio
|
||||
scaled_pixmap = pixmap.scaled(
|
||||
self.image_label.size(),
|
||||
Qt.KeepAspectRatio,
|
||||
Qt.SmoothTransformation
|
||||
)
|
||||
|
||||
self.image_label.setPixmap(scaled_pixmap)
|
||||
|
||||
|
||||
class ImageGenerator(QThread):
|
||||
"""Thread for generating images continuously"""
|
||||
image_ready = Signal(np.ndarray)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
printsc(REGION, TEMP_IMG_PATH)
|
||||
self.running = True
|
||||
self.OCR_LANGUAGES = [SOURCE_LANG, 'en']
|
||||
self.ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = self.OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
self.ocr_output = id_keep_source_lang(self.ocr, TEMP_IMG_PATH, SOURCE_LANG)
|
||||
self.models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
# self.model, self.tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
self.runs = 0
|
||||
self.prev_words = set()
|
||||
self.curr_words = set(get_words(self.ocr_output))
|
||||
self.translated_image = None
|
||||
|
||||
def run(self):
|
||||
asyncio.run(self.async_run())
|
||||
|
||||
async def async_run(self):
|
||||
|
||||
while self.running:
|
||||
logger.debug("Capturing screen")
|
||||
printsc(REGION, TEMP_IMG_PATH)
|
||||
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
|
||||
self.ocr_output = id_keep_source_lang(self.ocr, TEMP_IMG_PATH, SOURCE_LANG) # keep only phrases containing the source language
|
||||
logger.debug(f"OCR completed. Detected {len(self.ocr_output)} phrases.")
|
||||
if self.runs == 0:
|
||||
logger.info('Initial run')
|
||||
self.prev_words = set()
|
||||
else:
|
||||
logger.debug(f'Run number: {self.runs}.')
|
||||
self.runs += 1
|
||||
|
||||
self.curr_words = set(get_words(self.ocr_output))
|
||||
logger.debug(f'Current words: {self.curr_words} Previous words: {self.prev_words}')
|
||||
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||
if self.prev_words != self.curr_words and not check_similarity(list(self.curr_words), list(self.prev_words), threshold = IMAGE_CHANGE_THRESHOLD, method="tfidf"):
|
||||
logger.info('Beginning Translation')
|
||||
|
||||
to_translate = [entry[1] for entry in self.ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = self.model, tokenizer = self.tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
try:
|
||||
translation = await translate_API_LLM(to_translate, self.models, call_size = 3)
|
||||
except TypeError as e:
|
||||
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for {2*INTERVAL} seconds.")
|
||||
time.sleep(2*INTERVAL)
|
||||
continue
|
||||
logger.debug('Translation complete. Modifying image.')
|
||||
self.translated_image = modify_image(TEMP_IMG_PATH, self.ocr_output, translation)
|
||||
# view_buffer_app.show_buffer_image(translated_image, label)
|
||||
logger.debug("Image modified. Saving image.")
|
||||
self.prev_words = self.curr_words
|
||||
else:
|
||||
logger.info(f"Skipping translation. No significant change in the screen detected. Total translation attempts so far: {self.runs}.")
|
||||
logger.debug("Continuing to next iteration.")
|
||||
time.sleep(INTERVAL)
|
||||
self.image_ready.emit(self.translated_image)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.wait()
|
||||
|
||||
|
||||
|
||||
def closeEvent(self, event):
|
||||
"""Clean up when closing the window"""
|
||||
self.generator.stop()
|
||||
event.accept()
|
||||
|
||||
|
||||
def qt_app_main():
|
||||
app = QApplication(sys.argv)
|
||||
window = MainWindow()
|
||||
window.show()
|
||||
sys.exit(app.exec())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
qt_app_main()
|
||||
115
qtapp.py
115
qtapp.py
@ -1,115 +0,0 @@
|
||||
###################################################################################
|
||||
##### IMPORT LIBRARIES #####
|
||||
import os, time, sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from logging_config import logger
|
||||
from draw import modify_image_bytes
|
||||
from config import ADD_OVERLAY, SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, FONT_SIZE, FONT_FILE, FONT_COLOUR
|
||||
from create_overlay import app, overlay
|
||||
from typing import Optional, List
|
||||
###################################################################################
|
||||
from PySide6.QtCore import Qt, QPoint, QRect, QTimer, QThread, Signal
|
||||
from PySide6.QtGui import (QKeySequence, QShortcut, QAction, QPainter, QFont,
|
||||
QColor, QIcon, QImage, QPixmap)
|
||||
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||
QLabel, QSystemTrayIcon, QMenu)
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TranslationThread(QThread):
|
||||
translation_ready = Signal(list, list) # Signal to send translation results
|
||||
start_capture = Signal()
|
||||
end_capture = Signal()
|
||||
screen_capture = Signal(int, int, int, int)
|
||||
def __init__(self, ocr, models, source_lang, target_lang, interval):
|
||||
super().__init__()
|
||||
self.ocr = ocr
|
||||
self.models = models
|
||||
self.source_lang = source_lang
|
||||
self.target_lang = target_lang
|
||||
self.interval = interval
|
||||
self.running = True
|
||||
self.prev_words = set()
|
||||
self.runs = 0
|
||||
|
||||
def run(self):
|
||||
while self.running:
|
||||
self.start_capture.emit()
|
||||
untranslated_image = printsc(REGION)
|
||||
self.end_capture.emit()
|
||||
byte_image = convert_image_to_bytes(untranslated_image)
|
||||
ocr_output = id_keep_source_lang(self.ocr, byte_image, self.source_lang)
|
||||
|
||||
if self.runs == 0:
|
||||
logger.info('Initial run')
|
||||
else:
|
||||
logger.info(f'Run number: {self.runs}.')
|
||||
self.runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
|
||||
if self.prev_words != curr_words:
|
||||
logger.info('Translating')
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
translation = translate_API_LLM(to_translate, self.models)
|
||||
logger.info(f'Translation from {to_translate} to\n {translation}')
|
||||
|
||||
# Emit the translation results
|
||||
modify_image_bytes(byte_image, ocr_output, translation)
|
||||
self.translation_ready.emit(ocr_output, translation)
|
||||
|
||||
self.prev_words = curr_words
|
||||
else:
|
||||
logger.info("No new words to translate. Output will not refresh.")
|
||||
|
||||
logger.info(f'Sleeping for {self.interval} seconds')
|
||||
time.sleep(self.interval)
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Initialize OCR
|
||||
OCR_LANGUAGES = [SOURCE_LANG, TARGET_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, easy_languages=OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
# Initialize translation
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
|
||||
|
||||
# Create and start translation thread
|
||||
translation_thread = TranslationThread(
|
||||
ocr=ocr,
|
||||
models=models,
|
||||
source_lang=SOURCE_LANG,
|
||||
target_lang=TARGET_LANG,
|
||||
interval=INTERVAL
|
||||
)
|
||||
|
||||
# Connect translation results to overlay update
|
||||
translation_thread.start_capture.connect(overlay.prepare_for_capture)
|
||||
translation_thread.end_capture.connect(overlay.restore_after_capture)
|
||||
translation_thread.translation_ready.connect(overlay.update_translation)
|
||||
translation_thread.screen_capture.connect(overlay.capture_behind)
|
||||
# Start the translation thread
|
||||
translation_thread.start()
|
||||
|
||||
|
||||
# Start Qt event loop
|
||||
result = app.exec()
|
||||
|
||||
# Cleanup
|
||||
translation_thread.stop()
|
||||
translation_thread.wait()
|
||||
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
262
requirements.txt
Normal file
262
requirements.txt
Normal file
@ -0,0 +1,262 @@
|
||||
absl-py==2.1.0
|
||||
aiohappyeyeballs==2.4.3
|
||||
aiohttp==3.10.10
|
||||
aiosignal==1.3.1
|
||||
albucore==0.0.13
|
||||
albumentations==1.4.10
|
||||
annotated-types==0.7.0
|
||||
anyio==4.6.2.post1
|
||||
argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==21.2.0
|
||||
arrow==1.3.0
|
||||
astor==0.8.1
|
||||
asttokens==2.4.1
|
||||
astunparse==1.6.3
|
||||
async-lru==2.0.4
|
||||
attrs==24.2.0
|
||||
babel==2.16.0
|
||||
beautifulsoup4==4.12.3
|
||||
bleach==6.2.0
|
||||
blinker==1.8.2
|
||||
cachetools==5.5.0
|
||||
certifi==2024.8.30
|
||||
cffi==1.17.1
|
||||
charset-normalizer==3.4.0
|
||||
click==8.1.7
|
||||
coloredlogs==15.0.1
|
||||
comm==0.2.2
|
||||
contourpy==1.3.0
|
||||
ctranslate2==4.5.0
|
||||
cycler==0.12.1
|
||||
Cython==3.0.11
|
||||
datasets==3.1.0
|
||||
debugpy==1.8.7
|
||||
decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
Deprecated==1.2.14
|
||||
dill==0.3.8
|
||||
distro==1.9.0
|
||||
easyocr==1.7.2
|
||||
EasyProcess==1.1
|
||||
entrypoint2==1.1
|
||||
eval_type_backport==0.2.0
|
||||
executing==2.1.0
|
||||
fastjsonschema==2.20.0
|
||||
filelock==3.16.1
|
||||
fire==0.7.0
|
||||
Flask==3.0.3
|
||||
Flask-SSE==1.0.0
|
||||
flatbuffers==24.3.25
|
||||
fonttools==4.54.1
|
||||
fqdn==1.5.1
|
||||
frozenlist==1.5.0
|
||||
fsspec==2024.10.0
|
||||
gast==0.6.0
|
||||
google==3.0.0
|
||||
google-ai-generativelanguage==0.6.10
|
||||
google-api-core==2.22.0
|
||||
google-api-python-client==2.151.0
|
||||
google-auth==2.35.0
|
||||
google-auth-httplib2==0.2.0
|
||||
google-generativeai==0.8.3
|
||||
google-pasta==0.2.0
|
||||
googleapis-common-protos==1.65.0
|
||||
greenlet==3.1.1
|
||||
groq==0.11.0
|
||||
grpcio==1.67.1
|
||||
grpcio-status==1.67.1
|
||||
h11==0.14.0
|
||||
h5py==3.12.1
|
||||
httpcore==1.0.6
|
||||
httplib2==0.22.0
|
||||
httpx==0.27.2
|
||||
huggingface-hub==0.26.2
|
||||
humanfriendly==10.0
|
||||
idna==3.10
|
||||
imageio==2.36.0
|
||||
imgaug==0.4.0
|
||||
ipykernel==6.29.5
|
||||
ipython==8.29.0
|
||||
ipywidgets==8.1.5
|
||||
isoduration==20.11.0
|
||||
itsdangerous==2.2.0
|
||||
jaconv==0.4.0
|
||||
jedi==0.19.1
|
||||
jeepney==0.8.0
|
||||
Jinja2==3.1.4
|
||||
jiter==0.7.0
|
||||
joblib==1.4.2
|
||||
json5==0.9.25
|
||||
jsonpath-python==1.0.6
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
jupyter==1.1.1
|
||||
jupyter-console==6.6.3
|
||||
jupyter-events==0.10.0
|
||||
jupyter-lsp==2.2.5
|
||||
jupyter_client==8.6.3
|
||||
jupyter_core==5.7.2
|
||||
jupyter_server==2.14.2
|
||||
jupyter_server_terminals==0.5.3
|
||||
jupyterlab==4.2.5
|
||||
jupyterlab_pygments==0.3.0
|
||||
jupyterlab_server==2.27.3
|
||||
jupyterlab_widgets==3.0.13
|
||||
keras==3.6.0
|
||||
kiwisolver==1.4.7
|
||||
langid==1.1.6
|
||||
lazy_loader==0.4
|
||||
libclang==18.1.1
|
||||
lmdb==1.5.1
|
||||
lxml==5.3.0
|
||||
Markdown==3.7
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.9.2
|
||||
matplotlib-inline==0.1.7
|
||||
mdurl==0.1.2
|
||||
mecab-python3==1.0.10
|
||||
mistralai==1.1.0
|
||||
mistune==3.0.2
|
||||
ml-dtypes==0.4.1
|
||||
mpmath==1.3.0
|
||||
mss==9.0.2
|
||||
multidict==6.1.0
|
||||
multiprocess==0.70.16
|
||||
mypy-extensions==1.0.0
|
||||
namex==0.0.8
|
||||
nbclient==0.10.0
|
||||
nbconvert==7.16.4
|
||||
nbformat==5.10.4
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.4.2
|
||||
ninja==1.11.1.1
|
||||
notebook==7.2.2
|
||||
notebook_shim==0.2.4
|
||||
numpy==1.26.4
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
onnxruntime==1.19.2
|
||||
openai==1.53.0
|
||||
opencv-contrib-python==4.10.0.84
|
||||
opt-einsum==3.3.0
|
||||
optimum==1.23.3
|
||||
optree==0.13.0
|
||||
overrides==7.7.0
|
||||
packaging==24.1
|
||||
paddleocr==2.9.1
|
||||
paddlepaddle-gpu==2.6.2
|
||||
pandas==2.2.3
|
||||
pandocfilters==1.5.1
|
||||
parso==0.8.4
|
||||
pexpect==4.9.0
|
||||
pillow==11.0.0
|
||||
pinyin==0.4.0
|
||||
plac==1.4.3
|
||||
platformdirs==4.3.6
|
||||
prometheus_client==0.21.0
|
||||
prompt_toolkit==3.0.48
|
||||
propcache==0.2.0
|
||||
proto-plus==1.25.0
|
||||
protobuf==5.28.3
|
||||
psutil==6.1.0
|
||||
ptyprocess==0.7.0
|
||||
pure_eval==0.2.3
|
||||
pyarrow==18.0.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.1
|
||||
pyclipper==1.3.0.post6
|
||||
pycparser==2.22
|
||||
pydantic==2.9.2
|
||||
pydantic_core==2.23.4
|
||||
pydotenv==0.0.7
|
||||
Pygments==2.18.0
|
||||
pykakasi==2.3.0
|
||||
pyparsing==3.2.0
|
||||
pypinyin==0.53.0
|
||||
pyscreenshot==3.1
|
||||
PySide6==6.8.0.2
|
||||
PySide6_Addons==6.8.0.2
|
||||
PySide6_Essentials==6.8.0.2
|
||||
python-bidi==0.6.3
|
||||
python-dateutil==2.8.2
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.0.1
|
||||
python-json-logger==2.0.7
|
||||
pytz==2024.2
|
||||
PyYAML==6.0.2
|
||||
pyzmq==26.2.0
|
||||
RapidFuzz==3.10.1
|
||||
rapidocr-onnxruntime==1.3.25
|
||||
redis==5.2.0
|
||||
referencing==0.35.1
|
||||
regex==2024.9.11
|
||||
requests==2.32.3
|
||||
rfc3339-validator==0.1.4
|
||||
rfc3986-validator==0.1.1
|
||||
rich==13.9.3
|
||||
rpds-py==0.20.0
|
||||
rsa==4.9
|
||||
sacremoses==0.1.1
|
||||
safetensors==0.4.5
|
||||
scikit-image==0.24.0
|
||||
scikit-learn==1.5.2
|
||||
scipy==1.14.1
|
||||
Send2Trash==1.8.3
|
||||
sentencepiece==0.2.0
|
||||
setuptools==75.3.0
|
||||
shapely==2.0.6
|
||||
shiboken6==6.8.0.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
SQLAlchemy==2.0.36
|
||||
stack-data==0.6.3
|
||||
sympy==1.13.1
|
||||
tensorboard==2.18.0
|
||||
tensorboard-data-server==0.7.2
|
||||
termcolor==2.5.0
|
||||
terminado==0.18.1
|
||||
threadpoolctl==3.5.0
|
||||
tifffile==2024.9.20
|
||||
tinycss2==1.4.0
|
||||
tokenizers==0.20.1
|
||||
tomli==2.0.2
|
||||
torch==2.5.1
|
||||
torchvision==0.20.1
|
||||
tornado==6.4.1
|
||||
tqdm==4.66.6
|
||||
traitlets==5.14.3
|
||||
transformers==4.46.1
|
||||
triton==3.1.0
|
||||
types-python-dateutil==2.9.0.20241003
|
||||
typing-inspect==0.9.0
|
||||
typing_extensions==4.12.2
|
||||
tzdata==2024.2
|
||||
unidic==1.1.0
|
||||
uri-template==1.3.0
|
||||
uritemplate==4.1.1
|
||||
urllib3==2.2.3
|
||||
uroman==1.3.1.1
|
||||
wasabi==0.10.1
|
||||
wcwidth==0.2.13
|
||||
webcolors==24.8.0
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.8.0
|
||||
Werkzeug==3.0.6
|
||||
wheel==0.44.0
|
||||
widgetsnbextension==4.0.13
|
||||
wrapt==1.16.0
|
||||
xxhash==3.5.0
|
||||
yarl==1.17.1
|
||||
@ -17,7 +17,7 @@
|
||||
setInterval(function () {
|
||||
document.getElementById("live-image").src =
|
||||
"/image?" + new Date().getTime();
|
||||
}, 3500); // Update every 2 seconds
|
||||
}, 1500); // Update every 2.5 seconds. Beware that if the image fails to reload on time, the browser will continuously refresh without being able to display the images.
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
35
web.py
35
web.py
@ -1,35 +0,0 @@
|
||||
from flask import Flask, Response, render_template
|
||||
import threading
|
||||
import io
|
||||
import app
|
||||
app = Flask(__name__)
|
||||
|
||||
# Global variable to hold the current image
|
||||
def curr_image():
|
||||
return app.latest_image
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/image')
|
||||
def stream_image():
|
||||
if curr_image() is None:
|
||||
return "No image generated yet.", 503
|
||||
file_object = io.BytesIO()
|
||||
curr_image().save(file_object, 'PNG')
|
||||
file_object.seek(0)
|
||||
response = Response(file_object.getvalue(), mimetype='image/png')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' # HTTP 1.1
|
||||
response.headers['Pragma'] = 'no-cache' # HTTP 1.0
|
||||
response.headers['Expires'] = '0' # Proxies
|
||||
|
||||
return response
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Start the image updating thread
|
||||
threading.Thread(target=app.main, daemon=True).start()
|
||||
|
||||
# Start the Flask web server
|
||||
app.run(host='0.0.0.0', port=5000, debug=True)
|
||||
113
web_app.py
Normal file
113
web_app.py
Normal file
@ -0,0 +1,113 @@
|
||||
from flask import Flask, Response, render_template
|
||||
import threading
|
||||
import io
|
||||
import os, time, sys, threading, subprocess
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'helpers'))
|
||||
|
||||
from translation import translate_Seq_LLM, translate_API_LLM, init_API_LLM, init_Seq_LLM
|
||||
from utils import printsc, convert_image_to_bytes, bytes_to_image, check_similarity, is_wayland
|
||||
from ocr import get_words, init_OCR, id_keep_source_lang
|
||||
from data import Base, engine, create_tables
|
||||
from draw import modify_image
|
||||
import asyncio
|
||||
from config import SOURCE_LANG, TARGET_LANG, OCR_MODEL, OCR_USE_GPU, LOCAL_FILES_ONLY, REGION, INTERVAL, MAX_TRANSLATE, TRANSLATION_MODEL, IMAGE_CHANGE_THRESHOLD, TEMP_IMG_PATH
|
||||
from logging_config import logger
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
latest_image = None
|
||||
|
||||
async def web_app_main():
|
||||
###################################################################################
|
||||
global latest_image
|
||||
# Initialisation
|
||||
##### Create the database if not present #####
|
||||
create_tables()
|
||||
|
||||
##### Initialize the OCR #####
|
||||
OCR_LANGUAGES = [SOURCE_LANG, 'en']
|
||||
ocr = init_OCR(model=OCR_MODEL, paddle_lang= SOURCE_LANG, easy_languages = OCR_LANGUAGES, use_GPU=OCR_USE_GPU)
|
||||
|
||||
##### Initialize the translation #####
|
||||
# model, tokenizer = init_Seq_LLM(TRANSLATION_MODEL, from_lang =SOURCE_LANG , target_lang = TARGET_LANG)
|
||||
models = init_API_LLM(SOURCE_LANG, TARGET_LANG)
|
||||
###################################################################################
|
||||
|
||||
runs = 0
|
||||
|
||||
# label, app = view_buffer_app.create_viewer()
|
||||
|
||||
# try:
|
||||
while True:
|
||||
logger.debug("Capturing screen")
|
||||
printsc(REGION, TEMP_IMG_PATH)
|
||||
logger.debug(f"Screen Captured. Proceeding to perform OCR.")
|
||||
ocr_output = id_keep_source_lang(ocr, TEMP_IMG_PATH, SOURCE_LANG) # keep only phrases containing the source language
|
||||
logger.debug(f"OCR completed. Detected {len(ocr_output)} phrases.")
|
||||
if runs == 0:
|
||||
logger.info('Initial run')
|
||||
prev_words = set()
|
||||
else:
|
||||
logger.debug(f'Run number: {runs}.')
|
||||
runs += 1
|
||||
|
||||
curr_words = set(get_words(ocr_output))
|
||||
logger.debug(f'Current words: {curr_words} Previous words: {prev_words}')
|
||||
|
||||
### If the OCR detects different words, translate screen -> to ensure that the screen is not refreshing constantly and to save GPU power
|
||||
if prev_words != curr_words and not check_similarity(list(curr_words),list(prev_words), threshold = IMAGE_CHANGE_THRESHOLD, method="tfidf"):
|
||||
logger.info('Beginning Translation')
|
||||
|
||||
to_translate = [entry[1] for entry in ocr_output][:MAX_TRANSLATE]
|
||||
# translation = translate_Seq_LLM(to_translate, model_type = TRANSLATION_MODEL, model = model, tokenizer = tokenizer, from_lang = SOURCE_LANG, target_lang = TARGET_LANG)
|
||||
try:
|
||||
if len(to_translate) == 0:
|
||||
logger.info("No text detected. Skipping translation. Continuing to next iteration.")
|
||||
continue
|
||||
translation = await translate_API_LLM(to_translate, models, call_size = 3)
|
||||
except TypeError as e:
|
||||
logger.error(f"Failed to translate using API models. Error: {e}. Sleeping for {2*INTERVAL} seconds.")
|
||||
time.sleep(2*INTERVAL)
|
||||
continue
|
||||
logger.debug('Translation complete. Modifying image.')
|
||||
translated_image = modify_image(TEMP_IMG_PATH, ocr_output, translation)
|
||||
latest_image = bytes_to_image(translated_image)
|
||||
logger.debug("Image modified. Saving image.")
|
||||
prev_words = curr_words
|
||||
else:
|
||||
logger.info(f"Skipping translation. No significant change in the screen detected. Total translation attempts so far: {runs}.")
|
||||
logger.debug("Continuing to next iteration.")
|
||||
time.sleep(INTERVAL)
|
||||
|
||||
# Global variable to hold the current image
|
||||
def curr_image():
|
||||
return latest_image
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/image')
|
||||
def stream_image():
|
||||
if curr_image() is None:
|
||||
return "No image generated yet.", 503
|
||||
file_object = io.BytesIO()
|
||||
curr_image().save(file_object, 'PNG')
|
||||
file_object.seek(0)
|
||||
response = Response(file_object.getvalue(), mimetype='image/png')
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' # HTTP 1.1
|
||||
response.headers['Pragma'] = 'no-cache' # HTTP 1.0
|
||||
response.headers['Expires'] = '0' # Proxies
|
||||
|
||||
return response
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Start the image updating thread
|
||||
|
||||
threading.Thread(target=asyncio.run, args=(web_app_main(),), daemon=True).start()
|
||||
|
||||
# Start the Flask web server
|
||||
app.run(host='0.0.0.0', port=5000, debug=False)
|
||||
Loading…
Reference in New Issue
Block a user