| import torch |
| import torch.nn as nn |
| from torchvision.models import swin_t |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| |
| class SwinClassifierConfig(PretrainedConfig): |
| model_type = "swin_classifier" |
| def __init__(self, num_classes=18, **kwargs): |
| super().__init__(**kwargs) |
| self.num_classes = num_classes |
|
|
| |
| class SwinClassifier(PreTrainedModel): |
| config_class = SwinClassifierConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.backbone = swin_t() |
| num_features = self.backbone.head.in_features |
| |
| self.backbone.head = nn.Sequential( |
| nn.Linear(num_features, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.5), |
| |
| nn.Linear(256, config.num_classes) |
| ) |
|
|
| def forward(self, x): |
| return self.backbone(x) |