keremberke commited on
Commit
651391d
1 Parent(s): 10f3130

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -22
app.py CHANGED
@@ -29,15 +29,17 @@ cls_model_id = DEFAULT_CLS_MODEL_ID
29
  def get_examples(task):
30
  examples = []
31
  Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True)
 
32
  for model_id in task_to_model_ids[task]:
33
  dataset_id = get_dataset_id_from_model_id(model_id)
34
  ds = load_dataset(dataset_id, name="mini")["validation"]
35
  for ind in range(min(2, len(ds))):
36
  jpeg_image_file = ds[ind]["image"]
37
- image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{ind}.jpg")
38
  jpeg_image_file.save(image_file_path, format='JPEG', quality=100)
39
  image_path = os.path.abspath(image_file_path)
40
  examples.append([image_path, model_id, 0.25])
 
41
  return examples
42
 
43
 
@@ -124,13 +126,23 @@ with gr.Blocks() as demo:
124
  with gr.Column():
125
  detect_output = gr.Image(label="Predictions:", interactive=False)
126
  with gr.Row():
127
- detect_examples = gr.Examples(
128
- det_examples,
129
- inputs=[detect_input, detect_model_id, detect_threshold],
130
- outputs=detect_output,
131
- fn=predict,
132
- cache_examples=False,
133
- )
 
 
 
 
 
 
 
 
 
 
134
  with gr.Tab("Segmentation"):
135
  with gr.Row():
136
  with gr.Column():
@@ -141,13 +153,23 @@ with gr.Blocks() as demo:
141
  with gr.Column():
142
  segment_output = gr.Image(label="Predictions:", interactive=False)
143
  with gr.Row():
144
- segment_examples = gr.Examples(
145
- seg_examples,
146
- inputs=[segment_input, segment_model_id, segment_threshold],
147
- outputs=segment_output,
148
- fn=predict,
149
- cache_examples=False,
150
- )
 
 
 
 
 
 
 
 
 
 
151
  with gr.Tab("Classification"):
152
  with gr.Row():
153
  with gr.Column():
@@ -160,13 +182,23 @@ with gr.Blocks() as demo:
160
  label="Predictions:", show_label=True, num_top_classes=5
161
  )
162
  with gr.Row():
163
- classify_examples = gr.Examples(
164
- cls_examples,
165
- inputs=[classify_input, classify_model_id, classify_threshold],
166
- outputs=classify_output,
167
- fn=predict,
168
- cache_examples=False,
169
- )
 
 
 
 
 
 
 
 
 
 
170
 
171
  detect_button.click(
172
  predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output
 
29
  def get_examples(task):
30
  examples = []
31
  Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True)
32
+ image_ind = 0
33
  for model_id in task_to_model_ids[task]:
34
  dataset_id = get_dataset_id_from_model_id(model_id)
35
  ds = load_dataset(dataset_id, name="mini")["validation"]
36
  for ind in range(min(2, len(ds))):
37
  jpeg_image_file = ds[ind]["image"]
38
+ image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{image_ind}.jpg")
39
  jpeg_image_file.save(image_file_path, format='JPEG', quality=100)
40
  image_path = os.path.abspath(image_file_path)
41
  examples.append([image_path, model_id, 0.25])
42
+ image_ind += 1
43
  return examples
44
 
45
 
 
126
  with gr.Column():
127
  detect_output = gr.Image(label="Predictions:", interactive=False)
128
  with gr.Row():
129
+ half_ind = int(len(det_examples) / 2)
130
+ with gr.Column():
131
+ detect_examples = gr.Examples(
132
+ det_examples[:half_ind],
133
+ inputs=[detect_input, detect_model_id, detect_threshold],
134
+ outputs=detect_output,
135
+ fn=predict,
136
+ cache_examples=False,
137
+ )
138
+ with gr.Column():
139
+ detect_examples = gr.Examples(
140
+ det_examples[:half_ind],
141
+ inputs=[detect_input, detect_model_id, detect_threshold],
142
+ outputs=detect_output,
143
+ fn=predict,
144
+ cache_examples=False,
145
+ )
146
  with gr.Tab("Segmentation"):
147
  with gr.Row():
148
  with gr.Column():
 
153
  with gr.Column():
154
  segment_output = gr.Image(label="Predictions:", interactive=False)
155
  with gr.Row():
156
+ half_ind = int(len(det_examples) / 2)
157
+ with gr.Column():
158
+ segment_examples = gr.Examples(
159
+ seg_examples[:half_ind],
160
+ inputs=[segment_input, segment_model_id, segment_threshold],
161
+ outputs=segment_output,
162
+ fn=predict,
163
+ cache_examples=False,
164
+ )
165
+ with gr.Column():
166
+ segment_examples = gr.Examples(
167
+ seg_examples[:half_ind],
168
+ inputs=[segment_input, segment_model_id, segment_threshold],
169
+ outputs=segment_output,
170
+ fn=predict,
171
+ cache_examples=False,
172
+ )
173
  with gr.Tab("Classification"):
174
  with gr.Row():
175
  with gr.Column():
 
182
  label="Predictions:", show_label=True, num_top_classes=5
183
  )
184
  with gr.Row():
185
+ half_ind = int(len(det_examples) / 2)
186
+ with gr.Column():
187
+ classify_examples = gr.Examples(
188
+ cls_examples[half_ind:],
189
+ inputs=[classify_input, classify_model_id, classify_threshold],
190
+ outputs=classify_output,
191
+ fn=predict,
192
+ cache_examples=False,
193
+ )
194
+ with gr.Column():
195
+ classify_examples = gr.Examples(
196
+ cls_examples[:half_ind],
197
+ inputs=[classify_input, classify_model_id, classify_threshold],
198
+ outputs=classify_output,
199
+ fn=predict,
200
+ cache_examples=False,
201
+ )
202
 
203
  detect_button.click(
204
  predict, inputs=[detect_input, detect_model_id, detect_threshold], outputs=detect_output