SkinGuard / model.py
Roofus's picture
Initial Commit for SkinGuard (#1)
8f4809f verified
import torch
import torchvision
from torch import nn
def create_ViT():
ViT_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
ViT_model = torchvision.models.vit_b_16(weights=ViT_weights)
# Freeze pre-trained weights
for param in ViT_model.parameters():
param.requires_grad = False
# Find the encoder module and its layers
encoder = ViT_model.encoder
encoder_layers = encoder.layers
# Modify each Encoder layer to include dropout
for layer in encoder_layers:
# Access the Multi-head Self-Attention module (might be named differently)
attn_module = layer.self_attention # Replace with the actual module name in your model
# Add dropout layer after the attention module
attn_module.add_module('my_dropout', nn.Dropout(p=0.4))
# Add your new head for classification (same as before)
ViT_model.heads = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(in_features=768, out_features=1, bias=True)
)
manual_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomRotation(25),
torchvision.transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), shear=15),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomVerticalFlip(),
torchvision.transforms.ColorJitter(brightness=(0.9, 1.5)),
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor()
])
return ViT_model, manual_transforms