import gradio as gr from datasets import load_dataset from PIL import Image from collections import OrderedDict from random import sample import csv from transformers import AutoFeatureExtractor, AutoModelForImageClassification import random feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") title="ImageNet Roulette" description="Try guessing the category of each image displayed, from the options provided below.\ After 10 guesses, we will show you your accuracy!\ " classdict = OrderedDict() for line in open('LOC_synset_mapping.txt', 'r').readlines(): try: classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0] except: continue classes = list(classdict.values()) imagedict={} with open('image_labels.csv', 'r') as csv_file: reader = csv.DictReader(csv_file) for row in reader: imagedict[row['image_name']] = row['image_label'] images= list(imagedict.keys()) labels = list(set(imagedict.values())) def model_classify(radio, im): if radio is not None: inputs = feature_extractor(images=im, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx], True else: return None, False def random_image(): imname = random.choice(images) im = Image.open('images/'+ imname +'.jpg') label = str(imagedict[imname]) print(label) labels.remove(label) options = sample(labels,3) options.append(label) random.shuffle(options) options = [classes[int(i)] for i in options] return im, label, gr.Radio.update(value=None, choices=options), None def check_score(pred, truth, current_score, total_score, has_guessed): if not(has_guessed): if pred == "lion": total_score +=1 return current_score + 1, f"Your score is {current_score+1} out of {total_score}", total_score else: if pred is not None: total_score +=1 return current_score, f"Your score is {current_score} out of {total_score}", total_score else: return current_score, f"Your score is {current_score} out of {total_score}", total_score def compare_score(userclass, prediction): if userclass == str(prediction).split(',')[0]: return "Great! You and the model agree on the category" return "You and the model disagree" with gr.Blocks() as demo: user_score = gr.State(0) model_score = gr.State(0) image_label = gr.State() prediction = gr.State() total_score = gr.State(0) has_guessed = gr.State(False) with gr.Row(): with gr.Column(min_width= 900): image = gr.Image(shape=(600, 600)) radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True) with gr.Column(): prediction = gr.Label(label="Model Prediction") score = gr.Label(label="Your Score") #message = gr.Text() btn = gr.Button("Next image") demo.load(random_image, None, [image, image_label, radio, prediction]) radio.change(model_classify, [radio, image], [prediction, has_guessed]) radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score]) #radio.change(compare_score, [radio, prediction], message) btn.click(random_image, None, [image, image_label, radio, prediction]) btn.click(lambda :False, None, has_guessed) demo.launch()