da03 commited on
Commit
7c08a74
·
1 Parent(s): 85fafa0
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -54,8 +54,11 @@ def predict_product(num1, num2):
54
  predicted_annotations_per_model = {}
55
  for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
  # Ground Truth
57
- ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
58
- ground_truth_annotations = ground_truth_annotations[::-1]
 
 
 
59
  # Predicted
60
  for model_name in models:
61
  model = models[model_name]
@@ -101,7 +104,9 @@ def predict_product(num1, num2):
101
 
102
  for i in range(len(predicted_digits_reversed)):
103
  predicted_digit = predicted_digits_reversed[i]
104
- if i >= len(ground_truth_digits_reversed):
 
 
105
  if predicted_digit == '0' and is_correct_sofar:
106
  is_correct_digit = True
107
  else:
@@ -114,7 +119,9 @@ def predict_product(num1, num2):
114
  is_correct_digit = False
115
  if not is_correct_digit:
116
  is_correct_sofar = False
117
- if is_correct_digit:
 
 
118
  predicted_annotations.append((predicted_digit, "correct"))
119
  else:
120
  predicted_annotations.append((predicted_digit, "wrong"))
 
54
  predicted_annotations_per_model = {}
55
  for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
56
  # Ground Truth
57
+ if not valid_input:
58
+ ground_truth_annotations = []
59
+ else:
60
+ ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
61
+ ground_truth_annotations = ground_truth_annotations[::-1]
62
  # Predicted
63
  for model_name in models:
64
  model = models[model_name]
 
104
 
105
  for i in range(len(predicted_digits_reversed)):
106
  predicted_digit = predicted_digits_reversed[i]
107
+ if not valid_input:
108
+ is_correct_digit = None
109
+ elif i >= len(ground_truth_digits_reversed):
110
  if predicted_digit == '0' and is_correct_sofar:
111
  is_correct_digit = True
112
  else:
 
119
  is_correct_digit = False
120
  if not is_correct_digit:
121
  is_correct_sofar = False
122
+ if is_correct_digit is None:
123
+ predicted_annotations.append((predicted_digit, None))
124
+ elif is_correct_digit:
125
  predicted_annotations.append((predicted_digit, "correct"))
126
  else:
127
  predicted_annotations.append((predicted_digit, "wrong"))