Eshieh2 commited on
Commit
e4b15bb
·
1 Parent(s): ed503ed

add more to valid list. Allow top3. fix confidence percentages.

Browse files
Files changed (2) hide show
  1. app.py +23 -6
  2. valid.txt +6 -0
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import os
2
- os.environ['TF_USE_LEGACY_KERAS']='1'
3
  import gradio as gr
4
  import tensorflow as tf
5
  import numpy as np
6
  import requests
7
  import torch
 
 
8
 
9
  from huggingface_hub import snapshot_download
10
  from huggingface_hub import hf_hub_download
@@ -28,7 +29,11 @@ detector_path = hf_hub_download(repo_id= "eshieh2/jaguarhead",
28
  filename = "jaguarheadv5.pt")
29
  detector = torch.hub.load('ultralytics/yolov5', 'custom', path = detector_path)
30
 
 
 
31
  def classify_image(in_image):
 
 
32
  width,height = in_image.size
33
  heads = detector(in_image)
34
  masks = [] # tuple of box coords and string
@@ -37,7 +42,7 @@ def classify_image(in_image):
37
  w = x2 - x
38
  h = y2 - y
39
  inp = in_image.crop((x,y,x2,y2))
40
- inp = inp.resize((480,480))
41
  inp = np.array(inp)
42
  inp = np.reshape(inp,(-1, 480, 480, 3)).astype(np.float32)
43
  inp = np.divide(inp,255.0)
@@ -45,12 +50,24 @@ def classify_image(in_image):
45
  prediction = tf.squeeze(prediction)
46
  pred = {labels[i]: float(prediction[i]) for i in range(label_count)}
47
  #print(pred)
48
- max_key = max(pred, key=pred.get)
49
  rect = (int(x),int(y),int(x2),int(y2))
50
- if show_all or max_key.lower() in valid:
51
- masks.append((rect,f"{max_key}:{pct}"))
 
 
 
 
 
 
 
 
 
52
  else:
53
- masks.append((rect,f"unknown",))
 
 
 
 
54
  return (in_image,masks)
55
 
56
  image = gr.Image(type='pil')
 
1
  import os
 
2
  import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
  import requests
6
  import torch
7
+ import heapq
8
+ from PIL import Image
9
 
10
  from huggingface_hub import snapshot_download
11
  from huggingface_hub import hf_hub_download
 
29
  filename = "jaguarheadv5.pt")
30
  detector = torch.hub.load('ultralytics/yolov5', 'custom', path = detector_path)
31
 
32
+ topk = 3
33
+
34
  def classify_image(in_image):
35
+ if in_image is None:
36
+ return None
37
  width,height = in_image.size
38
  heads = detector(in_image)
39
  masks = [] # tuple of box coords and string
 
42
  w = x2 - x
43
  h = y2 - y
44
  inp = in_image.crop((x,y,x2,y2))
45
+ inp = inp.resize((480,480),Image.BILINEAR)
46
  inp = np.array(inp)
47
  inp = np.reshape(inp,(-1, 480, 480, 3)).astype(np.float32)
48
  inp = np.divide(inp,255.0)
 
50
  prediction = tf.squeeze(prediction)
51
  pred = {labels[i]: float(prediction[i]) for i in range(label_count)}
52
  #print(pred)
 
53
  rect = (int(x),int(y),int(x2),int(y2))
54
+ if topk is not None:
55
+ top = heapq.nlargest(topk,pred,key=pred.get)
56
+ label = ''
57
+ for t in top:
58
+ if show_all or t.lower() in valid:
59
+ if len(label) != 0:
60
+ label += ", "
61
+ label += f"{t}:{pred[t]:.3f}"
62
+ if len(label)==0:
63
+ label = 'unknown'
64
+ masks.append((rect,label))
65
  else:
66
+ max_key = max(pred, key=pred.get)
67
+ if show_all or max_key.lower() in valid:
68
+ masks.append((rect,f"{max_key}:{pred[max_key]}"))
69
+ else:
70
+ masks.append((rect,f"unknown",))
71
  return (in_image,masks)
72
 
73
  image = gr.Image(type='pil')
valid.txt CHANGED
@@ -1,7 +1,13 @@
1
  bagua
2
  guaraci
 
 
 
3
  marcela
4
  medrosa
 
5
  oxum
6
  patricia
 
7
  ti
 
 
1
  bagua
2
  guaraci
3
+ kasimir
4
+ manath
5
+ margo
6
  marcela
7
  medrosa
8
+ ousado
9
  oxum
10
  patricia
11
+ saseka
12
  ti
13
+ tusk