Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
from torch import nn | |
def create_ViT(): | |
ViT_weights = torchvision.models.ViT_B_16_Weights.DEFAULT | |
ViT_model = torchvision.models.vit_b_16(weights=ViT_weights) | |
# Freeze pre-trained weights | |
for param in ViT_model.parameters(): | |
param.requires_grad = False | |
# Find the encoder module and its layers | |
encoder = ViT_model.encoder | |
encoder_layers = encoder.layers | |
# Modify each Encoder layer to include dropout | |
for layer in encoder_layers: | |
# Access the Multi-head Self-Attention module (might be named differently) | |
attn_module = layer.self_attention # Replace with the actual module name in your model | |
# Add dropout layer after the attention module | |
attn_module.add_module('my_dropout', nn.Dropout(p=0.4)) | |
# Add your new head for classification (same as before) | |
ViT_model.heads = nn.Sequential( | |
nn.Dropout(p=0.5), | |
nn.Linear(in_features=768, out_features=1, bias=True) | |
) | |
manual_transforms = torchvision.transforms.Compose([ | |
torchvision.transforms.RandomRotation(25), | |
torchvision.transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), shear=15), | |
torchvision.transforms.RandomHorizontalFlip(), | |
torchvision.transforms.RandomVerticalFlip(), | |
torchvision.transforms.ColorJitter(brightness=(0.9, 1.5)), | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.ToTensor() | |
]) | |
return ViT_model, manual_transforms | |