onscreen-translator/helpers/ocr.py
2024-11-01 15:44:12 +11:00

88 lines
2.8 KiB
Python

from paddleocr import PaddleOCR
import easyocr
from rapidocr_onnxruntime import RapidOCR
import langid
from helpers.utils import contains_lang
from concurrent.futures import ThreadPoolExecutor
# PaddleOCR
# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
# to switch the language model in order.
# need to run only once to download and load model into memory
def _paddle_init(lang='ch', use_angle_cls=False, use_GPU=True):
return PaddleOCR(use_angle_cls=use_angle_cls, lang=lang, use_GPU=use_GPU)
def _paddle_ocr(ocr, image) -> list:
### return a list containing the bounding box, text and confidence of the detected text
result = ocr.ocr(image, cls=False)[0]
if not isinstance(result, list):
return []
result = [ (pos, text[0], text[1]) for pos, text in result]
return result
# EasyOCR has support for many languages
def _easy_init(ocr_languages: list, use_GPU=True):
return easyocr.Reader(ocr_languages, gpu=use_GPU)
def _easy_ocr(ocr,image) -> list:
return ocr.readtext(image)
# RapidOCR mostly for mandarin and some other asian languages
def _rapid_init(use_GPU=True):
return RapidOCR(use_gpu=use_GPU)
def _rapid_ocr(ocr, image) -> list:
return ocr(image)
### Initialize the OCR model
def init_OCR(model='paddle', **kwargs):
if model == 'paddle':
return _paddle_init(**kwargs)
elif model == 'easy':
return _easy_init(**kwargs)
elif model == 'rapid':
return _rapid_init(**kwargs)
### Perform OCR on the image
def identify(ocr, image) -> list:
if isinstance(ocr, PaddleOCR):
return _paddle_ocr(ocr, image)
elif isinstance(ocr, easyocr.Reader):
return _easy_ocr(ocr, image)
elif isinstance(ocr, RapidOCR):
return _rapid_ocr(ocr, image)
else:
raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to identify().")
### Filter out the results that are not in the source language
def id_filtered(ocr, image, lang) -> list:
result = identify(ocr, image)
### Parallelise since langid is slow
def classify_text(entry):
return entry if langid.classify(entry[1])[0] == lang else None
with ThreadPoolExecutor() as executor:
results_no_eng = list(filter(None, executor.map(classify_text, result)))
return results_no_eng
# zh, ja, ko
def id_lang(ocr, image, lang) -> list:
result = identify(ocr, image)
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
return filtered
def get_words(ocr_output) -> list:
return [entry[1] for entry in ocr_output]
def get_positions(ocr_output) -> list:
return [entry[0] for entry in ocr_output]
def get_confidences(ocr_output) -> list:
return [entry[2] for entry in ocr_output]