|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self):
|
|
super(Encoder, self).__init__()
|
|
|
|
basemodel_name = 'tf_efficientnet_b5_ap'
|
|
print('Loading base model ()...'.format(basemodel_name), end='')
|
|
repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
|
|
basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
|
|
print('Done.')
|
|
|
|
|
|
print('Removing last two layers (global_pool & classifier).')
|
|
basemodel.global_pool = nn.Identity()
|
|
basemodel.classifier = nn.Identity()
|
|
|
|
self.original_model = basemodel
|
|
|
|
def forward(self, x):
|
|
features = [x]
|
|
for k, v in self.original_model._modules.items():
|
|
if (k == 'blocks'):
|
|
for ki, vi in v._modules.items():
|
|
features.append(vi(features[-1]))
|
|
else:
|
|
features.append(v(features[-1]))
|
|
return features
|
|
|
|
|
|
|