Spaces:
Runtime error
Runtime error
import copy | |
import torch | |
from torchvision.transforms import Compose | |
_TRANSFORM_DICT = {} | |
def register_transform(name): | |
def decorator(cls): | |
_TRANSFORM_DICT[name] = cls | |
return cls | |
return decorator | |
def get_transform(cfg): | |
if cfg is None or len(cfg) == 0: | |
return None | |
tfms = [] | |
for t_dict in cfg: | |
t_dict = copy.deepcopy(t_dict) | |
cls = _TRANSFORM_DICT[t_dict.pop('type')] | |
tfms.append(cls(**t_dict)) | |
return Compose(tfms) | |
def _index_select(v, index, n): | |
if isinstance(v, torch.Tensor) and v.size(0) == n: | |
return v[index] | |
elif isinstance(v, list) and len(v) == n: | |
return [v[i] for i in index] | |
else: | |
return v | |
def _index_select_data(data, index): | |
return { | |
k: _index_select(v, index, data['aa'].size(0)) | |
for k, v in data.items() | |
} | |
def _mask_select(v, mask): | |
if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): | |
return v[mask] | |
elif isinstance(v, list) and len(v) == mask.size(0): | |
return [v[i] for i, b in enumerate(mask) if b] | |
else: | |
return v | |
def _mask_select_data(data, mask): | |
return { | |
k: _mask_select(v, mask) | |
for k, v in data.items() | |
} | |