File size: 2,222 Bytes
2e237ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
from os import makedirs
from os.path import join, exists
from pathlib import Path
from urllib.request import urlretrieve
from huggingface_hub import snapshot_download, hf_hub_download

from configuration import service_logger, MODELS_PATH


def download_progress(count, block_size, total_size):
    total_counts = total_size // block_size
    show_counts_percentages = total_counts // 5
    percent = count * block_size * 100 / total_size
    if count % show_counts_percentages == 0:
        service_logger.info(f"Downloaded {math.ceil(percent)}%")


def download_vgt_model(model_name: str):
    service_logger.info(f"Downloading {model_name} model")
    model_path = join(MODELS_PATH, f"{model_name}_VGT_model.pth")
    if exists(model_path):
        return
    download_link = f"https://github.com/AlibabaResearch/AdvancedLiterateMachinery/releases/download/v1.3.0-VGT-release/{model_name}_VGT_model.pth"
    urlretrieve(download_link, model_path, reporthook=download_progress)


def download_embedding_model():
    model_path = join(MODELS_PATH, "layoutlm-base-uncased")
    if exists(model_path):
        return
    makedirs(model_path, exist_ok=True)
    service_logger.info("Embedding model is being downloaded")
    snapshot_download(repo_id="microsoft/layoutlm-base-uncased", local_dir=model_path, local_dir_use_symlinks=False)


def download_from_hf_hub(path: Path):
    if path.exists():
        return

    file_name = path.name
    makedirs(path.parent, exist_ok=True)
    repo_id = "HURIDOCS/pdf-document-layout-analysis"
    hf_hub_download(repo_id=repo_id, filename=file_name, local_dir=path.parent, local_dir_use_symlinks=False)


def download_lightgbm_models():
    download_from_hf_hub(Path(MODELS_PATH, "token_type_lightgbm.model"))
    download_from_hf_hub(Path(MODELS_PATH, "paragraph_extraction_lightgbm.model"))
    download_from_hf_hub(Path(MODELS_PATH, "config.json"))


def download_models(model_name: str):
    makedirs(MODELS_PATH, exist_ok=True)
    if model_name == "fast":
        download_lightgbm_models()
        return
    download_vgt_model(model_name)
    download_embedding_model()


if __name__ == "__main__":
    download_models("doclaynet")
    download_models("fast")