Spaces:
Running
Running
from collections import Counter | |
import io | |
import os | |
import pickle | |
import random | |
from boltons.iterutils import chunked | |
import lmdb | |
import numpy as np | |
from PIL import Image | |
import pysaliency | |
from pysaliency.datasets import create_subset | |
from pysaliency.utils import remove_trailing_nans | |
import torch | |
from tqdm import tqdm | |
def ensure_color_image(image): | |
if len(image.shape) == 2: | |
return np.dstack([image, image, image]) | |
return image | |
def x_y_to_sparse_indices(xs, ys): | |
# Converts list of x and y coordinates into indices and values for sparse mask | |
x_inds = [] | |
y_inds = [] | |
values = [] | |
pair_inds = {} | |
for x, y in zip(xs, ys): | |
key = (x, y) | |
if key not in pair_inds: | |
x_inds.append(x) | |
y_inds.append(y) | |
pair_inds[key] = len(x_inds) - 1 | |
values.append(1) | |
else: | |
values[pair_inds[key]] += 1 | |
return np.array([y_inds, x_inds]), values | |
class ImageDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
stimuli, | |
fixations, | |
centerbias_model=None, | |
lmdb_path=None, | |
transform=None, | |
cached=None, | |
average='fixation' | |
): | |
self.stimuli = stimuli | |
self.fixations = fixations | |
self.centerbias_model = centerbias_model | |
self.lmdb_path = lmdb_path | |
self.transform = transform | |
self.average = average | |
# cache only short dataset | |
if cached is None: | |
cached = len(self.stimuli) < 100 | |
cache_fixation_data = cached | |
if lmdb_path is not None: | |
_export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path) | |
self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), | |
readonly=True, lock=False, | |
readahead=False, meminit=False | |
) | |
cached = False | |
cache_fixation_data = True | |
else: | |
self.lmdb_env = None | |
self.cached = cached | |
if cached: | |
self._cache = {} | |
self.cache_fixation_data = cache_fixation_data | |
if cache_fixation_data: | |
print("Populating fixations cache") | |
self._xs_cache = {} | |
self._ys_cache = {} | |
for x, y, n in zip(self.fixations.x_int, self.fixations.y_int, tqdm(self.fixations.n)): | |
self._xs_cache.setdefault(n, []).append(x) | |
self._ys_cache.setdefault(n, []).append(y) | |
for key in list(self._xs_cache): | |
self._xs_cache[key] = np.array(self._xs_cache[key], dtype=int) | |
for key in list(self._ys_cache): | |
self._ys_cache[key] = np.array(self._ys_cache[key], dtype=int) | |
def get_shapes(self): | |
return list(self.stimuli.sizes) | |
def _get_image_data(self, n): | |
if self.lmdb_env: | |
image, centerbias_prediction = _get_image_data_from_lmdb(self.lmdb_env, n) | |
else: | |
image = np.array(self.stimuli.stimuli[n]) | |
centerbias_prediction = self.centerbias_model.log_density(image) | |
image = ensure_color_image(image).astype(np.float32) | |
image = image.transpose(2, 0, 1) | |
return image, centerbias_prediction | |
def __getitem__(self, key): | |
if not self.cached or key not in self._cache: | |
image, centerbias_prediction = self._get_image_data(key) | |
centerbias_prediction = centerbias_prediction.astype(np.float32) | |
if self.cache_fixation_data and self.cached: | |
xs = self._xs_cache.pop(key) | |
ys = self._ys_cache.pop(key) | |
elif self.cache_fixation_data and not self.cached: | |
xs = self._xs_cache[key] | |
ys = self._ys_cache[key] | |
else: | |
inds = self.fixations.n == key | |
xs = np.array(self.fixations.x_int[inds], dtype=int) | |
ys = np.array(self.fixations.y_int[inds], dtype=int) | |
data = { | |
"image": image, | |
"x": xs, | |
"y": ys, | |
"centerbias": centerbias_prediction, | |
} | |
if self.average == 'image': | |
data['weight'] = 1.0 | |
else: | |
data['weight'] = float(len(xs)) | |
if self.cached: | |
self._cache[key] = data | |
else: | |
data = self._cache[key] | |
if self.transform is not None: | |
return self.transform(dict(data)) | |
return data | |
def __len__(self): | |
return len(self.stimuli) | |
class FixationDataset(torch.utils.data.Dataset): | |
def __init__( | |
self, | |
stimuli, fixations, | |
centerbias_model=None, | |
lmdb_path=None, | |
transform=None, | |
included_fixations=-2, | |
allow_missing_fixations=False, | |
average='fixation', | |
cache_image_data=False, | |
): | |
self.stimuli = stimuli | |
self.fixations = fixations | |
self.centerbias_model = centerbias_model | |
self.lmdb_path = lmdb_path | |
if lmdb_path is not None: | |
_export_dataset_to_lmdb(stimuli, centerbias_model, lmdb_path) | |
self.lmdb_env = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), | |
readonly=True, lock=False, | |
readahead=False, meminit=False | |
) | |
cache_image_data=False | |
else: | |
self.lmdb_env = None | |
self.transform = transform | |
self.average = average | |
self._shapes = None | |
if isinstance(included_fixations, int): | |
if included_fixations < 0: | |
included_fixations = [-1 - i for i in range(-included_fixations)] | |
else: | |
raise NotImplementedError() | |
self.included_fixations = included_fixations | |
self.allow_missing_fixations = allow_missing_fixations | |
self.fixation_counts = Counter(fixations.n) | |
self.cache_image_data = cache_image_data | |
if self.cache_image_data: | |
self.image_data_cache = {} | |
print("Populating image cache") | |
for n in tqdm(range(len(self.stimuli))): | |
self.image_data_cache[n] = self._get_image_data(n) | |
def get_shapes(self): | |
if self._shapes is None: | |
shapes = list(self.stimuli.sizes) | |
self._shapes = [shapes[n] for n in self.fixations.n] | |
return self._shapes | |
def _get_image_data(self, n): | |
if self.lmdb_path: | |
return _get_image_data_from_lmdb(self.lmdb_env, n) | |
image = np.array(self.stimuli.stimuli[n]) | |
centerbias_prediction = self.centerbias_model.log_density(image) | |
image = ensure_color_image(image).astype(np.float32) | |
image = image.transpose(2, 0, 1) | |
return image, centerbias_prediction | |
def __getitem__(self, key): | |
n = self.fixations.n[key] | |
if self.cache_image_data: | |
image, centerbias_prediction = self.image_data_cache[n] | |
else: | |
image, centerbias_prediction = self._get_image_data(n) | |
centerbias_prediction = centerbias_prediction.astype(np.float32) | |
x_hist = remove_trailing_nans(self.fixations.x_hist[key]) | |
y_hist = remove_trailing_nans(self.fixations.y_hist[key]) | |
if self.allow_missing_fixations: | |
_x_hist = [] | |
_y_hist = [] | |
for fixation_index in self.included_fixations: | |
if fixation_index < -len(x_hist): | |
_x_hist.append(np.nan) | |
_y_hist.append(np.nan) | |
else: | |
_x_hist.append(x_hist[fixation_index]) | |
_y_hist.append(y_hist[fixation_index]) | |
x_hist = np.array(_x_hist) | |
y_hist = np.array(_y_hist) | |
else: | |
print("Not missing") | |
x_hist = x_hist[self.included_fixations] | |
y_hist = y_hist[self.included_fixations] | |
data = { | |
"image": image, | |
"x": np.array([self.fixations.x_int[key]], dtype=int), | |
"y": np.array([self.fixations.y_int[key]], dtype=int), | |
"x_hist": x_hist, | |
"y_hist": y_hist, | |
"centerbias": centerbias_prediction, | |
} | |
if self.average == 'image': | |
data['weight'] = 1.0 / self.fixation_counts[n] | |
else: | |
data['weight'] = 1.0 | |
if self.transform is not None: | |
return self.transform(data) | |
return data | |
def __len__(self): | |
return len(self.fixations) | |
class FixationMaskTransform(object): | |
def __init__(self, sparse=True): | |
super().__init__() | |
self.sparse = sparse | |
def __call__(self, item): | |
shape = torch.Size([item['image'].shape[1], item['image'].shape[2]]) | |
x = item.pop('x') | |
y = item.pop('y') | |
# inds, values = x_y_to_sparse_indices(x, y) | |
inds = np.array([y, x]) | |
values = np.ones(len(y), dtype=int) | |
mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape) | |
mask = mask.coalesce() | |
# sparse tensors don't work with workers... | |
if not self.sparse: | |
mask = mask.to_dense() | |
item['fixation_mask'] = mask | |
return item | |
class ImageDatasetSampler(torch.utils.data.Sampler): | |
def __init__(self, data_source, batch_size=1, ratio_used=1.0, shuffle=True): | |
self.ratio_used = ratio_used | |
self.shuffle = shuffle | |
shapes = data_source.get_shapes() | |
unique_shapes = sorted(set(shapes)) | |
shape_indices = [[] for shape in unique_shapes] | |
for k, shape in enumerate(shapes): | |
shape_indices[unique_shapes.index(shape)].append(k) | |
if self.shuffle: | |
for indices in shape_indices: | |
random.shuffle(indices) | |
self.batches = sum([chunked(indices, size=batch_size) for indices in shape_indices], []) | |
def __iter__(self): | |
if self.shuffle: | |
indices = torch.randperm(len(self.batches)) | |
else: | |
indices = range(len(self.batches)) | |
if self.ratio_used < 1.0: | |
indices = indices[:int(self.ratio_used * len(indices))] | |
return iter(self.batches[i] for i in indices) | |
def __len__(self): | |
return int(self.ratio_used * len(self.batches)) | |
def _export_dataset_to_lmdb(stimuli: pysaliency.FileStimuli, centerbias_model: pysaliency.Model, lmdb_path, write_frequency=100): | |
lmdb_path = os.path.expanduser(lmdb_path) | |
isdir = os.path.isdir(lmdb_path) | |
print("Generate LMDB to %s" % lmdb_path) | |
db = lmdb.open(lmdb_path, subdir=isdir, | |
map_size=1099511627776 * 2, readonly=False, | |
meminit=False, map_async=True) | |
txn = db.begin(write=True) | |
for idx, stimulus in enumerate(tqdm(stimuli)): | |
key = u'{}'.format(idx).encode('ascii') | |
previous_data = txn.get(key) | |
if previous_data: | |
continue | |
#timulus_data = stimulus.stimulus_data | |
stimulus_filename = stimuli.filenames[idx] | |
centerbias = centerbias_model.log_density(stimulus) | |
txn.put( | |
key, | |
_encode_filestimulus_item(stimulus_filename, centerbias) | |
) | |
if idx % write_frequency == 0: | |
#print("[%d/%d]" % (idx, len(stimuli))) | |
#print("stimulus ids", len(stimuli.stimulus_ids._cache)) | |
#print("stimuli.cached", stimuli.cached) | |
#print("stimuli", len(stimuli.stimuli._cache)) | |
#print("centerbias", len(centerbias_model._cache._cache)) | |
txn.commit() | |
txn = db.begin(write=True) | |
# finish iterating through dataset | |
txn.commit() | |
#keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] | |
#with db.begin(write=True) as txn: | |
# txn.put(b'__keys__', dumps_pyarrow(keys)) | |
# txn.put(b'__len__', dumps_pyarrow(len(keys))) | |
print("Flushing database ...") | |
db.sync() | |
db.close() | |
def _encode_filestimulus_item(filename, centerbias): | |
with open(filename, 'rb') as f: | |
image_bytes = f.read() | |
buffer = io.BytesIO() | |
pickle.dump({'image': image_bytes, 'centerbias': centerbias}, buffer) | |
buffer.seek(0) | |
return buffer.read() | |
def _get_image_data_from_lmdb(lmdb_env, n): | |
key = '{}'.format(n).encode('ascii') | |
with lmdb_env.begin(write=False) as txn: | |
byteflow = txn.get(key) | |
data = pickle.loads(byteflow) | |
buffer = io.BytesIO(data['image']) | |
buffer.seek(0) | |
image = np.array(Image.open(buffer).convert('RGB')) | |
centerbias_prediction = data['centerbias'] | |
image = image.transpose(2, 0, 1) | |
return image, centerbias_prediction |