Spaces:
Running
Running
File size: 34,272 Bytes
faf90bc |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "f7c4548e-c6b2-48dc-9df1-58a7c06481d8",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from sklearn.metrics import classification_report, confusion_matrix, accuracy_score\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9bd4822d-3ccd-416d-a03f-a0eda7b2bfaa",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Import custom modules\n",
"from models.resnet_model import MalariaResNet50\n",
"from data_prep import get_dataloaders"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "91ac3fe5-c5d7-497d-9a6d-2903f6b97bf7",
"metadata": {},
"outputs": [],
"source": [
"MODEL_PATH = 'models/malaria_model.pth'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1c7e9e80-2986-4979-8450-6821e6e0a3a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Total Classes: 2\n",
"Train batches: 689, Val batches: 87, Test batches: 87\n",
"Using device: cuda\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.12/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model loaded from models/malaria_model.pth\n",
"Running inference on test set...\n",
"\n",
"Test Accuracy: 0.9699\n",
"\n",
"Classification Report:\n",
" precision recall f1-score support\n",
"\n",
" Parasitized 0.97 0.97 0.97 1378\n",
" Uninfected 0.97 0.97 0.97 1378\n",
"\n",
" accuracy 0.97 2756\n",
" macro avg 0.97 0.97 0.97 2756\n",
"weighted avg 0.97 0.97 0.97 2756\n",
"\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 600x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def evaluate():\n",
" # Get all loaders and datasets\n",
" train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_dataloaders()\n",
"\n",
" # Define device\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" print(f\"Using device: {device}\")\n",
"\n",
" # Initialize and load model\n",
" model = MalariaResNet50(num_classes=2)\n",
" model.load(MODEL_PATH)\n",
" model = model.to(device)\n",
" model.eval() # Set to evaluation mode\n",
"\n",
" # Get test data\n",
" y_true = []\n",
" y_pred = []\n",
"\n",
" print(\"Running inference on test set...\")\n",
" with torch.no_grad():\n",
" for inputs, labels in test_loader:\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" outputs = model(inputs)\n",
" _, preds = torch.max(outputs, 1)\n",
"\n",
" y_true.extend(labels.cpu().numpy())\n",
" y_pred.extend(preds.cpu().numpy())\n",
"\n",
" # -----------------------------\n",
" # Compute Metrics\n",
" # -----------------------------\n",
" classes = test_dataset.classes # ['uninfected', 'parasitized']\n",
"\n",
" # Accuracy\n",
" acc = accuracy_score(y_true, y_pred)\n",
" print(f\"\\nTest Accuracy: {acc:.4f}\")\n",
"\n",
" # Classification Report\n",
" print(\"\\nClassification Report:\")\n",
" print(classification_report(y_true, y_pred, target_names=classes))\n",
"\n",
" # Confusion Matrix\n",
" cm = confusion_matrix(y_true, y_pred)\n",
"\n",
" plt.figure(figsize=(6, 5))\n",
" sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\", xticklabels=classes, yticklabels=classes)\n",
" plt.xlabel(\"Predicted\")\n",
" plt.ylabel(\"True\")\n",
" plt.title(\"Confusion Matrix\")\n",
" plt.show()\n",
"\n",
"if __name__ == '__main__':\n",
" evaluate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3b8ec551-2713-4b2e-b33c-cdc7930f8c54",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|