#!/usr/bin/env python3 # coding=utf-8 import pickle import torch from data.parser.from_mrp.node_centric_parser import NodeCentricParser from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser from data.parser.from_mrp.sequential_parser import SequentialParser from data.parser.from_mrp.evaluation_parser import EvaluationParser from data.parser.from_mrp.request_parser import RequestParser from data.field.edge_field import EdgeField from data.field.edge_label_field import EdgeLabelField from data.field.field import Field from data.field.mini_torchtext.field import Field as TorchTextField from data.field.label_field import LabelField from data.field.anchored_label_field import AnchoredLabelField from data.field.nested_field import NestedField from data.field.basic_field import BasicField from data.field.bert_field import BertField from data.field.anchor_field import AnchorField from data.batch import Batch def char_tokenize(word): return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10] class Collate: def __call__(self, batch): batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True) return Batch.build(batch) class Dataset: def __init__(self, args, verbose=True): self.verbose = verbose self.sos, self.eos, self.pad, self.unk = "", "", "", "" self.bert_input_field = BertField() self.scatter_field = BasicField() self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True) char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True) self.char_form_field = NestedField(char_form_nesting, include_lengths=True) self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes]) self.anchored_label_field = AnchoredLabelField() self.id_field = Field(batch_first=True, tokenize=lambda x: [x]) self.edge_presence_field = EdgeField() self.edge_label_field = EdgeLabelField() self.anchor_field = AnchorField() self.source_anchor_field = AnchorField() self.target_anchor_field = AnchorField() self.token_interval_field = BasicField() self.load_dataset(args) def log(self, text): if not self.verbose: return print(text, flush=True) def load_state_dict(self, args, d): for key, value in d["vocabs"].items(): getattr(self, key).vocab = pickle.loads(value) def state_dict(self): return { "vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")} } def load_sentences(self, sentences, args): dataset = RequestParser( sentences, args, fields={ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], "bert input": ("input", self.bert_input_field), "to scatter": ("input_scatter", self.scatter_field), "token anchors": ("token_intervals", self.token_interval_field), "id": ("id", self.id_field), }, ) self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) self.id_field.build_vocab(dataset, min_freq=1, specials=[]) return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate()) def load_dataset(self, args): parser = { "sequential": SequentialParser, "node-centric": NodeCentricParser, "labeled-edge": LabeledEdgeParser }[args.graph_mode] train = parser( args, "training", fields={ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], "bert input": ("input", self.bert_input_field), "to scatter": ("input_scatter", self.scatter_field), "nodes": ("labels", self.label_field), "anchored labels": ("anchored_labels", self.anchored_label_field), "edge presence": ("edge_presence", self.edge_presence_field), "edge labels": ("edge_labels", self.edge_label_field), "anchor edges": ("anchor", self.anchor_field), "source anchor edges": ("source_anchor", self.source_anchor_field), "target anchor edges": ("target_anchor", self.target_anchor_field), "token anchors": ("token_intervals", self.token_interval_field), "id": ("id", self.id_field), }, filter_pred=lambda example: len(example.input) <= 256, ) val = parser( args, "validation", fields={ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], "bert input": ("input", self.bert_input_field), "to scatter": ("input_scatter", self.scatter_field), "nodes": ("labels", self.label_field), "anchored labels": ("anchored_labels", self.anchored_label_field), "edge presence": ("edge_presence", self.edge_presence_field), "edge labels": ("edge_labels", self.edge_label_field), "anchor edges": ("anchor", self.anchor_field), "source anchor edges": ("source_anchor", self.source_anchor_field), "target anchor edges": ("target_anchor", self.target_anchor_field), "token anchors": ("token_intervals", self.token_interval_field), "id": ("id", self.id_field), }, ) test = EvaluationParser( args, fields={ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)], "bert input": ("input", self.bert_input_field), "to scatter": ("input_scatter", self.scatter_field), "token anchors": ("token_intervals", self.token_interval_field), "id": ("id", self.id_field), }, ) del train.data, val.data, test.data # TODO: why? for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why? if hasattr(f, "preprocessing"): del f.preprocessing self.train_size = len(train) self.val_size = len(val) self.test_size = len(test) self.log(f"\n{self.train_size} sentences in the train split") self.log(f"{self.val_size} sentences in the validation split") self.log(f"{self.test_size} sentences in the test split") self.node_count = train.node_counter self.token_count = train.input_count self.edge_count = train.edge_counter self.no_edge_count = train.no_edge_counter self.anchor_freq = train.anchor_freq self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5 self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5 self.log(f"{self.node_count} nodes in the train split") self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos]) self.char_form_field.nesting_field.vocab = self.char_form_field.vocab self.id_field.build_vocab(train, val, test, min_freq=1, specials=[]) self.label_field.build_vocab(train) self.anchored_label_field.vocab = self.label_field.vocab self.edge_label_field.build_vocab(train) print(list(self.edge_label_field.vocab.freqs.keys()), flush=True) self.char_form_vocab_size = len(self.char_form_field.vocab) self.create_label_freqs(args) self.create_edge_freqs(args) self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %") self.log(f"{len(self.label_field.vocab)} words in the label vocabulary") self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary") self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary") self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary") self.log(self.label_field.vocab.freqs) self.log(self.anchored_label_field.vocab.freqs) self.train = torch.utils.data.DataLoader( train, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, drop_last=True ) self.train_size = len(self.train.dataset) self.val = torch.utils.data.DataLoader( val, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, ) self.val_size = len(self.val.dataset) self.test = torch.utils.data.DataLoader( test, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, collate_fn=Collate(), pin_memory=True, ) self.test_size = len(self.test.dataset) if self.verbose: batch = next(iter(self.train)) print(f"\nBatch content: {Batch.to_str(batch)}\n") print(flush=True) def create_label_freqs(self, args): n_rules = len(self.label_field.vocab) blank_count = (args.query_length * self.token_count - self.node_count) label_counts = [blank_count] + [ self.label_field.vocab.freqs[self.label_field.vocab.itos[i]] for i in range(n_rules) ] label_counts = torch.FloatTensor(label_counts) self.label_freqs = label_counts / (self.node_count + blank_count) self.log(f"Label frequency: {self.label_freqs}") def create_edge_freqs(self, args): edge_counter = [ self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab)) ] edge_counter = torch.FloatTensor(edge_counter) self.edge_label_freqs = edge_counter / self.edge_count self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count)