| import torch
|
| import torch.nn as nn
|
| from torchvision.models import swin_t
|
|
|
| class SwinTransformerMultiLabel(nn.Module):
|
| def __init__(self, num_classes):
|
| super(SwinTransformerMultiLabel, self).__init__()
|
| self.model = swin_t(weights="IMAGENET1K_V1")
|
|
|
|
|
| in_features = self.model.head.in_features
|
| self.model.head = nn.Linear(in_features, num_classes)
|
|
|
| def forward(self, x):
|
| x = self.model.features(x)
|
|
|
| print(f"πΉ Feature map shape before flattening: {x.shape}")
|
|
|
|
|
| x = x.mean(dim=[1, 2])
|
| print(f"πΉ Feature shape after GAP: {x.shape}")
|
|
|
| x = self.model.head(x)
|
| return x
|
|
|
|
|
| def main():
|
|
|
| num_classes = 2
|
|
|
|
|
| model = SwinTransformerMultiLabel(num_classes)
|
|
|
|
|
| model.eval()
|
|
|
|
|
| dummy_input = torch.randn(5, 3, 224, 224)
|
|
|
|
|
| output = model(dummy_input)
|
|
|
|
|
| print(f"β
Model output shape: {output.shape}")
|
|
|
|
|
| print(f"β
Model classification head: {model.model.head}")
|
|
|
|
|
| for batch_size in [1, 8, 16]:
|
| dummy_input = torch.randn(batch_size, 3, 224, 224)
|
| output = model(dummy_input)
|
| print(f"β
Batch Size {batch_size} -> Output Shape: {output.shape}")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|