import gradio as gr import requests import random from src.classification_model import ClassificationModel from src.util.extract import extract_image_urls #only for dummy data # response = requests.get("https://git.io/JJkYN") # labels = response.text.split("\n") print('start...') clf = ClassificationModel() model_names = clf.get_model_names() output_labels = [] output_images = [] max_input_image = 10 def predict(models, img_url, img_files): print(f'model choosen: {models}') model_predictions = {} #set all labels visibility to false for label in output_labels: model_predictions[label] = gr.Label(label=f'# {name}', visible=False) #set all images visibility yo hidden for img in output_images: model_predictions[img] = gr.Image(visible=False) sources = extract_image_urls(img_url) + (img_files or []) for i, source in enumerate(sources): print(f'{i} type: {type(source)} --> {source}') if i >= max_input_image: break for j, m in enumerate(models): results = clf.classify(m, source) print(f'{m} --> {results}') idx = j + (len(model_names)*i) #getting index of label label_value = {raw.class_name: raw.confidence for raw in results} model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=label_value, visible=True) model_predictions[output_images[i]] = gr.Image(visible=True, value=source, label=f'image {i}') # set image visibility to true return model_predictions with gr.Blocks() as demo: gr.Markdown("# Image Classification Benchmark") gr.Markdown("You can input at maximum 10 images at once (urls or files)") with gr.Row(): with gr.Column(scale=1): model = gr.Dropdown(choices=model_names, multiselect=True, label='Choose the model') img_urls = gr.Textbox(label='Image Urls (separated with comma)') img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image']) apply = gr.Button("Classify", variant='primary') with gr.Column(scale=1): for i in range(max_input_image): output_images.append(gr.Image(interactive=False, visible= (i==0))) for name in clf.get_model_names(): output_labels.append(gr.Label(label=f'# {name}', visible= (i==0))) apply.click(fn=predict, inputs=[model, img_urls, img_files], outputs=output_images+output_labels) # demo.launch() demo.queue().launch()