ZeroShot-AD / utils /feature_extractor.py
HoomKh's picture
files
e5461d8 verified
# utils/feature_extractor.py
import torch
import torch.nn as nn
from torchvision import models
class FeatureExtractor(nn.Module):
def __init__(self, backbone='resnet50'):
super(FeatureExtractor, self).__init__()
if backbone == 'resnet50':
self.model = models.resnet50(pretrained=True)
# Remove the final fully connected layer
self.features = nn.Sequential(*list(self.model.children())[:-2])
else:
raise NotImplementedError(f"Backbone {backbone} is not implemented.")
def forward(self, x):
return self.features(x)