File size: 2,611 Bytes
55e8762
c49a9ad
 
 
cebad5c
55e8762
c49a9ad
cebad5c
 
55e8762
1c3554e
c49a9ad
 
 
cebad5c
 
c49a9ad
cebad5c
c49a9ad
 
 
 
cebad5c
 
 
 
 
c49a9ad
cebad5c
 
 
 
 
 
 
 
 
 
 
 
 
c49a9ad
 
 
 
 
cebad5c
c49a9ad
 
 
 
 
 
 
 
cebad5c
 
 
 
c49a9ad
 
 
cebad5c
c49a9ad
 
74ffb90
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import requests
import random
from src.classification_model import ClassificationModel
from src.util.extract import extract_image_urls

#only for dummy data
# response = requests.get("https://git.io/JJkYN")
# labels = response.text.split("\n")

print('start...')
clf = ClassificationModel()
model_names = clf.get_model_names()
output_labels = []
output_images = []
max_input_image = 10

def predict(models, img_url, img_files):
    print(f'model choosen: {models}')
    model_predictions = {}

    #set all labels visibility to false
    for label in output_labels:
        model_predictions[label] = gr.Label(label=f'# {name}', visible=False)
    #set all images visibility yo hidden
    for img in output_images:
        model_predictions[img] = gr.Image(visible=False)
    
    sources = extract_image_urls(img_url) + (img_files or [])
    for i, source in enumerate(sources):
        print(f'{i} type: {type(source)} --> {source}')
        if i >= max_input_image: break 

        for j, m in enumerate(models):
            results = clf.classify(m, source)
            print(f'{m} --> {results}')

            idx = j + (len(model_names)*i) #getting index of label
            label_value = {raw.class_name: raw.confidence for raw in results}        
            model_predictions[output_labels[idx]] =  gr.Label(label=f'# {m}, 3 seconds', value=label_value, visible=True) 
            model_predictions[output_images[i]] = gr.Image(visible=True, value=source, label=f'image {i}') # set image visibility to true
    
    return model_predictions

with gr.Blocks() as demo:
    gr.Markdown("# Image Classification Benchmark")
    gr.Markdown("You can input at maximum 10 images at once (urls or files)")
    
    with gr.Row():
        with gr.Column(scale=1):
            model = gr.Dropdown(choices=model_names, multiselect=True, label='Choose the model')
            img_urls = gr.Textbox(label='Image Urls (separated with comma)')    
            img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image'])
            apply = gr.Button("Classify", variant='primary')
        with gr.Column(scale=1):
            for i in range(max_input_image):
                output_images.append(gr.Image(interactive=False, visible= (i==0)))
                for name in clf.get_model_names():
                    output_labels.append(gr.Label(label=f'# {name}', visible= (i==0)))                  

    apply.click(fn=predict,
               inputs=[model, img_urls, img_files],
               outputs=output_images+output_labels)


# demo.launch()
demo.queue().launch()