File size: 3,722 Bytes
f37cfad
 
 
 
 
 
0490df4
27938aa
0490df4
27938aa
0490df4
f37cfad
 
 
 
 
 
 
 
 
 
27938aa
f37cfad
 
27938aa
f37cfad
 
 
 
 
27938aa
 
f37cfad
a0a2e9f
 
 
 
 
 
f2c12ed
 
 
f37cfad
27938aa
 
 
 
a0a2e9f
27938aa
 
 
 
 
dc0eec0
820d8a7
f2c12ed
 
 
 
 
 
 
 
 
a0a2e9f
f2c12ed
 
a0a2e9f
f37cfad
27938aa
 
 
 
f37cfad
 
27938aa
 
 
 
a0a2e9f
f2c12ed
27938aa
f37cfad
9e1d8e2
a0a2e9f
27938aa
 
 
 
9e1d8e2
27938aa
 
f37cfad
27938aa
f2c12ed
 
9e1d8e2
27938aa
f2c12ed
f37cfad
dc0eec0
f37cfad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from datasets import load_dataset
from PIL import Image
from collections import OrderedDict
from random import sample
import csv
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import random

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")


title="ImageNet Roulette"
description="Try guessing the category of each image displayed, from the options provided below.\
            After 10 guesses, we will show you your accuracy!\
            "

classdict = OrderedDict()
for line in open('LOC_synset_mapping.txt', 'r').readlines():
    try:
        classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0]
    except:
        continue
classes = list(classdict.values())
imagedict={}
with open('image_labels.csv', 'r') as csv_file:
    reader = csv.DictReader(csv_file)
    for row in reader:
        imagedict[row['image_name']] = row['image_label']
images= list(imagedict.keys())
labels = list(set(imagedict.values()))

def model_classify(radio, im):
    if radio is not None:
        inputs = feature_extractor(images=im, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        return model.config.id2label[predicted_class_idx], True
    else:
        return None, False

def random_image():
    imname = random.choice(images)
    im = Image.open('images/'+ imname +'.jpg')
    label = str(imagedict[imname])
    print(label)
    labels.remove(label)
    options = sample(labels,3)
    options.append(label)
    random.shuffle(options)
    options = [classes[int(i)] for i in options]
    return im, label, gr.Radio.update(value=None, choices=options), None

def check_score(pred, truth, current_score, total_score, has_guessed):
    if not(has_guessed):
        if pred == "lion":
            total_score +=1
            return current_score + 1, f"Your score is {current_score+1} out of {total_score}", total_score
        else:
            if pred is not None:
                total_score +=1
            return current_score, f"Your score is {current_score} out of {total_score}", total_score
    else:
        return current_score, f"Your score is {current_score} out of {total_score}", total_score
        


def compare_score(userclass, prediction):
    if userclass == str(prediction).split(',')[0]:
        return "Great! You and the model agree on the category"
    return "You and the model disagree"

with gr.Blocks() as demo:
    user_score = gr.State(0)
    model_score = gr.State(0)
    image_label = gr.State()
    prediction = gr.State()
    total_score = gr.State(0)
    has_guessed = gr.State(False)

    with gr.Row():
        with gr.Column(min_width= 900):
            image = gr.Image(shape=(600, 600))
            radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
        with gr.Column():
            prediction = gr.Label(label="Model Prediction")
            score = gr.Label(label="Your Score")
            #message = gr.Text()

    btn = gr.Button("Next image")

    demo.load(random_image, None, [image, image_label, radio, prediction])
    radio.change(model_classify, [radio, image], [prediction, has_guessed])
    radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
    #radio.change(compare_score, [radio, prediction], message)
    btn.click(random_image, None, [image, image_label, radio, prediction])
    btn.click(lambda :False, None, has_guessed)


demo.launch()