UNIQ-DEV's picture
Update app.py
1c3554e verified
import gradio as gr
import requests
import random
from src.classification_model import ClassificationModel
from src.util.extract import extract_image_urls
#only for dummy data
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")
print('start...')
clf = ClassificationModel()
model_names = clf.get_model_names()
output_labels = []
output_images = []
max_input_image = 10
def predict(models, img_url, img_files):
print(f'model choosen: {models}')
model_predictions = {}
#set all labels visibility to false
for label in output_labels:
model_predictions[label] = gr.Label(label=f'# {name}', visible=False)
#set all images visibility yo hidden
for img in output_images:
model_predictions[img] = gr.Image(visible=False)
sources = extract_image_urls(img_url) + (img_files or [])
for i, source in enumerate(sources):
print(f'{i} type: {type(source)} --> {source}')
if i >= max_input_image: break
for j, m in enumerate(models):
results = clf.classify(m, source)
print(f'{m} --> {results}')
idx = j + (len(model_names)*i) #getting index of label
label_value = {raw.class_name: raw.confidence for raw in results}
model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=label_value, visible=True)
model_predictions[output_images[i]] = gr.Image(visible=True, value=source, label=f'image {i}') # set image visibility to true
return model_predictions
with gr.Blocks() as demo:
gr.Markdown("# Image Classification Benchmark")
gr.Markdown("You can input at maximum 10 images at once (urls or files)")
with gr.Row():
with gr.Column(scale=1):
model = gr.Dropdown(choices=model_names, multiselect=True, label='Choose the model')
img_urls = gr.Textbox(label='Image Urls (separated with comma)')
img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image'])
apply = gr.Button("Classify", variant='primary')
with gr.Column(scale=1):
for i in range(max_input_image):
output_images.append(gr.Image(interactive=False, visible= (i==0)))
for name in clf.get_model_names():
output_labels.append(gr.Label(label=f'# {name}', visible= (i==0)))
apply.click(fn=predict,
inputs=[model, img_urls, img_files],
outputs=output_images+output_labels)
# demo.launch()
demo.queue().launch()