danibalcells's picture
Clean up
1b59f08
raw
history blame
872 Bytes
import pickle
import torch
from .image import CroppedImage
class ImgFeatureStorage:
def __init__(self):
self.features_dict = {}
self.img_paths = []
self.images = []
@classmethod
def from_features_dict(cls, features_dict):
obj = cls()
obj.features_dict = features_dict
obj.img_paths = list(features_dict.keys())
obj.images = [CroppedImage.from_path(o) for o in obj.img_paths]
return obj
def get_features_tensor(self):
return torch.stack(list(self.features_dict.values()))
def ix(self, i):
return self.images[i]
def export(self, filename):
pickle.dump(self.features_dict, open(filename, 'wb'))
@classmethod
def load(cls, filename):
features_dict = pickle.load(open(filename, 'rb'))
return cls.from_features_dict(features_dict)