File size: 2,989 Bytes
85d9fef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import hashlib
import logging
from pathlib import Path
from typing import Callable, Optional
from urllib.request import urlretrieve
from zipfile import ZipFile

from tqdm.auto import tqdm

FILENAME = "craft_mlt_25k.pth"
URL = (
    "https://github.com/JaidedAI/EasyOCR/releases/download/pre-v1.1.6/craft_mlt_25k.zip"
)
MD5SUM = "2f8227d2def4037cdb3b34389dcf9ec1"
MD5MSG = "MD5 hash mismatch, possible file corruption"


logger = logging.getLogger(__name__)


def calculate_md5(path: Path) -> str:
    hash_md5 = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


def print_progress_bar(t: tqdm) -> Callable[[int, int, Optional[int]], None]:
    last = 0

    def update_to(

        count: int = 1, block_size: int = 1, total_size: Optional[int] = None

    ):
        nonlocal last
        if total_size is not None:
            t.total = total_size
        t.update((count - last) * block_size)
        last = count

    return update_to


def download_and_unzip(

    url: str, filename: str, model_storage_directory: Path, verbose: bool = True

):
    zip_path = model_storage_directory / "temp.zip"
    with tqdm(
        unit="B", unit_scale=True, unit_divisor=1024, miniters=1, disable=not verbose
    ) as t:
        reporthook = print_progress_bar(t)
        urlretrieve(url, str(zip_path), reporthook=reporthook)
    with ZipFile(zip_path, "r") as zipObj:
        zipObj.extract(filename, str(model_storage_directory))
    zip_path.unlink()


def prepare_model(model_storage_directory: Path, download=True, verbose: bool = True) -> bool:
    model_storage_directory.mkdir(parents=True, exist_ok=True)

    detector_path = model_storage_directory / FILENAME

    # try get model path
    model_available = False
    if not detector_path.is_file():
        if not download:
            raise FileNotFoundError(f"Missing {detector_path} and downloads disabled")
        logger.info(
            "Downloading detection model, please wait. "
            "This may take several minutes depending upon your network connection."
        )
    elif calculate_md5(detector_path) != MD5SUM:
        logger.warning(MD5MSG)
        if not download:
            raise FileNotFoundError(
                f"MD5 mismatch for {detector_path} and downloads disabled"
            )
        detector_path.unlink()
        logger.info(
            "Re-downloading the detection model, please wait. "
            "This may take several minutes depending upon your network connection."
        )
    else:
        model_available = True

    if not model_available:
        download_and_unzip(URL, FILENAME, model_storage_directory, verbose)
        if calculate_md5(detector_path) != MD5SUM:
            raise ValueError(MD5MSG)
        logger.info("Download complete")

    return detector_path