mold-detection-api / model.py
AdarshRajDS
Fix ConvNeXt checkpoint loading and Grad-CAM layer selection
7a5f7fb
import torch
import torch.nn as nn
from torchvision import models
def find_last_conv2d(module: nn.Module) -> nn.Conv2d | None:
"""
Returns the last nn.Conv2d found in a module traversal.
Important: we do NOT attach this as a child module on the model instance,
otherwise it becomes part of state_dict and breaks checkpoint loading.
"""
last = None
for m in module.modules():
if isinstance(m, nn.Conv2d):
last = m
return last
class MultiTaskResNet50(nn.Module):
def __init__(self, num_classes=9):
super().__init__()
self.backbone = models.resnet50(weights=None)
feat_dim = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
self.class_head = nn.Linear(feat_dim, num_classes)
self.bio_head = nn.Linear(feat_dim, 2)
def forward(self, x: torch.Tensor):
feats = self.backbone(x)
return {
"class": self.class_head(feats),
"bio": self.bio_head(feats),
}
class MultiTaskConvNeXt(nn.Module):
"""
ConvNeXt-Base backbone with two heads:
- N-class structural/mold classifier
- 2-class biological vs non-biological head
Mirrors the training setup from the ConvNeXt Kaggle notebook.
"""
def __init__(self, num_classes: int):
super().__init__()
# We load task-specific weights, so no ImageNet weights here.
self.backbone = models.convnext_base(weights=None)
# ConvNeXt classifier is [LayerNorm2d, Flatten, Linear]
feat_dim = self.backbone.classifier[2].in_features
self.backbone.classifier = nn.Identity()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.class_head = nn.Linear(feat_dim, num_classes)
self.bio_head = nn.Linear(feat_dim, 2)
self.dropout = nn.Dropout(p=0.1)
def forward(self, x: torch.Tensor):
feats = self.backbone.features(x)
feats = self.pool(feats)
feats = torch.flatten(feats, 1)
feats = self.dropout(feats)
return {
"class": self.class_head(feats),
"bio": self.bio_head(feats),
}