| """ |
| OurNet Model Definition with ConvNeXt backbone |
| |
| This model is designed for image forgery detection using a ConvNeXt backbone |
| with dual projection heads and a detection head. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import timm |
|
|
|
|
| class OurNet(nn.Module): |
| def __init__(self, config=None): |
| super().__init__() |
|
|
| |
| if config is None: |
| backbone_name = "convnext_base" |
| n_features = 1024 |
| else: |
| backbone_name = config.get("backbone", {}).get("name", "convnext_base") |
| n_features = config.get("backbone", {}).get("n_features", 1024) |
|
|
| self.backbone = timm.create_model(backbone_name, pretrained=False) |
|
|
| |
| if hasattr(self.backbone, "head"): |
| self.n_features = self.backbone.head.in_features |
| self.backbone.head = nn.Identity() |
| elif hasattr(self.backbone, "fc"): |
| self.n_features = self.backbone.fc.in_features |
| self.backbone.fc = nn.Identity() |
| elif hasattr(self.backbone, "classifier"): |
| self.n_features = self.backbone.classifier.in_features |
| self.backbone.classifier = nn.Identity() |
| else: |
| raise ValueError("Unsupported backbone architecture") |
|
|
| |
| self.aux_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
| self.aux_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
|
|
| |
| self.det_fc1 = nn.Sequential( |
| nn.Linear(self.n_features, self.n_features), |
| nn.ReLU(), |
| nn.Linear(self.n_features, 128), |
| ) |
| self.det_fc2 = nn.Sequential( |
| nn.Linear(self.n_features, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.5), |
| nn.Linear(256, 1), |
| ) |
|
|
| def forward_det(self, x): |
| """Detection forward pass""" |
| feats = self.backbone.forward_features(x) |
| feats = feats.mean([-2, -1]) |
| homo_head = self.det_fc1(feats) |
| det_head = self.det_fc2(feats) |
| return homo_head, det_head |
|
|
| def forward_proj(self, x): |
| """Projection forward pass""" |
| feats = self.backbone.forward_features(x) |
| feats = feats.mean([-2, -1]) |
| heter_head = self.aux_fc1(feats) |
| homo_head = self.aux_fc2(feats) |
| return heter_head, homo_head |
|
|