SayNoToCancer / model.py
Chaitanya Garg
Complete Project
ef55f93
raw
history blame
548 Bytes
import torch
import torchvision
from torch import nn
from helper import setAllSeeds
def getEffNetModel(seed,numClasses):
setAllSeeds(seed)
effNetWeights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
effNetTransforms = effNetWeights.transforms()
effNet = torchvision.models.efficientnet_b2(weights=effNetWeights)
for param in effNet.parameters():
param.requires_grad = False
effNet.classifier = nn.Sequential(
nn.Dropout(p=0.3,inplace=True),
nn.Linear(1408,numClasses,bias=True)
)
return effNet,effNetTransforms