88 lines
2.8 KiB
Python
88 lines
2.8 KiB
Python
from paddleocr import PaddleOCR
|
|
import easyocr
|
|
from rapidocr_onnxruntime import RapidOCR
|
|
import langid
|
|
from helpers.utils import contains_lang
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
# PaddleOCR
|
|
# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
|
|
# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
|
|
# to switch the language model in order.
|
|
# need to run only once to download and load model into memory
|
|
|
|
def _paddle_init(lang='ch', use_angle_cls=False, use_GPU=True):
|
|
return PaddleOCR(use_angle_cls=use_angle_cls, lang=lang, use_GPU=use_GPU)
|
|
|
|
|
|
def _paddle_ocr(ocr, image) -> list:
|
|
### return a list containing the bounding box, text and confidence of the detected text
|
|
result = ocr.ocr(image, cls=False)[0]
|
|
if not isinstance(result, list):
|
|
return []
|
|
result = [ (pos, text[0], text[1]) for pos, text in result]
|
|
return result
|
|
|
|
# EasyOCR has support for many languages
|
|
|
|
def _easy_init(ocr_languages: list, use_GPU=True):
|
|
return easyocr.Reader(ocr_languages, gpu=use_GPU)
|
|
|
|
def _easy_ocr(ocr,image) -> list:
|
|
return ocr.readtext(image)
|
|
|
|
# RapidOCR mostly for mandarin and some other asian languages
|
|
|
|
def _rapid_init(use_GPU=True):
|
|
return RapidOCR(use_gpu=use_GPU)
|
|
|
|
def _rapid_ocr(ocr, image) -> list:
|
|
return ocr(image)
|
|
|
|
### Initialize the OCR model
|
|
def init_OCR(model='paddle', **kwargs):
|
|
if model == 'paddle':
|
|
return _paddle_init(**kwargs)
|
|
elif model == 'easy':
|
|
return _easy_init(**kwargs)
|
|
elif model == 'rapid':
|
|
return _rapid_init(**kwargs)
|
|
|
|
### Perform OCR on the image
|
|
def identify(ocr, image) -> list:
|
|
if isinstance(ocr, PaddleOCR):
|
|
return _paddle_ocr(ocr, image)
|
|
elif isinstance(ocr, easyocr.Reader):
|
|
return _easy_ocr(ocr, image)
|
|
elif isinstance(ocr, RapidOCR):
|
|
return _rapid_ocr(ocr, image)
|
|
else:
|
|
raise ValueError("Invalid OCR model. Please initialise the OCR model first with init() and pass it as an argument to identify().")
|
|
|
|
|
|
### Filter out the results that are not in the source language
|
|
def id_filtered(ocr, image, lang) -> list:
|
|
result = identify(ocr, image)
|
|
|
|
### Parallelise since langid is slow
|
|
def classify_text(entry):
|
|
return entry if langid.classify(entry[1])[0] == lang else None
|
|
with ThreadPoolExecutor() as executor:
|
|
results_no_eng = list(filter(None, executor.map(classify_text, result)))
|
|
return results_no_eng
|
|
|
|
|
|
# zh, ja, ko
|
|
def id_lang(ocr, image, lang) -> list:
|
|
result = identify(ocr, image)
|
|
filtered = [entry for entry in result if contains_lang(entry[1], lang)]
|
|
return filtered
|
|
|
|
def get_words(ocr_output) -> list:
|
|
return [entry[1] for entry in ocr_output]
|
|
|
|
def get_positions(ocr_output) -> list:
|
|
return [entry[0] for entry in ocr_output]
|
|
|
|
def get_confidences(ocr_output) -> list:
|
|
return [entry[2] for entry in ocr_output]
|