import torch from torchvision.models.video import swin3d_t from transformers import PreTrainedModel from .configuration_swin3d_tiny import Swin3DTinyConfig class Swin3DTiny(PreTrainedModel): config_class = Swin3DTinyConfig def __init__(self, config: Swin3DTinyConfig) -> None: super().__init__(config=config) self.model = swin3d_t(config.weights) # self.model.head = torch.nn.Linear( # in_features=self.model.head.in_features, # out_features=config.num_classes, # ) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.model(pixel_values) class Swin3DTinyForVideoClassification(PreTrainedModel): config_class = Swin3DTinyConfig def __init__(self, config: Swin3DTinyConfig) -> None: super().__init__(config=config) self.model = swin3d_t self.model.head = torch.nn.Linear( in_features=self.model.head.in_features, out_features=config.num_classes, ) def forward( self, pixel_values: torch.Tensor, labels: torch.Tensor = None, ) -> torch.Tensor: logits = self.model(pixel_values) if labels is not None: loss = torch.nn.functional.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits}