pcuenq HF staff radames commited on
Commit
75f0b7a
·
verified ·
1 Parent(s): 556fb16

fix path / load face model (#12)

Browse files

- fix path (65d3e112994c7c84c285874b3009cfe02e5bdb15)
- model (9b24eb6e01f0f0ef079cf9540ce503bc61968a8d)
- Update app.py (cfd6e3456899a7bc1aadd48a69cecb022beb63de)


Co-authored-by: Radamés Ajna <radames@users.noreply.huggingface.co>

Files changed (3) hide show
  1. app.py +2 -2
  2. modelconfig.py +62 -0
  3. spiga_300wpublic.pt +3 -0
app.py CHANGED
@@ -11,7 +11,7 @@ import retinaface
11
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
12
  from diffusers import UniPCMultistepScheduler
13
 
14
- from spiga.inference.config import ModelConfig
15
  from spiga.inference.framework import SPIGAFramework
16
  import spiga.demo.analyze.track.retinasort.config as cfg
17
 
@@ -26,7 +26,7 @@ face_detector = retinaface.RetinaFaceDetector(model=config['retina']['model_name
26
  extra_features=config['retina']['extra_features'],
27
  cfg_postreat=config['retina']['postreat'])
28
  # Landmark extraction
29
- spiga_extractor = SPIGAFramework(ModelConfig("300wpublic"))
30
 
31
  uncanny_controlnet = ControlNetModel.from_pretrained(
32
  "multimodalart/uncannyfaces_25K", torch_dtype=torch.float16
 
11
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
12
  from diffusers import UniPCMultistepScheduler
13
 
14
+ from modelconfig import ModelConfig
15
  from spiga.inference.framework import SPIGAFramework
16
  import spiga.demo.analyze.track.retinasort.config as cfg
17
 
 
26
  extra_features=config['retina']['extra_features'],
27
  cfg_postreat=config['retina']['postreat'])
28
  # Landmark extraction
29
+ spiga_extractor = SPIGAFramework(ModelConfig("300wpublic", False))
30
 
31
  uncanny_controlnet = ControlNetModel.from_pretrained(
32
  "multimodalart/uncannyfaces_25K", torch_dtype=torch.float16
modelconfig.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ from spiga.data.loaders.dl_config import DatabaseStruct
4
+
5
+ MODELS_URL = {
6
+ "wflw": "https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7",
7
+ "300wpublic": "https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC",
8
+ "300wprivate": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM",
9
+ "merlrav": "https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6",
10
+ "cofw68": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM",
11
+ }
12
+
13
+
14
+ class ModelConfig(object):
15
+
16
+ def __init__(self, dataset_name=None, load_model_url=True):
17
+ # Model configuration
18
+ self.model_weights = None
19
+ self.model_weights_path = "./"
20
+ self.load_model_url = load_model_url
21
+ self.model_weights_url = None
22
+ # Pretreatment
23
+ self.focal_ratio = 1.5 # Camera matrix focal length ratio.
24
+ self.target_dist = 1.6 # Target distance zoom in/out around face.
25
+ self.image_size = (256, 256)
26
+ # Outputs
27
+ self.ftmap_size = (64, 64)
28
+ # Dataset
29
+ self.dataset = None
30
+
31
+ if dataset_name is not None:
32
+ self.update_with_dataset(dataset_name)
33
+
34
+ def update_with_dataset(self, dataset_name):
35
+
36
+ config_dict = {
37
+ "dataset": DatabaseStruct(dataset_name),
38
+ "model_weights": "spiga_%s.pt" % dataset_name,
39
+ }
40
+
41
+ if dataset_name == "cofw68": # Test only
42
+ config_dict["model_weights"] = "spiga_300wprivate.pt"
43
+
44
+ if self.load_model_url:
45
+ config_dict["model_weights_url"] = MODELS_URL[dataset_name]
46
+
47
+ self.update(config_dict)
48
+
49
+ def update(self, params_dict):
50
+ state_dict = self.state_dict()
51
+ for k, v in params_dict.items():
52
+ if k in state_dict or hasattr(self, k):
53
+ setattr(self, k, v)
54
+ else:
55
+ raise Warning("Unknown option: {}: {}".format(k, v))
56
+
57
+ def state_dict(self):
58
+ state_dict = OrderedDict()
59
+ for k in self.__dict__.keys():
60
+ if not k.startswith("_"):
61
+ state_dict[k] = getattr(self, k)
62
+ return state_dict
spiga_300wpublic.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98f014611ac25d549e89083992d9e6ade15da133c634a3883473abf2953cee2d
3
+ size 254397265