Kevin Fink commited on
Commit
a479880
·
1 Parent(s): e1dcc24
Files changed (2) hide show
  1. app.py +18 -8
  2. requirements.txt +1 -0
app.py CHANGED
@@ -8,6 +8,7 @@ from sklearn.metrics import accuracy_score
8
  import numpy as np
9
  import torch
10
  import os
 
11
  from huggingface_hub import login
12
  from peft import get_peft_model, LoraConfig
13
 
@@ -28,14 +29,23 @@ model.save_pretrained(model_save_path)
28
 
29
  def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
30
  try:
31
- def compute_metrics(eval_pred):
32
- logits, labels = eval_pred
33
- predictions = np.argmax(logits, axis=1)
34
- accuracy = accuracy_score(labels, predictions)
35
- return {
36
- 'eval_accuracy': accuracy,
37
- 'eval_loss': eval_pred.loss, # If you want to include loss as well
38
- }
 
 
 
 
 
 
 
 
 
39
  login(api_key.strip())
40
 
41
 
 
8
  import numpy as np
9
  import torch
10
  import os
11
+ import evaluate
12
  from huggingface_hub import login
13
  from peft import get_peft_model, LoraConfig
14
 
 
29
 
30
  def fine_tune_model(model, dataset_name, hub_id, api_key, num_epochs, batch_size, lr, grad):
31
  try:
32
+ metric = evaluate.load("rouge", cache_dir='/cache')
33
+ def compute_metrics(eval_preds):
34
+ preds, labels = eval_preds
35
+ if isinstance(preds, tuple):
36
+ preds = preds[0]
37
+ # Replace -100s used for padding as we can't decode them
38
+ preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
39
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
40
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
41
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
42
+
43
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
44
+ result = {k: round(v * 100, 4) for k, v in result.items()}
45
+ prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
46
+ result["gen_len"] = np.mean(prediction_lens)
47
+ return result
48
+
49
  login(api_key.strip())
50
 
51
 
requirements.txt CHANGED
@@ -6,3 +6,4 @@ huggingface_hub
6
  scikit-learn
7
  numpy
8
  torch
 
 
6
  scikit-learn
7
  numpy
8
  torch
9
+ evaluate