Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from spiga.data.loaders.dl_config import DatabaseStruct | |
MODELS_URL = { | |
"wflw": "https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7", | |
"300wpublic": "https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC", | |
"300wprivate": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", | |
"merlrav": "https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6", | |
"cofw68": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", | |
} | |
class ModelConfig(object): | |
def __init__(self, dataset_name=None, load_model_url=True): | |
# Model configuration | |
self.model_weights = None | |
self.model_weights_path = "./" | |
self.load_model_url = load_model_url | |
self.model_weights_url = None | |
# Pretreatment | |
self.focal_ratio = 1.5 # Camera matrix focal length ratio. | |
self.target_dist = 1.6 # Target distance zoom in/out around face. | |
self.image_size = (256, 256) | |
# Outputs | |
self.ftmap_size = (64, 64) | |
# Dataset | |
self.dataset = None | |
if dataset_name is not None: | |
self.update_with_dataset(dataset_name) | |
def update_with_dataset(self, dataset_name): | |
config_dict = { | |
"dataset": DatabaseStruct(dataset_name), | |
"model_weights": "spiga_%s.pt" % dataset_name, | |
} | |
if dataset_name == "cofw68": # Test only | |
config_dict["model_weights"] = "spiga_300wprivate.pt" | |
if self.load_model_url: | |
config_dict["model_weights_url"] = MODELS_URL[dataset_name] | |
self.update(config_dict) | |
def update(self, params_dict): | |
state_dict = self.state_dict() | |
for k, v in params_dict.items(): | |
if k in state_dict or hasattr(self, k): | |
setattr(self, k, v) | |
else: | |
raise Warning("Unknown option: {}: {}".format(k, v)) | |
def state_dict(self): | |
state_dict = OrderedDict() | |
for k in self.__dict__.keys(): | |
if not k.startswith("_"): | |
state_dict[k] = getattr(self, k) | |
return state_dict |