Spaces:
Build error
Build error
| import gradio as gr | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from collections import OrderedDict | |
| from random import sample | |
| import csv | |
| from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
| import random | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
| model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") | |
| classdict = OrderedDict() | |
| for line in open('LOC_synset_mapping.txt', 'r').readlines(): | |
| try: | |
| classdict[line.split(' ')[0]]= ' '.join(line.split(' ')[1:]).replace('\n','').split(',')[0] | |
| except: | |
| continue | |
| classes = list(classdict.values()) | |
| imagedict={} | |
| with open('image_labels.csv', 'r') as csv_file: | |
| reader = csv.DictReader(csv_file) | |
| for row in reader: | |
| imagedict[row['image_name']] = row['image_label'] | |
| images= list(imagedict.keys()) | |
| labels = list(set(imagedict.values())) | |
| def model_classify(radio, im): | |
| if radio is not None: | |
| inputs = feature_extractor(images=im, return_tensors="pt") | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| modelclass=model.config.id2label[predicted_class_idx] | |
| return modelclass.split(',')[0], predicted_class_idx, True | |
| else: | |
| return None, None, False | |
| def random_image(): | |
| imname = random.choice(images) | |
| im = Image.open('images/'+ imname +'.jpg') | |
| label = str(imagedict[imname]) | |
| labels.remove(label) | |
| options = sample(labels,3) | |
| options.append(label) | |
| random.shuffle(options) | |
| options = [classes[int(i)] for i in options] | |
| return im, label, gr.Radio.update(value=None, choices=options), None | |
| def check_score(pred, truth, current_score, total_score, has_guessed): | |
| if not(has_guessed): | |
| if pred == classes[int(truth)]: | |
| total_score +=1 | |
| return current_score + 1, f"Your score is {current_score+1} out of {total_score}!", total_score | |
| else: | |
| if pred is not None: | |
| total_score +=1 | |
| return current_score, f"Your score is {current_score} out of {total_score}!", total_score | |
| else: | |
| return current_score, f"Your score is {current_score} out of {total_score}!", total_score | |
| def compare_score(userclass, truth): | |
| if userclass is None: | |
| return"Try guessing a category!" | |
| else: | |
| if userclass == classes[int(truth)]: | |
| return "Great! You guessed it right" | |
| else: | |
| return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image." | |
| with gr.Blocks() as demo: | |
| user_score = gr.State(0) | |
| model_score = gr.State(0) | |
| image_label = gr.State() | |
| model_class = gr.State() | |
| total_score = gr.State(0) | |
| has_guessed = gr.State(False) | |
| gr.Markdown("# ImageNet Quiz") | |
| gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.") | |
| gr.Markdown("### But many of its categories are hard to guess, even for humans.") | |
| gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!") | |
| with gr.Row(): | |
| with gr.Column(min_width= 900): | |
| image = gr.Image(shape=(600, 600)) | |
| radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True) | |
| with gr.Column(): | |
| prediction = gr.Label(label="The AI model predicts:") | |
| score = gr.Label(label="Your Score") | |
| message = gr.Label(label="Did you guess it right?") | |
| btn = gr.Button("Next image") | |
| demo.load(random_image, None, [image, image_label, radio, prediction]) | |
| radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed]) | |
| radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score]) | |
| radio.change(compare_score, [radio, image_label], message) | |
| btn.click(random_image, None, [image, image_label, radio, prediction]) | |
| btn.click(lambda :False, None, has_guessed) | |
| demo.launch() | |