from collections import OrderedDict from spiga.data.loaders.dl_config import DatabaseStruct MODELS_URL = { "wflw": "https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7", "300wpublic": "https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC", "300wprivate": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", "merlrav": "https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6", "cofw68": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", } class ModelConfig(object): def __init__(self, dataset_name=None, load_model_url=True): # Model configuration self.model_weights = None self.model_weights_path = "./" self.load_model_url = load_model_url self.model_weights_url = None # Pretreatment self.focal_ratio = 1.5 # Camera matrix focal length ratio. self.target_dist = 1.6 # Target distance zoom in/out around face. self.image_size = (256, 256) # Outputs self.ftmap_size = (64, 64) # Dataset self.dataset = None if dataset_name is not None: self.update_with_dataset(dataset_name) def update_with_dataset(self, dataset_name): config_dict = { "dataset": DatabaseStruct(dataset_name), "model_weights": "spiga_%s.pt" % dataset_name, } if dataset_name == "cofw68": # Test only config_dict["model_weights"] = "spiga_300wprivate.pt" if self.load_model_url: config_dict["model_weights_url"] = MODELS_URL[dataset_name] self.update(config_dict) def update(self, params_dict): state_dict = self.state_dict() for k, v in params_dict.items(): if k in state_dict or hasattr(self, k): setattr(self, k, v) else: raise Warning("Unknown option: {}: {}".format(k, v)) def state_dict(self): state_dict = OrderedDict() for k in self.__dict__.keys(): if not k.startswith("_"): state_dict[k] = getattr(self, k) return state_dict