from torch import nn | |
import timm | |
from configuration import CFG | |
class ImageEncoder(nn.Module): | |
""" | |
Encode images to a fixed size vector | |
""" | |
def __init__( | |
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable | |
): | |
super().__init__() | |
self.model = timm.create_model( | |
model_name, pretrained, num_classes=0, global_pool="avg" | |
) | |
for p in self.model.parameters(): | |
p.requires_grad = trainable | |
def forward(self, x): | |
return self.model(x) |