File size: 872 Bytes
6b82a67
 
 
 
1b59f08
6b82a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pickle

import torch

from .image import CroppedImage

class ImgFeatureStorage:
    def __init__(self):
        self.features_dict = {}
        self.img_paths = []
        self.images = []

    @classmethod
    def from_features_dict(cls, features_dict):
        obj = cls()
        obj.features_dict = features_dict
        obj.img_paths = list(features_dict.keys())
        obj.images = [CroppedImage.from_path(o) for o in obj.img_paths]
        return obj
    
    def get_features_tensor(self):
        return torch.stack(list(self.features_dict.values()))

    def ix(self, i):
        return self.images[i]

    def export(self, filename):
        pickle.dump(self.features_dict, open(filename, 'wb'))

    @classmethod
    def load(cls, filename):
        features_dict = pickle.load(open(filename, 'rb'))
        return cls.from_features_dict(features_dict)