danibalcells commited on
Commit
85d0ae4
1 Parent(s): f2bf83e

Started writing classes to handle extraction / retrieval

Browse files
Files changed (2) hide show
  1. feature_extractor.py +86 -58
  2. iirwi.ipynb +2 -2
feature_extractor.py CHANGED
@@ -1,12 +1,20 @@
1
  import os
2
  import pickle
 
 
 
3
  import torch
4
  import matplotlib.pyplot as plt
5
- from fastai.vision.all import *
 
 
 
 
6
 
 
 
7
 
8
- # Define a new model that includes a global max pooling layer
9
- class FeatureExtractor(nn.Module):
10
  def __init__(self, original_model):
11
  super().__init__()
12
  self.features = nn.Sequential(*list(original_model.children())[:-1])
@@ -17,66 +25,86 @@ class FeatureExtractor(nn.Module):
17
  x = self.pooling(x)
18
  return x.view(x.size(0), -1)
19
 
20
- def dummy_loss_func(x, y):
21
- return torch.tensor(0.)
22
-
23
- def get_label(file_path):
24
- return os.path.basename(file_path).split('_')[0]
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def train_model(dls, n_epochs=5):
27
- learn = vision_learner(dls, resnet18, metrics=error_rate)
28
- learn.fine_tune(n_epochs)
29
- return learn
 
30
 
31
- def get_feature_extractor(learn, dls, loss_func=dummy_loss_func):
32
- model = FeatureExtractor(learn.model)
33
- feature_extractor = Learner(dls, model, loss_func=loss_func)
34
- return feature_extractor
35
 
36
- def train_feature_extractor(dataset_path, item_tfms=[Resize(224)], label_func=get_label, n_epochs=5):
37
- path = Path(dataset_path)
38
- dls = ImageDataLoaders.from_name_func(
39
- path, get_image_files(path), valid_pct=0.2, seed=42,
40
- label_func=get_label, item_tfms=item_tfms)
41
- learn = train_model(dls, n_epochs=n_epochs)
42
- feature_extractor = get_feature_extractor(learn, dls)
43
- return feature_extractor
44
 
45
- def export_feature_extractor(feature_extractor, model_name):
46
- feature_extractor.path = Path('.')
47
- feature_extractor.export(model_name)
 
 
 
 
 
 
 
48
 
49
- def load_feature_extractor(model_name):
50
- feature_extractor = load_learner(model_name)
51
- return feature_extractor
52
- def get_dataset_features(feature_extractor):
53
- dls = feature_extractor.dls
54
- # Get activations for training data
55
- train_activations, _ = feature_extractor.get_preds(dl=dls.train)
56
- # Get activations for validation data
57
- valid_activations, _ = feature_extractor.get_preds(dl=dls.valid)
58
- # Concatenate the activations
59
- all_activations = torch.cat([train_activations, valid_activations])
60
- # Concatenate the image paths
61
- all_items = dls.train.items + dls.valid.items
62
- # Create a dictionary mapping image paths to features
63
- features = {image: activation.clone() for image, activation in zip(all_items, all_activations)}
64
- return features
 
65
 
66
- def write_features(features, filename):
67
- with open(filename, 'wb') as f:
68
- pickle.dump(features, f)
69
 
70
- def load_features(filename):
71
- with open(filename, 'rb') as f:
72
- features = pickle.load(f)
73
- return features
74
 
75
- def get_features_tensor_from_dict(features):
76
- # Convert the features dictionary to a list of tuples
77
- features_list = list(features.items())
78
- # Extract the image paths and features
79
- image_paths, feature_tensors = zip(*features_list)
80
- # Convert the features to a PyTorch tensor
81
- features_tensor = torch.stack(feature_tensors)
82
- return features_tensor, image_paths
 
1
  import os
2
  import pickle
3
+ from pathlib import Path
4
+ from collections.abc import Iterable
5
+
6
  import torch
7
  import matplotlib.pyplot as plt
8
+ import fastai.vision.all as fv
9
+ import torch.nn as nn
10
+
11
+ def dummy_loss_func(x, y):
12
+ return torch.tensor(0.)
13
 
14
+ def get_label(file_path):
15
+ return os.path.basename(file_path).split('_')[0]
16
 
17
+ class FeatureExtractorModel(nn.Module):
 
18
  def __init__(self, original_model):
19
  super().__init__()
20
  self.features = nn.Sequential(*list(original_model.children())[:-1])
 
25
  x = self.pooling(x)
26
  return x.view(x.size(0), -1)
27
 
28
+ class FeatureExtractor:
29
+ def __init__(self, dataset_path=None, dls=None, item_tfms=None, label_func=get_label, n_epochs=5):
30
+ item_tfms = item_tfms or [fv.Resize(224)]
31
+ self.dataset_path = dataset_path
32
+ self.dls = dls
33
+ self.item_tfms = item_tfms
34
+ self.label_func = label_func
35
+ self.n_epochs = n_epochs
36
+ if self.dataset_path and not self.dls:
37
+ self.dls = fv.ImageDataLoaders.from_name_func(
38
+ self.dataset_path, fv.get_image_files(self.dataset_path), valid_pct=0.2, seed=42,
39
+ label_func=self.label_func, item_tfms=self.item_tfms)
40
+
41
+ @classmethod
42
+ def from_dataset(cls, dataset_path, item_tfms=[fv.Resize(224)], label_func=get_label, n_epochs=5):
43
+ return cls(dataset_path=dataset_path, item_tfms=item_tfms, label_func=label_func, n_epochs=n_epochs)
44
 
45
+ @classmethod
46
+ def from_learner(cls, extractor):
47
+ obj = cls(dls=extractor.dls)
48
+ obj.extractor = extractor
49
+ return obj
50
 
51
+ @classmethod
52
+ def load(cls, filename):
53
+ extractor = fv.load_learner(filename, cpu=False)
54
+ return cls.from_learner(extractor)
55
 
56
+ def export(self, model_name, path=Path('.')):
57
+ self.extractor.path = path
58
+ self.extractor.export(model_name)
59
+
60
+ def train(self, n_epochs=None):
61
+ n_epochs = n_epochs or self.n_epochs
62
+ self.classifier = self.train_classifier(n_epochs)
63
+ self.extractor = self.get_extractor()
64
 
65
+ def train_classifier(self, n_epochs=None):
66
+ n_epochs = n_epochs or self.n_epochs
67
+ classifier = fv.vision_learner(self.dls, fv.resnet18, metrics=fv.error_rate)
68
+ classifier.fine_tune(n_epochs)
69
+ return classifier
70
+
71
+ def get_extractor(self):
72
+ model = FeatureExtractorModel(self.classifier.model)
73
+ extractor = fv.Learner(self.dls, model, loss_func=dummy_loss_func)
74
+ return extractor
75
 
76
+ def predict(self, input_images):
77
+ if not isinstance(input_images, Iterable) or isinstance(input_images, str):
78
+ input_images = [input_images]
79
+ with self.extractor.no_bar(), self.extractor.no_logging():
80
+ dl = self.extractor.dls.test_dl(input_images)
81
+ inp, features, _, dec = self.extractor.get_preds(dl=dl, with_input=True, with_decoded=True)
82
+ return features
83
+
84
+ def predict_for_dataset(self, dls=None):
85
+ dls = dls or self.dls
86
+ train_features, _ = self.extractor.get_preds(dl=dls.train)
87
+ valid_features, _ = self.extractor.get_preds(dl=dls.valid)
88
+ all_features = torch.cat([train_features, valid_features])
89
+ all_items = dls.train.items + dls.valid.items
90
+ # Create a dictionary mapping image paths to features
91
+ features = {image: activation.clone() for image, activation in zip(all_items, all_activations)}
92
+ return features
93
 
94
+ # def write_features(features, filename):
95
+ # with open(filename, 'wb') as f:
96
+ # pickle.dump(features, f)
97
 
98
+ # def load_features(filename):
99
+ # with open(filename, 'rb') as f:
100
+ # features = pickle.load(f)
101
+ # return features
102
 
103
+ # def get_features_tensor_from_dict(features):
104
+ # # Convert the features dictionary to a list of tuples
105
+ # features_list = list(features.items())
106
+ # # Extract the image paths and features
107
+ # image_paths, feature_tensors = zip(*features_list)
108
+ # # Convert the features to a PyTorch tensor
109
+ # features_tensor = torch.stack(feature_tensors)
110
+ # return features_tensor, image_paths
iirwi.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac5a85b2028b1c89e732f6b61b6753fdda1947c9d699a50684ffec733bb3622a
3
- size 4802471
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5260d60b65f6ce374191f62eb873d125c33778263f909f515626ca70eb9ac41
3
+ size 12819615