wangjin2000 commited on
Commit
e474658
·
verified ·
1 Parent(s): a8846d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -60,7 +60,7 @@ def compute_metrics(p):
60
 
61
  return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
62
 
63
- def compute_loss(model, inputs):
64
  """Custom compute_loss function."""
65
  logits = model(**inputs).logits
66
  labels = inputs["labels"]
@@ -76,9 +76,9 @@ def compute_loss(model, inputs):
76
  # Define Custom Trainer Class
77
  # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
78
  class WeightedTrainer(Trainer):
79
- def compute_loss(self, model, inputs, return_outputs=False):
80
  outputs = model(**inputs)
81
- loss = compute_loss(model, inputs)
82
  return (loss, outputs) if return_outputs else loss
83
 
84
  # fine-tuning function
@@ -185,7 +185,8 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
185
  seed=8893,
186
  fp16=True,
187
  #report_to='wandb'
188
- report_to=None
 
189
  )
190
 
191
  # Initialize Trainer
 
60
 
61
  return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
62
 
63
+ def compute_loss(model, inputs, class_weights): #compute_loss(model, inputs): add class_weights as input, jw 20240628
64
  """Custom compute_loss function."""
65
  logits = model(**inputs).logits
66
  labels = inputs["labels"]
 
76
  # Define Custom Trainer Class
77
  # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
78
  class WeightedTrainer(Trainer):
79
+ def compute_loss(self, model, inputs, class_weights, return_outputs=False): #add class_weights as input, jw 20240628
80
  outputs = model(**inputs)
81
+ loss = compute_loss(model, inputs, class_weights) #add class_weights as input, jw 20240628
82
  return (loss, outputs) if return_outputs else loss
83
 
84
  # fine-tuning function
 
185
  seed=8893,
186
  fp16=True,
187
  #report_to='wandb'
188
+ report_to=None,
189
+ class_weights=class_weights, #jw, 20240628
190
  )
191
 
192
  # Initialize Trainer