Mayanand's picture
Update model.py
295828c
raw
history blame
537 Bytes
import timm
import torch.nn as nn
from torchvision import models
class ResnetModel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
model = models.resnet18()
model.fc = nn.Linear(512, 10)
self.model = model
def forward(self, x):
out = self.model(x)
return out
class EffnetModel(nn.Module):
def __init__(self, num_classes=10) -> None:
super().__init__()
model = timm.create_model('efficientnet_b0', num_classes=10)
self.model = model
def forward(self, x):
out = self.model(x)
return out