Spaces:
Sleeping
Sleeping
import os | |
from collections.abc import Iterable | |
from pathlib import Path | |
import pickle | |
import torch | |
import torch.nn as nn | |
import matplotlib.pyplot as plt | |
import fastai.vision.all as fv | |
from .storage import ImgFeatureStorage | |
def dummy_loss_func(x, y): | |
return torch.tensor(0.) | |
def get_label(file_path): | |
return os.path.basename(file_path).split('_')[0] | |
class FeatureExtractorModel(nn.Module): | |
def __init__(self, original_model): | |
super().__init__() | |
self.features = nn.Sequential(*list(original_model.children())[:-1]) | |
self.pooling = nn.AdaptiveMaxPool2d(1) | |
def forward(self, x): | |
x = self.features(x) | |
x = self.pooling(x) | |
return x.view(x.size(0), -1) | |
class FeatureExtractor: | |
def __init__(self, dataset_path=None, dls=None, item_tfms=None, label_func=get_label, n_epochs=5): | |
item_tfms = item_tfms or [fv.Resize(224)] | |
self.dataset_path = dataset_path | |
self.dls = dls | |
self.item_tfms = item_tfms | |
self.label_func = label_func | |
self.n_epochs = n_epochs | |
if self.dataset_path and not self.dls: | |
self.dls = fv.ImageDataLoaders.from_name_func( | |
self.dataset_path, fv.get_image_files(self.dataset_path), valid_pct=0.2, seed=42, | |
label_func=self.label_func, item_tfms=self.item_tfms) | |
def from_dataset(cls, dataset_path, item_tfms=[fv.Resize(224)], label_func=get_label, n_epochs=5): | |
return cls(dataset_path=dataset_path, item_tfms=item_tfms, label_func=label_func, n_epochs=n_epochs) | |
def from_learner(cls, extractor): | |
obj = cls(dls=extractor.dls) | |
obj.extractor = extractor | |
return obj | |
def load(cls, filename, label_func=get_label, item_tfms=[fv.Resize(224)]): | |
extractor = fv.load_learner(filename, cpu=False) | |
dls = fv.ImageDataLoaders.from_name_func( | |
extractor.dls.path, fv.get_image_files(extractor.dls.path), valid_pct=0.2, seed=42, | |
label_func=label_func, item_tfms=item_tfms) | |
extractor.dls = dls | |
return cls.from_learner(extractor) | |
def export(self, model_name, path=Path('.')): | |
self.extractor.path = path | |
self.extractor.export(model_name) | |
def train(self, n_epochs=None): | |
n_epochs = n_epochs or self.n_epochs | |
self.classifier = self.train_classifier(n_epochs) | |
self.extractor = self.get_extractor() | |
def train_classifier(self, n_epochs=None): | |
n_epochs = n_epochs or self.n_epochs | |
classifier = fv.vision_learner(self.dls, fv.resnet18, metrics=fv.error_rate) | |
classifier.fine_tune(n_epochs) | |
return classifier | |
def get_extractor(self): | |
model = FeatureExtractorModel(self.classifier.model) | |
extractor = fv.Learner(self.dls, model, loss_func=dummy_loss_func) | |
return extractor | |
def predict(self, input_images): | |
if not isinstance(input_images, Iterable) or isinstance(input_images, str): | |
input_images = [input_images] | |
with self.extractor.no_bar(), self.extractor.no_logging(): | |
dl = self.extractor.dls.test_dl(input_images) | |
inp, features, _, dec = self.extractor.get_preds(dl=dl, with_input=True, with_decoded=True) | |
return features | |
def predict_for_dataset(self, dls=None): | |
dls = dls or self.dls | |
train_features, _ = self.extractor.get_preds(dl=dls.train) | |
valid_features, _ = self.extractor.get_preds(dl=dls.valid) | |
all_features = torch.cat([train_features, valid_features]) | |
all_items = dls.train.items + dls.valid.items | |
# Create a dictionary mapping image paths to features | |
features = {image: activation.clone() for image, activation in zip(all_items, all_features)} | |
return ImgFeatureStorage.from_features_dict(features) |