from torch import nn from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights, vit_b_16, ViT_B_16_Weights def create_effnetb2_model(num_classes=3): weights_effnetb2 = EfficientNet_B2_Weights.DEFAULT transforms_effnetb2 = weights_effnetb2.transforms() model_effnetb2 = efficientnet_b2(weights=weights_effnetb2) for param in model_effnetb2.parameters(): param.requires_grad = False model_effnetb2.classifier[1] = nn.Linear(in_features=1408, out_features=num_classes) return model_effnetb2, transforms_effnetb2 def create_vitb16_model(num_classes=3): weights_vitb16 = ViT_B_16_Weights.DEFAULT transforms_vitb16 = weights_vitb16.transforms() model_vitb16 = vit_b_16(weights=weights_vitb16) for param in model_vitb16.parameters(): param.requires_grad = False model_vitb16.heads[0] = nn.Linear(in_features=768, out_features=num_classes) return model_vitb16, transforms_vitb16