sashavor commited on
Commit
a0a2e9f
1 Parent(s): 9e1d8e2

trying total score

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -31,18 +31,19 @@ with open('image_labels.csv', 'r') as csv_file:
31
  images= list(imagedict.keys())
32
  labels = list(set(imagedict.values()))
33
 
34
- def model_classify(im):
35
- inputs = feature_extractor(images=im, return_tensors="pt")
36
- outputs = model(**inputs)
37
- logits = outputs.logits
38
- predicted_class_idx = logits.argmax(-1).item()
39
- return model.config.id2label[predicted_class_idx]
40
-
41
 
42
  def random_image():
43
  imname = random.choice(images)
44
  im = Image.open('images/'+ imname +'.jpg')
45
  label = str(imagedict[imname])
 
46
  labels.remove(label)
47
  options = sample(labels,3)
48
  options.append(label)
@@ -50,14 +51,16 @@ def random_image():
50
  options = [classes[int(i)] for i in options]
51
  return im, label, gr.Radio.update(value=None, choices=options), None
52
 
53
- def check_score(pred, truth, current_score):
54
  if pred == classes[int(truth)]:
55
- return current_score + 1, f"Your score is {current_score+1}"
56
- return current_score, f"Your score is {current_score}"
 
 
 
 
57
 
58
  def compare_score(userclass, prediction):
59
- print(userclass)
60
- print(prediction)
61
  if userclass == str(prediction).split(',')[0]:
62
  return "Great! You and the model agree on the category"
63
  return "You and the model disagree"
@@ -67,10 +70,11 @@ with gr.Blocks() as demo:
67
  model_score = gr.State(0)
68
  image_label = gr.State()
69
  prediction = gr.State()
 
70
 
71
  with gr.Row():
72
  with gr.Column(min_width= 900):
73
- image = gr.Image(shape=(224, 224))
74
  radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
75
  with gr.Column():
76
  prediction = gr.Label(label="Model Prediction")
@@ -80,8 +84,8 @@ with gr.Blocks() as demo:
80
  btn = gr.Button("Next image")
81
 
82
  demo.load(random_image, None, [image, image_label, radio, prediction])
83
- radio.change(model_classify, image, prediction)
84
- radio.change(check_score, [radio, image_label, user_score], [user_score, score])
85
  #radio.change(compare_score, [radio, prediction], message)
86
  btn.click(random_image, None, [image, image_label, radio, prediction])
87
 
 
31
  images= list(imagedict.keys())
32
  labels = list(set(imagedict.values()))
33
 
34
+ def model_classify(radio, im):
35
+ if radio is not None:
36
+ inputs = feature_extractor(images=im, return_tensors="pt")
37
+ outputs = model(**inputs)
38
+ logits = outputs.logits
39
+ predicted_class_idx = logits.argmax(-1).item()
40
+ return model.config.id2label[predicted_class_idx]
41
 
42
  def random_image():
43
  imname = random.choice(images)
44
  im = Image.open('images/'+ imname +'.jpg')
45
  label = str(imagedict[imname])
46
+ print(label)
47
  labels.remove(label)
48
  options = sample(labels,3)
49
  options.append(label)
 
51
  options = [classes[int(i)] for i in options]
52
  return im, label, gr.Radio.update(value=None, choices=options), None
53
 
54
+ def check_score(pred, truth, current_score, total_score):
55
  if pred == classes[int(truth)]:
56
+ total_score +=1
57
+ return current_score + 1, f"Your score is {current_score+1} out of {total_score}"
58
+ else:
59
+ total_score +=1
60
+ return current_score, f"Your score is {current_score} out of {total_score}"
61
+
62
 
63
  def compare_score(userclass, prediction):
 
 
64
  if userclass == str(prediction).split(',')[0]:
65
  return "Great! You and the model agree on the category"
66
  return "You and the model disagree"
 
70
  model_score = gr.State(0)
71
  image_label = gr.State()
72
  prediction = gr.State()
73
+ total_score = gr.State(0)
74
 
75
  with gr.Row():
76
  with gr.Column(min_width= 900):
77
+ image = gr.Image(shape=(600, 600))
78
  radio = gr.Radio(["option1", "option2", "option3"], label="Pick a category", interactive=True)
79
  with gr.Column():
80
  prediction = gr.Label(label="Model Prediction")
 
84
  btn = gr.Button("Next image")
85
 
86
  demo.load(random_image, None, [image, image_label, radio, prediction])
87
+ radio.change(model_classify, [radio, image], prediction)
88
+ radio.change(check_score, [radio, image_label, user_score, total_score], [user_score, score])
89
  #radio.change(compare_score, [radio, prediction], message)
90
  btn.click(random_image, None, [image, image_label, radio, prediction])
91