Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -60,7 +60,7 @@ def compute_metrics(p):
|
|
60 |
|
61 |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
62 |
|
63 |
-
def compute_loss(model, inputs):
|
64 |
"""Custom compute_loss function."""
|
65 |
logits = model(**inputs).logits
|
66 |
labels = inputs["labels"]
|
@@ -76,9 +76,9 @@ def compute_loss(model, inputs):
|
|
76 |
# Define Custom Trainer Class
|
77 |
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
78 |
class WeightedTrainer(Trainer):
|
79 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
80 |
outputs = model(**inputs)
|
81 |
-
loss = compute_loss(model, inputs)
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
# fine-tuning function
|
@@ -185,7 +185,8 @@ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset)
|
|
185 |
seed=8893,
|
186 |
fp16=True,
|
187 |
#report_to='wandb'
|
188 |
-
report_to=None
|
|
|
189 |
)
|
190 |
|
191 |
# Initialize Trainer
|
|
|
60 |
|
61 |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
|
62 |
|
63 |
+
def compute_loss(model, inputs, class_weights): #compute_loss(model, inputs): add class_weights as input, jw 20240628
|
64 |
"""Custom compute_loss function."""
|
65 |
logits = model(**inputs).logits
|
66 |
labels = inputs["labels"]
|
|
|
76 |
# Define Custom Trainer Class
|
77 |
# Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
|
78 |
class WeightedTrainer(Trainer):
|
79 |
+
def compute_loss(self, model, inputs, class_weights, return_outputs=False): #add class_weights as input, jw 20240628
|
80 |
outputs = model(**inputs)
|
81 |
+
loss = compute_loss(model, inputs, class_weights) #add class_weights as input, jw 20240628
|
82 |
return (loss, outputs) if return_outputs else loss
|
83 |
|
84 |
# fine-tuning function
|
|
|
185 |
seed=8893,
|
186 |
fp16=True,
|
187 |
#report_to='wandb'
|
188 |
+
report_to=None,
|
189 |
+
class_weights=class_weights, #jw, 20240628
|
190 |
)
|
191 |
|
192 |
# Initialize Trainer
|