ssa-perin / data /field /anchor_field.py
larkkin's picture
Add supporting code from perin
7daaa6b
raw
history blame contribute delete
525 Bytes
#!/usr/bin/env python3
# coding=utf-8
import torch
from data.field.mini_torchtext.field import RawField
class AnchorField(RawField):
def process(self, batch, device=None):
tensors, masks = self.pad(batch, device)
return tensors, masks
def pad(self, anchors, device):
tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device)
for anchor in anchors[-1]:
tensor[anchor[0], anchor[1]] = 1
mask = tensor.sum(-1) == 0
return tensor, mask