import torch import torchvision from torch import nn def create_ViT(): ViT_weights = torchvision.models.ViT_B_16_Weights.DEFAULT ViT_model = torchvision.models.vit_b_16(weights=ViT_weights) # Freeze pre-trained weights for param in ViT_model.parameters(): param.requires_grad = False # Find the encoder module and its layers encoder = ViT_model.encoder encoder_layers = encoder.layers # Modify each Encoder layer to include dropout for layer in encoder_layers: # Access the Multi-head Self-Attention module (might be named differently) attn_module = layer.self_attention # Replace with the actual module name in your model # Add dropout layer after the attention module attn_module.add_module('my_dropout', nn.Dropout(p=0.4)) # Add your new head for classification (same as before) ViT_model.heads = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(in_features=768, out_features=1, bias=True) ) manual_transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomRotation(25), torchvision.transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), shear=15), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomVerticalFlip(), torchvision.transforms.ColorJitter(brightness=(0.9, 1.5)), torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor() ]) return ViT_model, manual_transforms