wangjin2000 commited on
Commit
6296772
·
verified ·
1 Parent(s): d9e008a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -60,7 +60,7 @@ def compute_metrics(p):
60
 
61
  return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
62
 
63
- def compute_loss(model, inputs, class_weights): #compute_loss(model, inputs): add class_weights as input, jw 20240628
64
  """Custom compute_loss function."""
65
  logits = model(**inputs).logits
66
  labels = inputs["labels"]
@@ -76,11 +76,12 @@ def compute_loss(model, inputs, class_weights): #compute_loss(model, inputs): a
76
  # Define Custom Trainer Class
77
  # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
78
  class WeightedTrainer(Trainer):
79
- def compute_loss(self, model, inputs, class_weights, return_outputs=False): #add class_weights as input, jw 20240628
80
  outputs = model(**inputs)
81
- loss = compute_loss(model, inputs, class_weights) #add class_weights as input, jw 20240628
82
  return (loss, outputs) if return_outputs else loss
83
-
 
84
  # fine-tuning function
85
  def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
86
 
@@ -196,8 +197,7 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
196
  eval_dataset=test_dataset,
197
  tokenizer=tokenizer,
198
  data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
199
- compute_metrics=compute_metrics,
200
- class_weights=class_weights, #add class_weights as input, jw 20240628
201
  )
202
 
203
  # Train and Save Model
 
60
 
61
  return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
62
 
63
+ def compute_loss(model, inputs):
64
  """Custom compute_loss function."""
65
  logits = model(**inputs).logits
66
  labels = inputs["labels"]
 
76
  # Define Custom Trainer Class
77
  # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
78
  class WeightedTrainer(Trainer):
79
+ def compute_loss(self, model, inputs, return_outputs=False):
80
  outputs = model(**inputs)
81
+ loss = compute_loss(model, inputs)
82
  return (loss, outputs) if return_outputs else loss
83
+
84
+ #
85
  # fine-tuning function
86
  def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
87
 
 
197
  eval_dataset=test_dataset,
198
  tokenizer=tokenizer,
199
  data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
200
+ compute_metrics=compute_metrics
 
201
  )
202
 
203
  # Train and Save Model