Spaces:
Build error
Build error
| import torch | |
| def to_view(data, i): | |
| return {k + i: v for k, v in data.items()} | |
| def get_view(data, i): | |
| data_g = {k: v for k, v in data.items() if not k[-1].isnumeric()} | |
| data_i = {k[:-1]: v for k, v in data.items() if k[-1] == i} | |
| return {**data_g, **data_i} | |
| def get_twoview(data, idx): | |
| li = idx[0] | |
| ri = idx[-1] | |
| assert idx == f"{li}to{ri}" | |
| data_lr = {k[:-4] + "0to1": v for k, v in data.items() if k[-4:] == f"{li}to{ri}"} | |
| data_rl = {k[:-4] + "1to0": v for k, v in data.items() if k[-4:] == f"{ri}ito{li}"} | |
| data_l = { | |
| k[:-1] + "0": v for k, v in data.items() if k[-1:] == li and k[-3:-1] != "to" | |
| } | |
| data_r = { | |
| k[:-1] + "1": v for k, v in data.items() if k[-1:] == ri and k[-3:-1] != "to" | |
| } | |
| return {**data_lr, **data_rl, **data_l, **data_r} | |
| def stack_twoviews(data, indices=["0to1", "0to2", "1to2"]): | |
| idx0 = indices[0] | |
| m_data = data[idx0] if idx0 in data else get_twoview(data, idx0) | |
| # stack on dim=0 | |
| for idx in indices[1:]: | |
| data_i = data[idx] if idx in data else get_twoview(data, idx) | |
| for k, v in data_i.items(): | |
| m_data[k] = torch.cat([m_data[k], v], dim=0) | |
| return m_data | |
| def unstack_twoviews(data, B, indices=["0to1", "0to2", "1to2"]): | |
| out = {} | |
| for i, idx in enumerate(indices): | |
| out[idx] = {k: v[i * B : (i + 1) * B] for k, v in data.items()} | |
| return out | |