|
import torch.nn as nn |
|
import numpy as np |
|
from abc import abstractmethod |
|
|
|
|
|
class BaseModel(nn.Module): |
|
""" |
|
Base class for all models |
|
""" |
|
@abstractmethod |
|
def forward(self, *inputs): |
|
""" |
|
Forward pass logic |
|
|
|
:return: Model output |
|
""" |
|
raise NotImplementedError |
|
|
|
def __str__(self): |
|
""" |
|
Model prints with number of trainable parameters |
|
""" |
|
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) |
|
params = sum([np.prod(p.size()) for p in model_parameters]) |
|
return super().__str__() + '\nTrainable parameters: {}'.format(params) |
|
|