# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Represents a model repository, including pre-trained models and bags of models. A repo can either be the main remote repository stored in AWS, or a local repository with your own models. """ from hashlib import sha256 from pathlib import Path import typing as tp import torch import yaml from .apply import BagOfModels, Model from .states import load_model AnyModel = tp.Union[Model, BagOfModels] class ModelLoadingError(RuntimeError): pass def check_checksum(path: Path, checksum: str): sha = sha256() with open(path, 'rb') as file: while True: buf = file.read(2**20) if not buf: break sha.update(buf) actual_checksum = sha.hexdigest()[:len(checksum)] if actual_checksum != checksum: raise ModelLoadingError(f'Invalid checksum for file {path}, ' f'expected {checksum} but got {actual_checksum}') class ModelOnlyRepo: """Base class for all model only repos. """ def has_model(self, sig: str) -> bool: raise NotImplementedError() def get_model(self, sig: str) -> Model: raise NotImplementedError() class RemoteRepo(ModelOnlyRepo): def __init__(self, models: tp.Dict[str, str]): self._models = models def has_model(self, sig: str) -> bool: return sig in self._models def get_model(self, sig: str) -> Model: try: url = self._models[sig] except KeyError: raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) return load_model(pkg) class LocalRepo(ModelOnlyRepo): def __init__(self, root: Path): self.root = root self.scan() def scan(self): self._models = {} self._checksums = {} for file in self.root.iterdir(): if file.suffix == '.th': if '-' in file.stem: xp_sig, checksum = file.stem.split('-') self._checksums[xp_sig] = checksum else: xp_sig = file.stem if xp_sig in self._models: print('Whats xp? ', xp_sig) raise ModelLoadingError( f'Duplicate pre-trained model exist for signature {xp_sig}. ' 'Please delete all but one.') self._models[xp_sig] = file def has_model(self, sig: str) -> bool: return sig in self._models def get_model(self, sig: str) -> Model: try: file = self._models[sig] except KeyError: raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') if sig in self._checksums: check_checksum(file, self._checksums[sig]) return load_model(file) class BagOnlyRepo: """Handles only YAML files containing bag of models, leaving the actual model loading to some Repo. """ def __init__(self, root: Path, model_repo: ModelOnlyRepo): self.root = root self.model_repo = model_repo self.scan() def scan(self): self._bags = {} for file in self.root.iterdir(): if file.suffix == '.yaml': self._bags[file.stem] = file def has_model(self, name: str) -> bool: return name in self._bags def get_model(self, name: str) -> BagOfModels: try: yaml_file = self._bags[name] except KeyError: raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' 'a bag of models.') bag = yaml.safe_load(open(yaml_file)) signatures = bag['models'] models = [self.model_repo.get_model(sig) for sig in signatures] weights = bag.get('weights') segment = bag.get('segment') return BagOfModels(models, weights, segment) class AnyModelRepo: def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): self.model_repo = model_repo self.bag_repo = bag_repo def has_model(self, name_or_sig: str) -> bool: return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) def get_model(self, name_or_sig: str) -> AnyModel: # print('name_or_sig: ', name_or_sig) if self.model_repo.has_model(name_or_sig): return self.model_repo.get_model(name_or_sig) else: return self.bag_repo.get_model(name_or_sig)