alexkueck commited on
Commit
5d078f9
·
1 Parent(s): a7ac19b

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +14 -0
utils.py CHANGED
@@ -201,6 +201,20 @@ def compute_metrics(eval_pred):
201
  #Before passing your predictions to compute, you need to convert the predictions to logits (remember all Transformers models return logits):
202
  return metric.compute(predictions=predictions, references=labels)
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def convert_to_markdown(text):
206
  text = text.replace("$","$")
 
201
  #Before passing your predictions to compute, you need to convert the predictions to logits (remember all Transformers models return logits):
202
  return metric.compute(predictions=predictions, references=labels)
203
 
204
+
205
+ def compute_metrics2(p):
206
+ pred, labels = p
207
+ pred = np.argmax(pred, axis=1)
208
+
209
+ accuracy = accuracy_score(y_true=labels, y_pred=pred)
210
+ recall = recall_score(y_true=labels, y_pred=pred)
211
+ precision = precision_score(y_true=labels, y_pred=pred)
212
+ f1 = f1_score(y_true=labels, y_pred=pred)
213
+
214
+ return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
215
+
216
+
217
+
218
 
219
  def convert_to_markdown(text):
220
  text = text.replace("$","$")