tomas-gajarsky commited on
Commit
78c79c8
1 Parent(s): a093564

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -12,11 +12,15 @@ from torch.nn.functional import cosine_similarity
12
  cfg = OmegaConf.load("config.merged.yml")
13
  analyzer = FaceAnalyzer(cfg.analyzer)
14
 
15
- def get_sim_dict_str(response: ImageData, pred_name: str = "verify", index: int = 0)-> str:
16
- base_emb = response.faces[index].preds[pred_name].logits
17
- sim_dict = {face.indx: cosine_similarity(base_emb, face.preds[pred_name].logits, dim=0).item() for face in response.faces}
18
- sim_dict_sort = dict(sorted(sim_dict.items(), key=operator.itemgetter(1),reverse=True))
19
- sim_dict_sort_str = str(sim_dict_sort)
 
 
 
 
20
  return sim_dict_sort_str
21
 
22
 
@@ -36,8 +40,8 @@ def inference(path_image: str) -> Tuple:
36
  deepfake_dict_str = str({face.indx: face.preds["deepfake"].label for face in response.faces})
37
  response_str = str(response)
38
 
39
- sim_dict_str_embed = get_sim_dict_str(response, pred_name="embed", index=0)
40
- sim_dict_str_verify = get_sim_dict_str(response, pred_name="verify", index=0)
41
 
42
  out_tuple = (pil_image, fer_dict_str, deepfake_dict_str, sim_dict_str_embed, sim_dict_str_verify, response_str)
43
  return out_tuple
 
12
  cfg = OmegaConf.load("config.merged.yml")
13
  analyzer = FaceAnalyzer(cfg.analyzer)
14
 
15
+ def gen_sim_dict_str(response: ImageData, pred_name: str = "verify", index: int = 0)-> str:
16
+ if len(response.faces) > 0
17
+ base_emb = response.faces[index].preds[pred_name].logits
18
+ sim_dict = {face.indx: cosine_similarity(base_emb, face.preds[pred_name].logits, dim=0).item() for face in response.faces}
19
+ sim_dict_sort = dict(sorted(sim_dict.items(), key=operator.itemgetter(1),reverse=True))
20
+ sim_dict_sort_str = str(sim_dict_sort)
21
+ else:
22
+ sim_dict_sort_str = ""
23
+
24
  return sim_dict_sort_str
25
 
26
 
 
40
  deepfake_dict_str = str({face.indx: face.preds["deepfake"].label for face in response.faces})
41
  response_str = str(response)
42
 
43
+ sim_dict_str_embed = gen_sim_dict_str(response, pred_name="embed", index=0)
44
+ sim_dict_str_verify = gen_sim_dict_str(response, pred_name="verify", index=0)
45
 
46
  out_tuple = (pil_image, fer_dict_str, deepfake_dict_str, sim_dict_str_embed, sim_dict_str_verify, response_str)
47
  return out_tuple