#!/usr/bin/env python3 # coding=utf-8 import torch from data.field.mini_torchtext.field import RawField from data.field.mini_torchtext.vocab import Vocab from collections import Counter import types class EdgeField(RawField): def __init__(self): super(EdgeField, self).__init__() self.vocab = None def process(self, edges, device=None): edges = self.numericalize(edges) tensor = self.pad(edges, device) return tensor def pad(self, edges, device): tensor = torch.zeros(edges[0], edges[1], dtype=torch.long, device=device) for edge in edges[-1]: tensor[edge[0], edge[1]] = edge[2] return tensor def numericalize(self, arr): def multi_map(array, function): if isinstance(array, tuple): return (array[0], array[1], function(array[2])) elif isinstance(array, list): return [multi_map(array[i], function) for i in range(len(array))] else: return array if self.vocab is not None: arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x is not None else 0) return arr def build_vocab(self, *args): def generate(l): if isinstance(l, tuple): yield l[2] elif isinstance(l, list) or isinstance(l, types.GeneratorType): for i in l: yield from generate(i) else: return counter = Counter() sources = [] for arg in args: if isinstance(arg, torch.utils.data.Dataset): sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for x in generate(sources): if x is not None: counter.update([x]) self.vocab = Vocab(counter, specials=[])