Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from mmdet.apis import DetInferencer | |
# from mmengine import Config | |
from mmocr.apis import TextRecInferencer | |
class HtrModels: | |
def __init__(self, local_run=False): | |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_folder = "./models" | |
self.region_config = f"{model_folder}/RmtDet_regions/rtmdet_m_textregions_2_concat.py" | |
self.region_checkpoint = f"{model_folder}/RmtDet_regions/epoch_12.pth" | |
self.line_config = f"{model_folder}/RmtDet_lines/rtmdet_m_textlines_2_concat.py" | |
self.line_checkpoint = f"{model_folder}/RmtDet_lines/epoch_12.pth" | |
self.mmocr_config = f"{model_folder}/SATRN/_base_satrn_shallow_concat.py" | |
self.mmocr_checkpoint = f"{model_folder}/SATRN/epoch_5.pth" | |
# Check if model files exist at the specified paths, if not, get the config | |
if not ( | |
os.path.exists(self.region_checkpoint) | |
and os.path.exists(self.line_checkpoint) | |
and os.path.exists(self.mmocr_checkpoint) | |
): | |
config_path = self.get_config() | |
self.region_checkpoint = config_path["region_checkpoint"] | |
self.line_checkpoint = config_path["line_checkpoint"] | |
self.mmocr_checkpoint = config_path["mmocr_checkpoint"] | |
def load_region_model(self): | |
# build the model from a config file and a checkpoint file | |
return DetInferencer(self.region_config, self.region_checkpoint, device=self.device) | |
def load_line_model(self): | |
return DetInferencer(self.line_config, self.line_checkpoint, device=self.device) | |
def load_htr_model(self): | |
inferencer = TextRecInferencer(self.mmocr_config, self.mmocr_checkpoint, device=self.device) | |
return inferencer | |
def get_config(): | |
path_models = snapshot_download( | |
"Riksarkivet/HTR_pipeline_models", | |
allow_patterns=["*.pth"], | |
token="__INSERT__FINS_HUGGINFACE_TOKEN__", | |
cache_dir="./", | |
) | |
config_path = { | |
"region_checkpoint": os.path.join(path_models, "RmtDet_regions/epoch_12.pth"), | |
"line_checkpoint": os.path.join(path_models, "RmtDet_lines/epoch_12.pth"), | |
"mmocr_checkpoint": os.path.join(path_models, "SATRN/epoch_5.pth"), | |
} | |
return config_path | |
if __name__ == "__main__": | |
pass | |