leygit commited on
Commit
c9f11fd
·
verified ·
1 Parent(s): a0ed149

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -21
app.py CHANGED
@@ -136,31 +136,14 @@ def evaluate_model_with_report(val_loader):
136
 
137
  # Performance metrics
138
  def generate_performance_metrics():
139
- model.eval() # Set model to evaluation mode
140
-
141
- y_true = [] # True labels
142
- y_pred = [] # Predicted labels
143
-
144
- with torch.no_grad():
145
- for batch in val_loader:
146
- inputs = {key: val.to(device) for key, val in batch.items()}
147
- labels = inputs.pop("labels").to(device)
148
-
149
- outputs = model(**inputs)
150
- prediction = torch.argmax(outputs.logits, dim=1).item()
151
-
152
- y_true.append(label)
153
- y_pred.append(prediction)
154
-
155
- # Compute accuracy and classification report
156
- accuracy = accuracy_score(y_true, y_pred)
157
- report = classification_report(y_true, y_pred, output_dict=True)
158
-
159
  return {
160
  "accuracy": f"{accuracy:.2%}",
161
  "precision": f"{report['1']['precision']:.2%}",
162
  "recall": f"{report['1']['recall']:.2%}",
163
- "f1_score": f"{report['1']['f1-score']:.2%}",
164
  }
165
 
166
  # Gradio Interface
 
136
 
137
  # Performance metrics
138
  def generate_performance_metrics():
139
+ y_pred = model.predict(X_test)
140
+ accuracy = evaluate_model_with_report(val_loader)
141
+ report = classification_report(y_true, y_pred, target_names=["Ham", "Spam"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  return {
143
  "accuracy": f"{accuracy:.2%}",
144
  "precision": f"{report['1']['precision']:.2%}",
145
  "recall": f"{report['1']['recall']:.2%}",
146
+ "f1_score": f"{report['1']['f1-score']:.2%}"
147
  }
148
 
149
  # Gradio Interface