OFA-OCR / ezocr /easyocrlite /utils /download_utils.py
JustinLin610's picture
add easyocr
85d9fef
raw
history blame
No virus
2.99 kB
import hashlib
import logging
from pathlib import Path
from typing import Callable, Optional
from urllib.request import urlretrieve
from zipfile import ZipFile
from tqdm.auto import tqdm
FILENAME = "craft_mlt_25k.pth"
URL = (
"https://github.com/JaidedAI/EasyOCR/releases/download/pre-v1.1.6/craft_mlt_25k.zip"
)
MD5SUM = "2f8227d2def4037cdb3b34389dcf9ec1"
MD5MSG = "MD5 hash mismatch, possible file corruption"
logger = logging.getLogger(__name__)
def calculate_md5(path: Path) -> str:
hash_md5 = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def print_progress_bar(t: tqdm) -> Callable[[int, int, Optional[int]], None]:
last = 0
def update_to(
count: int = 1, block_size: int = 1, total_size: Optional[int] = None
):
nonlocal last
if total_size is not None:
t.total = total_size
t.update((count - last) * block_size)
last = count
return update_to
def download_and_unzip(
url: str, filename: str, model_storage_directory: Path, verbose: bool = True
):
zip_path = model_storage_directory / "temp.zip"
with tqdm(
unit="B", unit_scale=True, unit_divisor=1024, miniters=1, disable=not verbose
) as t:
reporthook = print_progress_bar(t)
urlretrieve(url, str(zip_path), reporthook=reporthook)
with ZipFile(zip_path, "r") as zipObj:
zipObj.extract(filename, str(model_storage_directory))
zip_path.unlink()
def prepare_model(model_storage_directory: Path, download=True, verbose: bool = True) -> bool:
model_storage_directory.mkdir(parents=True, exist_ok=True)
detector_path = model_storage_directory / FILENAME
# try get model path
model_available = False
if not detector_path.is_file():
if not download:
raise FileNotFoundError(f"Missing {detector_path} and downloads disabled")
logger.info(
"Downloading detection model, please wait. "
"This may take several minutes depending upon your network connection."
)
elif calculate_md5(detector_path) != MD5SUM:
logger.warning(MD5MSG)
if not download:
raise FileNotFoundError(
f"MD5 mismatch for {detector_path} and downloads disabled"
)
detector_path.unlink()
logger.info(
"Re-downloading the detection model, please wait. "
"This may take several minutes depending upon your network connection."
)
else:
model_available = True
if not model_available:
download_and_unzip(URL, FILENAME, model_storage_directory, verbose)
if calculate_md5(detector_path) != MD5SUM:
raise ValueError(MD5MSG)
logger.info("Download complete")
return detector_path