sashavor commited on
Commit
367e0cd
·
1 Parent(s): b444b38

final push for now

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -32,7 +32,8 @@ def model_classify(radio, im):
32
  outputs = model(**inputs)
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
35
- return model.config.id2label[predicted_class_idx], predicted_class_idx, True
 
36
  else:
37
  return None, None, False
38
 
@@ -61,15 +62,14 @@ def check_score(pred, truth, current_score, total_score, has_guessed):
61
 
62
 
63
 
64
- def compare_score(userclass, model_class, truth, has_guessed):
65
- print(model_class)
66
- prediction= classes[int(model_class)]
67
- if userclass == classes[int(truth)] == prediction:
68
- return "Great! You and the model both got the correct answer"
69
- elif userclass == classes[int(truth)]:
70
- return "Great! You guessed it right"
71
- elif prediction == classes[int(truth)]:
72
- return "The AI model got it right this time, try again!"
73
 
74
  with gr.Blocks() as demo:
75
  user_score = gr.State(0)
@@ -80,7 +80,9 @@ with gr.Blocks() as demo:
80
  has_guessed = gr.State(False)
81
 
82
  gr.Markdown("# ImageNet Quiz")
83
- gr.Markdown("Try guessing the category of each image displayed, from the options provided below.")
 
 
84
  with gr.Row():
85
 
86
  with gr.Column(min_width= 900):
@@ -89,14 +91,14 @@ with gr.Blocks() as demo:
89
  with gr.Column():
90
  prediction = gr.Label(label="The AI model predicts:")
91
  score = gr.Label(label="Your Score")
92
- message = gr.Text(label="Who guessed it right?")
93
 
94
  btn = gr.Button("Next image")
95
 
96
  demo.load(random_image, None, [image, image_label, radio, prediction])
97
  radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
98
  radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
99
- radio.change(compare_score, [radio, prediction, image_label, has_guessed], message)
100
  btn.click(random_image, None, [image, image_label, radio, prediction])
101
  btn.click(lambda :False, None, has_guessed)
102
 
 
32
  outputs = model(**inputs)
33
  logits = outputs.logits
34
  predicted_class_idx = logits.argmax(-1).item()
35
+ modelclass=model.config.id2label[predicted_class_idx]
36
+ return modelclass.split(',')[0], predicted_class_idx, True
37
  else:
38
  return None, None, False
39
 
 
62
 
63
 
64
 
65
+ def compare_score(userclass, truth, has_guessed):
66
+ if userclass is None:
67
+ return"Try guessing a category!"
68
+ else:
69
+ if userclass == classes[int(truth)]:
70
+ return "Great! You guessed it right"
71
+ else:
72
+ return "The right answer was " +str(classes[int(truth)])+ "! Try guessing the next image."
 
73
 
74
  with gr.Blocks() as demo:
75
  user_score = gr.State(0)
 
80
  has_guessed = gr.State(False)
81
 
82
  gr.Markdown("# ImageNet Quiz")
83
+ gr.Markdown("### ImageNet is one of the most popular datasets used for training and evaluating AI models.")
84
+ gr.Markdown("### But many of its categories are hard to guess, even for humans.")
85
+ gr.Markdown("#### Try your hand at guessing the category of each image displayed, from the options provided. Compare your answers to that of a neural network trained on the dataset, and see if you can do better!")
86
  with gr.Row():
87
 
88
  with gr.Column(min_width= 900):
 
91
  with gr.Column():
92
  prediction = gr.Label(label="The AI model predicts:")
93
  score = gr.Label(label="Your Score")
94
+ message = gr.Label(label="Did you guess it right?")
95
 
96
  btn = gr.Button("Next image")
97
 
98
  demo.load(random_image, None, [image, image_label, radio, prediction])
99
  radio.change(model_classify, [radio, image], [prediction, model_class, has_guessed])
100
  radio.change(check_score, [radio, image_label, user_score, total_score, has_guessed], [user_score, score, total_score])
101
+ radio.change(compare_score, [radio, image_label, has_guessed], message)
102
  btn.click(random_image, None, [image, image_label, radio, prediction])
103
  btn.click(lambda :False, None, has_guessed)
104