import torch from torchvision import transforms, models def create_vit(output_classes: int = 6, seed: int = 233): transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) model = models.vit_b_16() for param in model.parameters(): param.requires_grade = False torch.manual_seed(seed) model.heads = torch.nn.Sequential( torch.nn.Dropout(0.3), torch.nn.Linear(in_features=768, out_features=output_classes) ) return model, transform