Gabriel's picture
Update src/htr_pipeline/models.py
9429105
raw history blame
No virus
2.46 kB
import os
import torch
from huggingface_hub import snapshot_download
from mmdet.apis import DetInferencer
# from mmengine import Config
from mmocr.apis import TextRecInferencer
class HtrModels:
def __init__(self, local_run=False):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_folder = "./models"
self.region_config = f"{model_folder}/RmtDet_regions/rtmdet_m_textregions_2_concat.py"
self.region_checkpoint = f"{model_folder}/RmtDet_regions/epoch_12.pth"
self.line_config = f"{model_folder}/RmtDet_lines/rtmdet_m_textlines_2_concat.py"
self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth"
self.mmocr_config = f"{model_folder}/SATRN/_base_satrn_shallow_concat.py"
self.mmocr_checkpoint = f"{model_folder}/SATRN/epoch_5.pth"
# Check if model files exist at the specified paths, if not, get the config
if not (
os.path.exists(self.region_checkpoint)
and os.path.exists(self.line_checkpoint)
and os.path.exists(self.mmocr_checkpoint)
):
config_path = self.get_config()
self.region_checkpoint = config_path["region_checkpoint"]
self.line_checkpoint = config_path["line_checkpoint"]
self.mmocr_checkpoint = config_path["mmocr_checkpoint"]
def load_region_model(self):
# build the model from a config file and a checkpoint file
return DetInferencer(self.region_config, self.region_checkpoint, device=self.device)
def load_line_model(self):
return DetInferencer(self.line_config, self.line_checkpoint, device=self.device)
def load_htr_model(self):
inferencer = TextRecInferencer(self.mmocr_config, self.mmocr_checkpoint, device=self.device)
return inferencer
@staticmethod
def get_config():
path_models = snapshot_download(
"Riksarkivet/HTR_pipeline_models",
allow_patterns=["*.pth"],
token="__INSERT__FINS_HUGGINFACE_TOKEN__",
cache_dir="./",
)
config_path = {
"region_checkpoint": os.path.join(path_models, "RmtDet_regions/epoch_12.pth"),
"line_checkpoint": os.path.join(path_models, "RmtDet_lines/epoch_12.pth"),
"mmocr_checkpoint": os.path.join(path_models, "SATRN/epoch_5.pth"),
}
return config_path
if __name__ == "__main__":
pass