da03 commited on
Commit
e5d31e2
1 Parent(s): 313c68d
Files changed (1) hide show
  1. app.py +83 -80
app.py CHANGED
@@ -56,90 +56,93 @@ def predict_product(num1, num2):
56
  finished_per_model = {model_name: False for model_name in models}
57
  past_key_values_per_model = {model_name: None for model_name in models}
58
  predicted_annotations_per_model = {}
59
- for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
60
- # Ground Truth
61
- if not valid_input:
62
- ground_truth_annotations = [('Invalid Input!', None)]
63
- else:
64
- ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
65
- ground_truth_annotations = ground_truth_annotations[::-1]
66
- # Predicted
67
- for model_name in models:
68
- model = models[model_name]
69
- if finished_per_model[model_name]:
70
- continue
71
- if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
72
- continue
73
- generation_kwargs = {
74
- 'input_ids': generated_ids_per_model[model_name],
75
- 'max_new_tokens': 1,
76
- 'do_sample': False,
77
- 'past_key_values': past_key_values_per_model[model_name],
78
- 'return_dict_in_generate': True,
79
- 'use_cache': True
80
- }
81
- if step == 0:
82
- del generation_kwargs['past_key_values']
83
- outputs = model.generate(**generation_kwargs)
84
- generated_ids = outputs.sequences
85
- next_token_id = generated_ids[0, -1]
86
- #print (next_token_id)
87
-
88
- if next_token_id.item() == tokenizer.eos_token_id:
89
- finished_per_model[model_name] = True
90
- if valid_input:
91
- if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed):
92
- predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong'))
93
- continue
 
 
 
 
94
 
95
- generated_ids_per_model[model_name] = generated_ids
96
- past_key_values_per_model[model_name] = outputs.past_key_values
97
 
98
- output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
99
- predicted_digits_reversed = output_text.strip().split(' ')
100
-
101
- predicted_annotations = []
102
- is_correct_sofar = True
103
- if model_name == 'explicit':
104
- if '=' not in predicted_digits_reversed:
105
- predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
106
- predicted_digits_reversed = []
107
- else:
108
- equal_sign_position = predicted_digits_reversed.index('=')
109
- predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
110
- predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]
111
-
112
- for i in range(len(predicted_digits_reversed)):
113
- predicted_digit = predicted_digits_reversed[i]
114
- if not valid_input:
115
- is_correct_digit = None
116
- elif i >= len(ground_truth_digits_reversed):
117
- if predicted_digit == '0' and is_correct_sofar:
118
- is_correct_digit = True
119
  else:
120
- is_correct_digit = False
121
- else:
122
- ground_truth_digit = ground_truth_digits_reversed[i]
123
- if predicted_digit == ground_truth_digit:
124
- is_correct_digit = True
 
 
 
 
 
 
 
 
125
  else:
126
- is_correct_digit = False
127
- if not is_correct_digit:
128
- is_correct_sofar = False
129
- if is_correct_digit is None:
130
- predicted_annotations.append((predicted_digit, None))
131
- elif is_correct_digit:
132
- predicted_annotations.append((predicted_digit, "correct"))
133
- else:
134
- predicted_annotations.append((predicted_digit, "wrong"))
135
- predicted_annotations = predicted_annotations[::-1]
136
- predicted_annotations_per_model[model_name] = predicted_annotations
137
-
138
- predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
139
- predicted_annotations_nocot = predicted_annotations_per_model['no']
140
- predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
141
-
142
- yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
 
 
 
 
 
 
143
 
144
  color_map = {"correct": "green", "wrong": "red"}
145
 
 
56
  finished_per_model = {model_name: False for model_name in models}
57
  past_key_values_per_model = {model_name: None for model_name in models}
58
  predicted_annotations_per_model = {}
59
+ try:
60
+ for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
61
+ # Ground Truth
62
+ if not valid_input:
63
+ ground_truth_annotations = [('Invalid Input!', None)]
64
+ else:
65
+ ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
66
+ ground_truth_annotations = ground_truth_annotations[::-1]
67
+ # Predicted
68
+ for model_name in models:
69
+ model = models[model_name]
70
+ if finished_per_model[model_name]:
71
+ continue
72
+ if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
73
+ continue
74
+ generation_kwargs = {
75
+ 'input_ids': generated_ids_per_model[model_name],
76
+ 'max_new_tokens': 1,
77
+ 'do_sample': False,
78
+ 'past_key_values': past_key_values_per_model[model_name],
79
+ 'return_dict_in_generate': True,
80
+ 'use_cache': True
81
+ }
82
+ if step == 0:
83
+ del generation_kwargs['past_key_values']
84
+ outputs = model.generate(**generation_kwargs)
85
+ generated_ids = outputs.sequences
86
+ next_token_id = generated_ids[0, -1]
87
+ #print (next_token_id)
88
+
89
+ if next_token_id.item() == tokenizer.eos_token_id:
90
+ finished_per_model[model_name] = True
91
+ if valid_input:
92
+ if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed):
93
+ predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong'))
94
+ continue
95
+
96
+ generated_ids_per_model[model_name] = generated_ids
97
+ past_key_values_per_model[model_name] = outputs.past_key_values
98
 
99
+ output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
100
+ predicted_digits_reversed = output_text.strip().split(' ')
101
 
102
+ predicted_annotations = []
103
+ is_correct_sofar = True
104
+ if model_name == 'explicit':
105
+ if '=' not in predicted_digits_reversed:
106
+ predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
107
+ predicted_digits_reversed = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
+ equal_sign_position = predicted_digits_reversed.index('=')
110
+ predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
111
+ predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]
112
+
113
+ for i in range(len(predicted_digits_reversed)):
114
+ predicted_digit = predicted_digits_reversed[i]
115
+ if not valid_input:
116
+ is_correct_digit = None
117
+ elif i >= len(ground_truth_digits_reversed):
118
+ if predicted_digit == '0' and is_correct_sofar:
119
+ is_correct_digit = True
120
+ else:
121
+ is_correct_digit = False
122
  else:
123
+ ground_truth_digit = ground_truth_digits_reversed[i]
124
+ if predicted_digit == ground_truth_digit:
125
+ is_correct_digit = True
126
+ else:
127
+ is_correct_digit = False
128
+ if not is_correct_digit:
129
+ is_correct_sofar = False
130
+ if is_correct_digit is None:
131
+ predicted_annotations.append((predicted_digit, None))
132
+ elif is_correct_digit:
133
+ predicted_annotations.append((predicted_digit, "correct"))
134
+ else:
135
+ predicted_annotations.append((predicted_digit, "wrong"))
136
+ predicted_annotations = predicted_annotations[::-1]
137
+ predicted_annotations_per_model[model_name] = predicted_annotations
138
+
139
+ predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
140
+ predicted_annotations_nocot = predicted_annotations_per_model['no']
141
+ predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
142
+
143
+ yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
144
+ except Exception as e:
145
+ pass
146
 
147
  color_map = {"correct": "green", "wrong": "red"}
148