JosephTK commited on
Commit
22c47ed
1 Parent(s): 189df2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -9
app.py CHANGED
@@ -7,6 +7,18 @@ import torch
7
  image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
8
  model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')
9
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def detect(image):
11
  inputs = image_processor(images=image, return_tensors="pt")
12
  outputs = model(**inputs)
@@ -22,23 +34,32 @@ def detect(image):
22
  # label and the count
23
  counts = {}
24
 
25
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
- box = [round(i, 4) for i in box.tolist()]
27
  label_name = model.config.id2label[label.item()]
28
  if label_name not in counts:
29
  counts[label_name] = 0
30
  counts[label_name] += 1
31
 
32
- x1, y1, x2, y2 = tuple(box)
33
- draw.rectangle((x1, y1, x2, y2), outline=(128, 128, 50), width=2)
34
- draw.text((x1, y1), label_name, fill="white")
 
 
 
 
 
 
 
 
 
 
35
 
36
  df = pd.DataFrame({
37
- 'label': [label for label in counts],
38
- 'counts': [counts[label] for label in counts]
39
  })
40
-
41
- return image, df, counts
42
 
43
  demo = gr.Interface(
44
  fn=detect,
 
7
  image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
8
  model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')
9
 
10
+ colors = ["red",
11
+ "orange",
12
+ "yellow",
13
+ "green",
14
+ "blue",
15
+ "indigo",
16
+ "violet",
17
+ "brown",
18
+ "black",
19
+ "slategray",
20
+ ]
21
+
22
  def detect(image):
23
  inputs = image_processor(images=image, return_tensors="pt")
24
  outputs = model(**inputs)
 
34
  # label and the count
35
  counts = {}
36
 
37
+ for score, label in zip(results["scores"], results["labels"]):
 
38
  label_name = model.config.id2label[label.item()]
39
  if label_name not in counts:
40
  counts[label_name] = 0
41
  counts[label_name] += 1
42
 
43
+ count_results = {k: v for k, v in sorted(counts.items(), key=lambda item: item[1])}[:10]
44
+ label2color = {}
45
+ for idx, label in enumerate(count_results):
46
+ label2color[label] = colors[idx]
47
+
48
+ for label, box in zip(results["labels"], results["boxes"]):
49
+ label_name = model.config.id2label[label.item()]
50
+
51
+ if label_name in count_results:
52
+ box = [round(i, 4) for i in box.tolist()]
53
+ x1, y1, x2, y2 = tuple(box)
54
+ draw.rectangle((x1, y1, x2, y2), outline=label2color[label_name], width=2)
55
+ draw.text((x1, y1), label_name, fill="white")
56
 
57
  df = pd.DataFrame({
58
+ 'label': [label for label in count_results],
59
+ 'counts': [counts[label] for label in count_results]
60
  })
61
+
62
+ return image, df, count_results
63
 
64
  demo = gr.Interface(
65
  fn=detect,