Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
from data.field.mini_torchtext.field import NestedField as TorchTextNestedField | |
class NestedField(TorchTextNestedField): | |
def pad(self, example): | |
self.nesting_field.include_lengths = self.include_lengths | |
if not self.include_lengths: | |
return self.nesting_field.pad(example) | |
sentence_length = len(example) | |
example, word_lengths = self.nesting_field.pad(example) | |
return example, sentence_length, word_lengths | |
def numericalize(self, arr, device=None): | |
numericalized = [] | |
self.nesting_field.include_lengths = False | |
if self.include_lengths: | |
arr, sentence_length, word_lengths = arr | |
numericalized = self.nesting_field.numericalize(arr, device=device) | |
self.nesting_field.include_lengths = True | |
if self.include_lengths: | |
sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device) | |
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device) | |
return (numericalized, sentence_length, word_lengths) | |
return numericalized | |
def build_vocab(self, *args, **kwargs): | |
sources = [] | |
for arg in args: | |
if isinstance(arg, torch.utils.data.Dataset): | |
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] | |
else: | |
sources.append(arg) | |
flattened = [] | |
for source in sources: | |
flattened.extend(source) | |
# just build vocab and does not load vector | |
self.nesting_field.build_vocab(*flattened, **kwargs) | |
super(TorchTextNestedField, self).build_vocab() | |
self.vocab.extend(self.nesting_field.vocab) | |
self.vocab.freqs = self.nesting_field.vocab.freqs.copy() | |
self.nesting_field.vocab = self.vocab | |