Antoine101 commited on
Commit
ec2a6ff
·
verified ·
1 Parent(s): 1828143

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -1,4 +1,38 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
@@ -10,8 +44,8 @@ DETR model finetuned on "anindya64/hardhat" for hard hats detection.
10
 
11
  demo = gr.Interface(
12
  fn=greet,
13
- inputs="text",
14
- outputs="text",
15
  title=title,
16
  description=description
17
  )
 
1
  import gradio as gr
2
+ from PIL import Image, ImageDraw
3
+ from transformers import pipeline
4
+
5
+
6
+ def plot_results(image, results, threshold=0.7):
7
+ image = Image.fromarray(np.uint8(image))
8
+ draw = ImageDraw.Draw(image)
9
+ for result in results:
10
+ score = result["score"]
11
+ label = result["label"]
12
+ box = list(result["box"].values())
13
+ if score > threshold:
14
+ x, y, x2, y2 = tuple(box)
15
+ draw.rectangle((x, y, x2, y2), outline="red", width=1)
16
+ draw.text((x, y), label, fill="white")
17
+ draw.text(
18
+ (x + 0.5, y - 0.5),
19
+ text=str(score),
20
+ fill="green" if score > 0.7 else "red",
21
+ )
22
+ return image
23
+
24
+ def predict(image):
25
+ # make the object detection pipeline
26
+ obj_detector = pipeline(
27
+ "object-detection", model="anindya64/detr-resnet-50-dc5-hardhat-finetuned"
28
+ )
29
+ results = obj_detector(train_dataset[0]["image"])
30
+ return plot_results(image)
31
+
32
+
33
+
34
+ results = obj_detector(image)
35
+ plot_results(image, results)
36
 
37
  def greet(name):
38
  return "Hello " + name + "!!"
 
44
 
45
  demo = gr.Interface(
46
  fn=greet,
47
+ inputs=gr.Image(type="filepath", label="Input Image"),
48
+ outputs="image",
49
  title=title,
50
  description=description
51
  )