karthickn commited on
Commit
0f7feb5
·
1 Parent(s): fae9073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -19
app.py CHANGED
@@ -4,21 +4,11 @@ from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from color import Color
5
  from color_wheel import ColorWheel
6
  from PIL import ImageDraw, ImageFont
7
- import numpy as np
8
 
9
- resnet_101_processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
10
- resnet_101_model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
11
- resnet_50_processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
12
- resnet_50_model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
13
-
14
- def process_image(image, margin, model):
15
- if model=='detr-resnet-101':
16
- processor = resnet_101_processor
17
- model = resnet_101_model
18
- else:
19
- processor = resnet_50_processor
20
- model = resnet_50_model
21
 
 
22
  if image is None:
23
  yield [None, None, None]
24
  return
@@ -39,7 +29,6 @@ def process_image(image, margin, model):
39
  index = 0
40
  gallery = []
41
  labels = []
42
- newlabel = {}
43
  drawImage = image.copy()
44
  draw = ImageDraw.Draw(drawImage)
45
  for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
@@ -53,9 +42,8 @@ def process_image(image, margin, model):
53
  draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
54
  gallery.append(image.crop((box[0], box[1], box[2], box[3])))
55
  labels.append(model.config.id2label[label.item()])
56
- newlabel[model.config.id2label[label.item()]] = 1
57
  index += 1
58
- yield [drawImage, gallery, newlabel, ','.join(labels)]
59
 
60
  app = gr.Interface(
61
  title='Object Detection for Image',
@@ -63,17 +51,14 @@ app = gr.Interface(
63
  inputs=[
64
  gr.Image(type='pil'),
65
  gr.Slider(maximum=100, step=1, label='margin'),
66
- gr.Radio(["detr-resnet-50", "detr-resnet-101"], value="detr-resnet-50", label="Select the model")
67
  ],
68
  outputs=[
69
  gr.Image(label='boxes', type='pil'),
70
  gr.Gallery(label='gallery', columns=8, height=140),
71
- gr.Label(label='scores'),
72
  gr.Textbox(label='text'),
73
  ],
74
  allow_flagging='never',
75
  examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
76
- # theme='HaleyCH/HaleyCH_Theme',
77
  #cache_examples=False
78
  )
79
  app.queue(concurrency_count=20)
 
4
  from color import Color
5
  from color_wheel import ColorWheel
6
  from PIL import ImageDraw, ImageFont
 
7
 
8
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
9
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def process_image(image, margin):
12
  if image is None:
13
  yield [None, None, None]
14
  return
 
29
  index = 0
30
  gallery = []
31
  labels = []
 
32
  drawImage = image.copy()
33
  draw = ImageDraw.Draw(drawImage)
34
  for score, label, box in zip(results['scores'], results['labels'], results['boxes']):
 
42
  draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline=colors[index], width=4)
43
  gallery.append(image.crop((box[0], box[1], box[2], box[3])))
44
  labels.append(model.config.id2label[label.item()])
 
45
  index += 1
46
+ yield [drawImage, gallery, ','.join(labels)]
47
 
48
  app = gr.Interface(
49
  title='Object Detection for Image',
 
51
  inputs=[
52
  gr.Image(type='pil'),
53
  gr.Slider(maximum=100, step=1, label='margin'),
 
54
  ],
55
  outputs=[
56
  gr.Image(label='boxes', type='pil'),
57
  gr.Gallery(label='gallery', columns=8, height=140),
 
58
  gr.Textbox(label='text'),
59
  ],
60
  allow_flagging='never',
61
  examples=[['examples/Wild.jpg', 0], ['examples/Football-Match.jpg', 0]],
 
62
  #cache_examples=False
63
  )
64
  app.queue(concurrency_count=20)