|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
class Batch: |
|
@staticmethod |
|
def build(data): |
|
fields = list(data[0].keys()) |
|
transposed = {} |
|
for field in fields: |
|
if isinstance(data[0][field], tuple): |
|
transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field]))) |
|
else: |
|
transposed[field] = Batch._stack(field, [example[field] for example in data]) |
|
|
|
return transposed |
|
|
|
@staticmethod |
|
def _stack(field: str, examples): |
|
if field == "anchored_labels": |
|
return examples |
|
|
|
dim = examples[0].dim() |
|
|
|
if dim == 0: |
|
return torch.stack(examples) |
|
|
|
lengths = [max(example.size(i) for example in examples) for i in range(dim)] |
|
if any(length == 0 for length in lengths): |
|
return torch.LongTensor(len(examples), *lengths) |
|
|
|
examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples] |
|
return torch.stack(examples) |
|
|
|
@staticmethod |
|
def _pad_size(example, total_size): |
|
return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))] |
|
|
|
@staticmethod |
|
def index_select(batch, indices): |
|
filtered_batch = {} |
|
for key, examples in batch.items(): |
|
if isinstance(examples, list) or isinstance(examples, tuple): |
|
filtered_batch[key] = [example.index_select(0, indices) for example in examples] |
|
else: |
|
filtered_batch[key] = examples.index_select(0, indices) |
|
|
|
return filtered_batch |
|
|
|
@staticmethod |
|
def to_str(batch): |
|
string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()]) |
|
return string |
|
|
|
@staticmethod |
|
def to(batch, device): |
|
converted = {} |
|
for field in batch.keys(): |
|
converted[field] = Batch._to(batch[field], device) |
|
return converted |
|
|
|
@staticmethod |
|
def _short_str(tensor): |
|
|
|
if not torch.is_tensor(tensor): |
|
|
|
if hasattr(tensor, "data"): |
|
tensor = getattr(tensor, "data") |
|
|
|
elif isinstance(tensor, tuple) or isinstance(tensor, list): |
|
return str(tuple(Batch._short_str(t) for t in tensor)) |
|
|
|
else: |
|
return str(tensor) |
|
|
|
|
|
size_str = "x".join(str(size) for size in tensor.size()) |
|
device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device()) |
|
strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str) |
|
return strt |
|
|
|
@staticmethod |
|
def _to(tensor, device): |
|
if not torch.is_tensor(tensor): |
|
if isinstance(tensor, tuple): |
|
return tuple(Batch._to(t, device) for t in tensor) |
|
elif isinstance(tensor, list): |
|
return [Batch._to(t, device) for t in tensor] |
|
else: |
|
raise Exception(f"unsupported type of {tensor} to be casted to cuda") |
|
|
|
return tensor.to(device, non_blocking=True) |
|
|