isitreallyworthit / iirwi /feature_extractor.py
danibalcells's picture
Reorder imports
ffc82d7
raw
history blame
3.88 kB
import os
from collections.abc import Iterable
from pathlib import Path
import pickle
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import fastai.vision.all as fv
from .storage import ImgFeatureStorage
def dummy_loss_func(x, y):
return torch.tensor(0.)
def get_label(file_path):
return os.path.basename(file_path).split('_')[0]
class FeatureExtractorModel(nn.Module):
def __init__(self, original_model):
super().__init__()
self.features = nn.Sequential(*list(original_model.children())[:-1])
self.pooling = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
x = self.features(x)
x = self.pooling(x)
return x.view(x.size(0), -1)
class FeatureExtractor:
def __init__(self, dataset_path=None, dls=None, item_tfms=None, label_func=get_label, n_epochs=5):
item_tfms = item_tfms or [fv.Resize(224)]
self.dataset_path = dataset_path
self.dls = dls
self.item_tfms = item_tfms
self.label_func = label_func
self.n_epochs = n_epochs
if self.dataset_path and not self.dls:
self.dls = fv.ImageDataLoaders.from_name_func(
self.dataset_path, fv.get_image_files(self.dataset_path), valid_pct=0.2, seed=42,
label_func=self.label_func, item_tfms=self.item_tfms)
@classmethod
def from_dataset(cls, dataset_path, item_tfms=[fv.Resize(224)], label_func=get_label, n_epochs=5):
return cls(dataset_path=dataset_path, item_tfms=item_tfms, label_func=label_func, n_epochs=n_epochs)
@classmethod
def from_learner(cls, extractor):
obj = cls(dls=extractor.dls)
obj.extractor = extractor
return obj
@classmethod
def load(cls, filename, label_func=get_label, item_tfms=[fv.Resize(224)]):
extractor = fv.load_learner(filename, cpu=False)
dls = fv.ImageDataLoaders.from_name_func(
extractor.dls.path, fv.get_image_files(extractor.dls.path), valid_pct=0.2, seed=42,
label_func=label_func, item_tfms=item_tfms)
extractor.dls = dls
return cls.from_learner(extractor)
def export(self, model_name, path=Path('.')):
self.extractor.path = path
self.extractor.export(model_name)
def train(self, n_epochs=None):
n_epochs = n_epochs or self.n_epochs
self.classifier = self.train_classifier(n_epochs)
self.extractor = self.get_extractor()
def train_classifier(self, n_epochs=None):
n_epochs = n_epochs or self.n_epochs
classifier = fv.vision_learner(self.dls, fv.resnet18, metrics=fv.error_rate)
classifier.fine_tune(n_epochs)
return classifier
def get_extractor(self):
model = FeatureExtractorModel(self.classifier.model)
extractor = fv.Learner(self.dls, model, loss_func=dummy_loss_func)
return extractor
def predict(self, input_images):
if not isinstance(input_images, Iterable) or isinstance(input_images, str):
input_images = [input_images]
with self.extractor.no_bar(), self.extractor.no_logging():
dl = self.extractor.dls.test_dl(input_images)
inp, features, _, dec = self.extractor.get_preds(dl=dl, with_input=True, with_decoded=True)
return features
def predict_for_dataset(self, dls=None):
dls = dls or self.dls
train_features, _ = self.extractor.get_preds(dl=dls.train)
valid_features, _ = self.extractor.get_preds(dl=dls.valid)
all_features = torch.cat([train_features, valid_features])
all_items = dls.train.items + dls.valid.items
# Create a dictionary mapping image paths to features
features = {image: activation.clone() for image, activation in zip(all_items, all_features)}
return ImgFeatureStorage.from_features_dict(features)