larkkin commited on
Commit
8044721
1 Parent(s): 21ea231

Add application code and models, update README

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +39 -7
  2. app.py +80 -0
  3. config/__init__.py +0 -0
  4. config/params.py +89 -0
  5. data/__init__.py +0 -0
  6. data/batch.py +95 -0
  7. data/dataset.py +245 -0
  8. data/field/__init__.py +0 -0
  9. data/field/anchor_field.py +19 -0
  10. data/field/anchored_label_field.py +38 -0
  11. data/field/basic_field.py +11 -0
  12. data/field/bert_field.py +18 -0
  13. data/field/edge_field.py +63 -0
  14. data/field/edge_label_field.py +67 -0
  15. data/field/field.py +70 -0
  16. data/field/label_field.py +36 -0
  17. data/field/mini_torchtext/example.py +100 -0
  18. data/field/mini_torchtext/field.py +637 -0
  19. data/field/mini_torchtext/pipeline.py +86 -0
  20. data/field/mini_torchtext/utils.py +256 -0
  21. data/field/mini_torchtext/vocab.py +116 -0
  22. data/field/nested_field.py +50 -0
  23. data/parser/__init__.py +0 -0
  24. data/parser/from_mrp/__init__.py +0 -0
  25. data/parser/from_mrp/abstract_parser.py +50 -0
  26. data/parser/from_mrp/evaluation_parser.py +18 -0
  27. data/parser/from_mrp/labeled_edge_parser.py +70 -0
  28. data/parser/from_mrp/node_centric_parser.py +69 -0
  29. data/parser/from_mrp/request_parser.py +23 -0
  30. data/parser/from_mrp/sequential_parser.py +90 -0
  31. data/parser/json_parser.py +35 -0
  32. data/parser/to_mrp/__init__.py +0 -0
  33. data/parser/to_mrp/abstract_parser.py +80 -0
  34. data/parser/to_mrp/labeled_edge_parser.py +52 -0
  35. data/parser/to_mrp/node_centric_parser.py +35 -0
  36. data/parser/to_mrp/sequential_parser.py +35 -0
  37. model/__init__.py +0 -0
  38. model/head/__init__.py +0 -0
  39. model/head/abstract_head.py +274 -0
  40. model/head/labeled_edge_head.py +67 -0
  41. model/head/node_centric_head.py +25 -0
  42. model/head/sequential_head.py +24 -0
  43. model/model.py +82 -0
  44. model/module/__init__.py +0 -0
  45. model/module/anchor_classifier.py +32 -0
  46. model/module/biaffine.py +20 -0
  47. model/module/bilinear.py +43 -0
  48. model/module/char_embedding.py +42 -0
  49. model/module/edge_classifier.py +56 -0
  50. model/module/encoder.py +95 -0
README.md CHANGED
@@ -1,13 +1,45 @@
1
  ---
2
- title: Ssa Perin
3
- emoji:
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
  app_file: app.py
9
  pinned: false
10
- license: cc-by-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Sentiment Analysis
3
+ emoji: 🤔
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.1.7
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
13
+ This space containt an implementation of method described in "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
14
+
15
+ The proposed method suggests three different ways to encode the sentiment graph: "node-centric", "labeled-edge", and "opinion-tuple". The current model uses the "labeled-edge" graph encoding, and achieves the following results on the held-out set of the NoReC dataset:
16
+
17
+ | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
18
+ |:----------------------------:|:----------:|:---------------------------:|
19
+ | 0.434 | 0.541 | 0.926 |
20
+
21
+
22
+ In "Word Substitution with Masked Language Models as Data Augmentation for Sentiment Analysis", we analyzed data augmentation strategies for improving performance of the model. Using masked-language modeling (MLM), we augmented the sentences with MLM-substituted words inside, outside, or inside+outside the actual sentiment tuples. The results below show that augmentation may be improve the model performance. This space, however, runs the original model trained without augmentation.
23
+
24
+ | | Augmentation rate | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
25
+ |----------------|-------------------|------------------------------|-----------|-----------------------------|
26
+ | Baseline | 0% | 43.39 | 54.13 | 92.59 |
27
+ | Outside | 59% | **45.08** | 56.18 | 92.95 |
28
+ | Inside | 9% | 43.38 | 55.62 | 92.49 |
29
+ | Inside+Outside | 27% | 44.12 | **56.44** | **93.19** |
30
+
31
+
32
+
33
+ The model can be easily used for predicting sentiment tuples as follows:
34
+
35
+ ```python
36
+ >>> import model_wrapper
37
+ >>> model = model_wrapper.PredictionModel()
38
+ >>> model.predict(['vi liker svart kaffe'])
39
+ [{'sent_id': '0',
40
+ 'text': 'vi liker svart kaffe',
41
+ 'opinions': [{'Source': [['vi'], ['0:2']],
42
+ 'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
43
+ 'Polar_expression': [['liker'], ['3:8']],
44
+ 'Polarity': 'Positive'}]}]
45
+ ```
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import model_wrapper
3
+
4
+
5
+ model = model_wrapper.PredictionModel()
6
+
7
+
8
+ def pretty_print_opinion(opinion_dict):
9
+ res = []
10
+ maxlen = max([len(key) for key in opinion_dict.keys()]) + 2
11
+ maxlen = 0
12
+ for key, value in opinion_dict.items():
13
+ if key == 'Polarity':
14
+ res.append(f'{(key + ":").ljust(maxlen)} {value}')
15
+ else:
16
+ res.append(f'{(key + ":").ljust(maxlen)} \'{" ".join(value[0])}\'')
17
+ return '\n'.join(res) + '\n'
18
+
19
+
20
+ def predict(text):
21
+ print(f'Input message "{text}"')
22
+ try:
23
+ predictions = model.predict([text])
24
+ prediction = predictions[0]
25
+ results = []
26
+ if not prediction['opinions']:
27
+ return 'No opinions detected'
28
+ for opinion in prediction['opinions']:
29
+ results.append(pretty_print_opinion(opinion))
30
+ print(f'Successfully predicted SA for input message "{text}": {results}')
31
+ return '\n'.join(results)
32
+ except Exception as e:
33
+ print(f'Error for input message "{text}": {e}')
34
+ raise e
35
+
36
+
37
+
38
+ markdown_text = '''
39
+ <br>
40
+ <br>
41
+ This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
42
+ This model is an implementation of the paper "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
43
+
44
+ The current model uses the 'labeled-edge' graph encoding, and achieves the following results on the NoReC dataset:
45
+
46
+ | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
47
+ |:----------------------------:|:----------:|:---------------------------:|
48
+ | 0.393 | 0.468 | 0.939 |
49
+
50
+
51
+ The model can be easily used for predicting sentiment tuples as follows:
52
+
53
+ ```python
54
+ >>> import model_wrapper
55
+ >>> model = model_wrapper.PredictionModel()
56
+ >>> model.predict(['vi liker svart kaffe'])
57
+ [{'sent_id': '0',
58
+ 'text': 'vi liker svart kaffe',
59
+ 'opinions': [{'Source': [['vi'], ['0:2']],
60
+ 'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
61
+ 'Polar_expression': [['liker'], ['3:8']],
62
+ 'Polarity': 'Positive'}]}]
63
+ ```
64
+ '''
65
+
66
+
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row() as row:
70
+ text_input = gr.Textbox(label="input")
71
+ text_output = gr.Textbox(label="output")
72
+ with gr.Row() as row:
73
+ text_button = gr.Button("submit")
74
+
75
+ text_button.click(fn=predict, inputs=text_input, outputs=text_output)
76
+
77
+ gr.Markdown(markdown_text)
78
+
79
+
80
+ demo.launch()
config/__init__.py ADDED
File without changes
config/params.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+
4
+ class Params:
5
+ def __init__(self):
6
+ self.graph_mode = "sequential" # possibilities: {sequential, node-centric, edge-labeled}
7
+ self.accumulation_steps = 1 # number of gradient accumulation steps for achieving a bigger batch_size
8
+ self.activation = "relu" # transformer (decoder) activation function, supported values: {'relu', 'gelu', 'sigmoid', 'mish'}
9
+ self.predict_intensity = False
10
+ self.batch_size = 32 # batch size (further divided into multiple GPUs)
11
+ self.beta_2 = 0.98 # beta 2 parameter for Adam(W) optimizer
12
+ self.blank_weight = 1.0 # weight of cross-entropy loss for predicting an empty label
13
+ self.char_embedding = True # use character embedding in addition to bert
14
+ self.char_embedding_size = 128 # dimension of the character embedding layer in the character embedding module
15
+ self.decoder_delay_steps = 0 # number of initial steps with frozen decoder
16
+ self.decoder_learning_rate = 6e-4 # initial decoder learning rate
17
+ self.decoder_weight_decay = 1.2e-6 # amount of weight decay
18
+ self.dropout_anchor = 0.5 # dropout at the last layer of anchor classifier
19
+ self.dropout_edge_label = 0.5 # dropout at the last layer of edge label classifier
20
+ self.dropout_edge_presence = 0.5 # dropout at the last layer of edge presence classifier
21
+ self.dropout_label = 0.5 # dropout at the last layer of label classifier
22
+ self.dropout_transformer = 0.5 # dropout for the transformer layers (decoder)
23
+ self.dropout_transformer_attention = 0.1 # dropout for the transformer's attention (decoder)
24
+ self.dropout_word = 0.1 # probability of dropping out a whole word from the encoder (in favour of char embedding)
25
+ self.encoder = "xlm-roberta-base" # pretrained encoder model
26
+ self.encoder_delay_steps = 2000 # number of initial steps with frozen XLM-R
27
+ self.encoder_freeze_embedding = True # freeze the first embedding layer in XLM-R
28
+ self.encoder_learning_rate = 6e-5 # initial encoder learning rate
29
+ self.encoder_weight_decay = 1e-2 # amount of weight decay
30
+ self.lr_decay_multiplier = 100
31
+ self.epochs = 100 # number of epochs for train
32
+ self.focal = True # use focal loss for the label prediction
33
+ self.freeze_bert = False # use focal loss for the label prediction
34
+ self.group_ops = False # group 'opN' edge labels into one
35
+ self.hidden_size_ff = 4 * 768 # hidden size of the transformer feed-forward submodule
36
+ self.hidden_size_anchor = 128 # hidden size anchor biaffine layer
37
+ self.hidden_size_edge_label = 256 # hidden size for edge label biaffine layer
38
+ self.hidden_size_edge_presence = 512 # hidden size for edge label biaffine layer
39
+ self.layerwise_lr_decay = 1.0 # layerwise decay of learning rate in the encoder
40
+ self.n_attention_heads = 8 # number of attention heads in the decoding transformer
41
+ self.n_layers = 3 # number of layers in the decoder
42
+ self.query_length = 4 # number of queries genereted for each word on the input
43
+ self.pre_norm = True # use pre-normalized version of the transformer (as in Transformers without Tears)
44
+ self.warmup_steps = 6000 # number of the warm-up steps for the inverse_sqrt scheduler
45
+
46
+ def init_data_paths(self):
47
+ directory_1 = {
48
+ "sequential": "node_centric_mrp",
49
+ "node-centric": "node_centric_mrp",
50
+ "labeled-edge": "labeled_edge_mrp"
51
+ }[self.graph_mode]
52
+ directory_2 = {
53
+ ("darmstadt", "en"): "darmstadt_unis",
54
+ ("mpqa", "en"): "mpqa",
55
+ ("multibooked", "ca"): "multibooked_ca",
56
+ ("multibooked", "eu"): "multibooked_eu",
57
+ ("norec", "no"): "norec",
58
+ ("opener", "en"): "opener_en",
59
+ ("opener", "es"): "opener_es",
60
+ }[(self.framework, self.language)]
61
+
62
+ self.training_data = f"{self.data_directory}/{directory_1}/{directory_2}/train.mrp"
63
+ self.validation_data = f"{self.data_directory}/{directory_1}/{directory_2}/dev.mrp"
64
+ self.test_data = f"{self.data_directory}/{directory_1}/{directory_2}/test.mrp"
65
+
66
+ self.raw_training_data = f"{self.data_directory}/raw/{directory_2}/train.json"
67
+ self.raw_validation_data = f"{self.data_directory}/raw/{directory_2}/dev.json"
68
+
69
+ return self
70
+
71
+ def load_state_dict(self, d):
72
+ for k, v in d.items():
73
+ setattr(self, k, v)
74
+ return self
75
+
76
+ def state_dict(self):
77
+ members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
78
+ return {k: self.__dict__[k] for k in members}
79
+
80
+ def load(self, args):
81
+ with open(args.config, "r", encoding="utf-8") as f:
82
+ params = yaml.safe_load(f)
83
+ self.load_state_dict(params)
84
+ self.init_data_paths()
85
+
86
+ def save(self, json_path):
87
+ with open(json_path, "w", encoding="utf-8") as f:
88
+ d = self.state_dict()
89
+ yaml.dump(d, f)
data/__init__.py ADDED
File without changes
data/batch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Batch:
9
+ @staticmethod
10
+ def build(data):
11
+ fields = list(data[0].keys())
12
+ transposed = {}
13
+ for field in fields:
14
+ if isinstance(data[0][field], tuple):
15
+ transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field])))
16
+ else:
17
+ transposed[field] = Batch._stack(field, [example[field] for example in data])
18
+
19
+ return transposed
20
+
21
+ @staticmethod
22
+ def _stack(field: str, examples):
23
+ if field == "anchored_labels":
24
+ return examples
25
+
26
+ dim = examples[0].dim()
27
+
28
+ if dim == 0:
29
+ return torch.stack(examples)
30
+
31
+ lengths = [max(example.size(i) for example in examples) for i in range(dim)]
32
+ if any(length == 0 for length in lengths):
33
+ return torch.LongTensor(len(examples), *lengths)
34
+
35
+ examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples]
36
+ return torch.stack(examples)
37
+
38
+ @staticmethod
39
+ def _pad_size(example, total_size):
40
+ return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))]
41
+
42
+ @staticmethod
43
+ def index_select(batch, indices):
44
+ filtered_batch = {}
45
+ for key, examples in batch.items():
46
+ if isinstance(examples, list) or isinstance(examples, tuple):
47
+ filtered_batch[key] = [example.index_select(0, indices) for example in examples]
48
+ else:
49
+ filtered_batch[key] = examples.index_select(0, indices)
50
+
51
+ return filtered_batch
52
+
53
+ @staticmethod
54
+ def to_str(batch):
55
+ string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()])
56
+ return string
57
+
58
+ @staticmethod
59
+ def to(batch, device):
60
+ converted = {}
61
+ for field in batch.keys():
62
+ converted[field] = Batch._to(batch[field], device)
63
+ return converted
64
+
65
+ @staticmethod
66
+ def _short_str(tensor):
67
+ # unwrap variable to tensor
68
+ if not torch.is_tensor(tensor):
69
+ # (1) unpack variable
70
+ if hasattr(tensor, "data"):
71
+ tensor = getattr(tensor, "data")
72
+ # (2) handle include_lengths
73
+ elif isinstance(tensor, tuple) or isinstance(tensor, list):
74
+ return str(tuple(Batch._short_str(t) for t in tensor))
75
+ # (3) fallback to default str
76
+ else:
77
+ return str(tensor)
78
+
79
+ # copied from torch _tensor_str
80
+ size_str = "x".join(str(size) for size in tensor.size())
81
+ device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device())
82
+ strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str)
83
+ return strt
84
+
85
+ @staticmethod
86
+ def _to(tensor, device):
87
+ if not torch.is_tensor(tensor):
88
+ if isinstance(tensor, tuple):
89
+ return tuple(Batch._to(t, device) for t in tensor)
90
+ elif isinstance(tensor, list):
91
+ return [Batch._to(t, device) for t in tensor]
92
+ else:
93
+ raise Exception(f"unsupported type of {tensor} to be casted to cuda")
94
+
95
+ return tensor.to(device, non_blocking=True)
data/dataset.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import pickle
5
+
6
+ import torch
7
+
8
+ from data.parser.from_mrp.node_centric_parser import NodeCentricParser
9
+ from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser
10
+ from data.parser.from_mrp.sequential_parser import SequentialParser
11
+ from data.parser.from_mrp.evaluation_parser import EvaluationParser
12
+ from data.parser.from_mrp.request_parser import RequestParser
13
+ from data.field.edge_field import EdgeField
14
+ from data.field.edge_label_field import EdgeLabelField
15
+ from data.field.field import Field
16
+ from data.field.mini_torchtext.field import Field as TorchTextField
17
+ from data.field.label_field import LabelField
18
+ from data.field.anchored_label_field import AnchoredLabelField
19
+ from data.field.nested_field import NestedField
20
+ from data.field.basic_field import BasicField
21
+ from data.field.bert_field import BertField
22
+ from data.field.anchor_field import AnchorField
23
+ from data.batch import Batch
24
+
25
+
26
+ def char_tokenize(word):
27
+ return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10]
28
+
29
+
30
+ class Collate:
31
+ def __call__(self, batch):
32
+ batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True)
33
+ return Batch.build(batch)
34
+
35
+
36
+ class Dataset:
37
+ def __init__(self, args, verbose=True):
38
+ self.verbose = verbose
39
+ self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
40
+
41
+ self.bert_input_field = BertField()
42
+ self.scatter_field = BasicField()
43
+ self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True)
44
+
45
+ char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True)
46
+ self.char_form_field = NestedField(char_form_nesting, include_lengths=True)
47
+
48
+ self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes])
49
+ self.anchored_label_field = AnchoredLabelField()
50
+
51
+ self.id_field = Field(batch_first=True, tokenize=lambda x: [x])
52
+ self.edge_presence_field = EdgeField()
53
+ self.edge_label_field = EdgeLabelField()
54
+ self.anchor_field = AnchorField()
55
+ self.source_anchor_field = AnchorField()
56
+ self.target_anchor_field = AnchorField()
57
+ self.token_interval_field = BasicField()
58
+
59
+ self.load_dataset(args)
60
+
61
+ def log(self, text):
62
+ if not self.verbose:
63
+ return
64
+ print(text, flush=True)
65
+
66
+ def load_state_dict(self, args, d):
67
+ for key, value in d["vocabs"].items():
68
+ getattr(self, key).vocab = pickle.loads(value)
69
+
70
+ def state_dict(self):
71
+ return {
72
+ "vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")}
73
+ }
74
+
75
+ def load_sentences(self, sentences, args):
76
+ dataset = RequestParser(
77
+ sentences, args,
78
+ fields={
79
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
80
+ "bert input": ("input", self.bert_input_field),
81
+ "to scatter": ("input_scatter", self.scatter_field),
82
+ "token anchors": ("token_intervals", self.token_interval_field),
83
+ "id": ("id", self.id_field),
84
+ },
85
+ )
86
+
87
+ self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
88
+ self.id_field.build_vocab(dataset, min_freq=1, specials=[])
89
+
90
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate())
91
+
92
+ def load_dataset(self, args):
93
+ parser = {
94
+ "sequential": SequentialParser,
95
+ "node-centric": NodeCentricParser,
96
+ "labeled-edge": LabeledEdgeParser
97
+ }[args.graph_mode]
98
+
99
+ train = parser(
100
+ args, "training",
101
+ fields={
102
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
103
+ "bert input": ("input", self.bert_input_field),
104
+ "to scatter": ("input_scatter", self.scatter_field),
105
+ "nodes": ("labels", self.label_field),
106
+ "anchored labels": ("anchored_labels", self.anchored_label_field),
107
+ "edge presence": ("edge_presence", self.edge_presence_field),
108
+ "edge labels": ("edge_labels", self.edge_label_field),
109
+ "anchor edges": ("anchor", self.anchor_field),
110
+ "source anchor edges": ("source_anchor", self.source_anchor_field),
111
+ "target anchor edges": ("target_anchor", self.target_anchor_field),
112
+ "token anchors": ("token_intervals", self.token_interval_field),
113
+ "id": ("id", self.id_field),
114
+ },
115
+ filter_pred=lambda example: len(example.input) <= 256,
116
+ )
117
+
118
+ val = parser(
119
+ args, "validation",
120
+ fields={
121
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
122
+ "bert input": ("input", self.bert_input_field),
123
+ "to scatter": ("input_scatter", self.scatter_field),
124
+ "nodes": ("labels", self.label_field),
125
+ "anchored labels": ("anchored_labels", self.anchored_label_field),
126
+ "edge presence": ("edge_presence", self.edge_presence_field),
127
+ "edge labels": ("edge_labels", self.edge_label_field),
128
+ "anchor edges": ("anchor", self.anchor_field),
129
+ "source anchor edges": ("source_anchor", self.source_anchor_field),
130
+ "target anchor edges": ("target_anchor", self.target_anchor_field),
131
+ "token anchors": ("token_intervals", self.token_interval_field),
132
+ "id": ("id", self.id_field),
133
+ },
134
+ )
135
+
136
+ test = EvaluationParser(
137
+ args,
138
+ fields={
139
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
140
+ "bert input": ("input", self.bert_input_field),
141
+ "to scatter": ("input_scatter", self.scatter_field),
142
+ "token anchors": ("token_intervals", self.token_interval_field),
143
+ "id": ("id", self.id_field),
144
+ },
145
+ )
146
+
147
+ del train.data, val.data, test.data # TODO: why?
148
+ for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why?
149
+ if hasattr(f, "preprocessing"):
150
+ del f.preprocessing
151
+
152
+ self.train_size = len(train)
153
+ self.val_size = len(val)
154
+ self.test_size = len(test)
155
+
156
+ self.log(f"\n{self.train_size} sentences in the train split")
157
+ self.log(f"{self.val_size} sentences in the validation split")
158
+ self.log(f"{self.test_size} sentences in the test split")
159
+
160
+ self.node_count = train.node_counter
161
+ self.token_count = train.input_count
162
+ self.edge_count = train.edge_counter
163
+ self.no_edge_count = train.no_edge_counter
164
+ self.anchor_freq = train.anchor_freq
165
+
166
+ self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5
167
+ self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5
168
+ self.log(f"{self.node_count} nodes in the train split")
169
+
170
+ self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
171
+ self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
172
+ self.char_form_field.nesting_field.vocab = self.char_form_field.vocab
173
+ self.id_field.build_vocab(train, val, test, min_freq=1, specials=[])
174
+ self.label_field.build_vocab(train)
175
+ self.anchored_label_field.vocab = self.label_field.vocab
176
+ self.edge_label_field.build_vocab(train)
177
+ print(list(self.edge_label_field.vocab.freqs.keys()), flush=True)
178
+
179
+ self.char_form_vocab_size = len(self.char_form_field.vocab)
180
+ self.create_label_freqs(args)
181
+ self.create_edge_freqs(args)
182
+
183
+ self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %")
184
+ self.log(f"{len(self.label_field.vocab)} words in the label vocabulary")
185
+ self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary")
186
+ self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary")
187
+ self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary")
188
+
189
+ self.log(self.label_field.vocab.freqs)
190
+ self.log(self.anchored_label_field.vocab.freqs)
191
+
192
+ self.train = torch.utils.data.DataLoader(
193
+ train,
194
+ batch_size=args.batch_size,
195
+ shuffle=True,
196
+ num_workers=args.workers,
197
+ collate_fn=Collate(),
198
+ pin_memory=True,
199
+ drop_last=True
200
+ )
201
+ self.train_size = len(self.train.dataset)
202
+
203
+ self.val = torch.utils.data.DataLoader(
204
+ val,
205
+ batch_size=args.batch_size,
206
+ shuffle=False,
207
+ num_workers=args.workers,
208
+ collate_fn=Collate(),
209
+ pin_memory=True,
210
+ )
211
+ self.val_size = len(self.val.dataset)
212
+
213
+ self.test = torch.utils.data.DataLoader(
214
+ test,
215
+ batch_size=args.batch_size,
216
+ shuffle=False,
217
+ num_workers=args.workers,
218
+ collate_fn=Collate(),
219
+ pin_memory=True,
220
+ )
221
+ self.test_size = len(self.test.dataset)
222
+
223
+ if self.verbose:
224
+ batch = next(iter(self.train))
225
+ print(f"\nBatch content: {Batch.to_str(batch)}\n")
226
+ print(flush=True)
227
+
228
+ def create_label_freqs(self, args):
229
+ n_rules = len(self.label_field.vocab)
230
+ blank_count = (args.query_length * self.token_count - self.node_count)
231
+ label_counts = [blank_count] + [
232
+ self.label_field.vocab.freqs[self.label_field.vocab.itos[i]]
233
+ for i in range(n_rules)
234
+ ]
235
+ label_counts = torch.FloatTensor(label_counts)
236
+ self.label_freqs = label_counts / (self.node_count + blank_count)
237
+ self.log(f"Label frequency: {self.label_freqs}")
238
+
239
+ def create_edge_freqs(self, args):
240
+ edge_counter = [
241
+ self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab))
242
+ ]
243
+ edge_counter = torch.FloatTensor(edge_counter)
244
+ self.edge_label_freqs = edge_counter / self.edge_count
245
+ self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count)
data/field/__init__.py ADDED
File without changes
data/field/anchor_field.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class AnchorField(RawField):
9
+ def process(self, batch, device=None):
10
+ tensors, masks = self.pad(batch, device)
11
+ return tensors, masks
12
+
13
+ def pad(self, anchors, device):
14
+ tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device)
15
+ for anchor in anchors[-1]:
16
+ tensor[anchor[0], anchor[1]] = 1
17
+ mask = tensor.sum(-1) == 0
18
+
19
+ return tensor, mask
data/field/anchored_label_field.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import RawField
3
+
4
+
5
+ class AnchoredLabelField(RawField):
6
+ def __init__(self):
7
+ super(AnchoredLabelField, self).__init__()
8
+ self.vocab = None
9
+
10
+ def process(self, example, device=None):
11
+ example = self.numericalize(example)
12
+ tensor = self.pad(example, device)
13
+ return tensor
14
+
15
+ def pad(self, example, device):
16
+ n_labels = len(self.vocab)
17
+ n_nodes, n_tokens = len(example[1]), example[0]
18
+
19
+ tensor = torch.full([n_nodes, n_tokens, n_labels + 1], 0, dtype=torch.long, device=device)
20
+ for i_node, node in enumerate(example[1]):
21
+ for anchor, rule in node:
22
+ tensor[i_node, anchor, rule + 1] = 1
23
+
24
+ return tensor
25
+
26
+ def numericalize(self, arr):
27
+ def multi_map(array, function):
28
+ if isinstance(array, tuple):
29
+ return (array[0], function(array[1]))
30
+ elif isinstance(array, list):
31
+ return [multi_map(a, function) for a in array]
32
+ else:
33
+ return array
34
+
35
+ if self.vocab is not None:
36
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
37
+
38
+ return arr
data/field/basic_field.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class BasicField(RawField):
9
+ def process(self, example, device=None):
10
+ tensor = torch.tensor(example, dtype=torch.long, device=device)
11
+ return tensor
data/field/bert_field.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class BertField(RawField):
9
+ def __init__(self):
10
+ super(BertField, self).__init__()
11
+
12
+ def process(self, example, device=None):
13
+ attention_mask = [1] * len(example)
14
+
15
+ example = torch.LongTensor(example, device=device)
16
+ attention_mask = torch.ones_like(example)
17
+
18
+ return example, attention_mask
data/field/edge_field.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+ from data.field.mini_torchtext.vocab import Vocab
7
+ from collections import Counter
8
+ import types
9
+
10
+
11
+ class EdgeField(RawField):
12
+ def __init__(self):
13
+ super(EdgeField, self).__init__()
14
+ self.vocab = None
15
+
16
+ def process(self, edges, device=None):
17
+ edges = self.numericalize(edges)
18
+ tensor = self.pad(edges, device)
19
+ return tensor
20
+
21
+ def pad(self, edges, device):
22
+ tensor = torch.zeros(edges[0], edges[1], dtype=torch.long, device=device)
23
+ for edge in edges[-1]:
24
+ tensor[edge[0], edge[1]] = edge[2]
25
+
26
+ return tensor
27
+
28
+ def numericalize(self, arr):
29
+ def multi_map(array, function):
30
+ if isinstance(array, tuple):
31
+ return (array[0], array[1], function(array[2]))
32
+ elif isinstance(array, list):
33
+ return [multi_map(array[i], function) for i in range(len(array))]
34
+ else:
35
+ return array
36
+
37
+ if self.vocab is not None:
38
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x is not None else 0)
39
+ return arr
40
+
41
+ def build_vocab(self, *args):
42
+ def generate(l):
43
+ if isinstance(l, tuple):
44
+ yield l[2]
45
+ elif isinstance(l, list) or isinstance(l, types.GeneratorType):
46
+ for i in l:
47
+ yield from generate(i)
48
+ else:
49
+ return
50
+
51
+ counter = Counter()
52
+ sources = []
53
+ for arg in args:
54
+ if isinstance(arg, torch.utils.data.Dataset):
55
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
56
+ else:
57
+ sources.append(arg)
58
+
59
+ for x in generate(sources):
60
+ if x is not None:
61
+ counter.update([x])
62
+
63
+ self.vocab = Vocab(counter, specials=[])
data/field/edge_label_field.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+ from data.field.mini_torchtext.vocab import Vocab
7
+ from collections import Counter
8
+ import types
9
+
10
+
11
+ class EdgeLabelField(RawField):
12
+ def process(self, edges, device=None):
13
+ edges, masks = self.numericalize(edges)
14
+ edges, masks = self.pad(edges, masks, device)
15
+
16
+ return edges, masks
17
+
18
+ def pad(self, edges, masks, device):
19
+ n_labels = len(self.vocab)
20
+
21
+ tensor = torch.zeros(edges[0], edges[1], n_labels, dtype=torch.long, device=device)
22
+ mask_tensor = torch.zeros(edges[0], edges[1], dtype=torch.bool, device=device)
23
+
24
+ for edge in edges[-1]:
25
+ tensor[edge[0], edge[1], edge[2]] = 1
26
+
27
+ for mask in masks[-1]:
28
+ mask_tensor[mask[0], mask[1]] = mask[2]
29
+
30
+ return tensor, mask_tensor
31
+
32
+ def numericalize(self, arr):
33
+ def multi_map(array, function):
34
+ if isinstance(array, tuple):
35
+ return (array[0], array[1], function(array[2]))
36
+ elif isinstance(array, list):
37
+ return [multi_map(array[i], function) for i in range(len(array))]
38
+ else:
39
+ return array
40
+
41
+ mask = multi_map(arr, lambda x: x is None)
42
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
43
+ return arr, mask
44
+
45
+ def build_vocab(self, *args):
46
+ def generate(l):
47
+ if isinstance(l, tuple):
48
+ yield l[2]
49
+ elif isinstance(l, list) or isinstance(l, types.GeneratorType):
50
+ for i in l:
51
+ yield from generate(i)
52
+ else:
53
+ return
54
+
55
+ counter = Counter()
56
+ sources = []
57
+ for arg in args:
58
+ if isinstance(arg, torch.utils.data.Dataset):
59
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
60
+ else:
61
+ sources.append(arg)
62
+
63
+ for x in generate(sources):
64
+ if x is not None:
65
+ counter.update([x])
66
+
67
+ self.vocab = Vocab(counter, specials=[])
data/field/field.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import Field as TorchTextField
3
+ from collections import Counter, OrderedDict
4
+
5
+
6
+ # small change of vocab building to correspond to our version of Dataset
7
+ class Field(TorchTextField):
8
+ def build_vocab(self, *args, **kwargs):
9
+ counter = Counter()
10
+ sources = []
11
+ for arg in args:
12
+ if isinstance(arg, torch.utils.data.Dataset):
13
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
14
+ else:
15
+ sources.append(arg)
16
+ for data in sources:
17
+ for x in data:
18
+ if not self.sequential:
19
+ x = [x]
20
+ counter.update(x)
21
+
22
+ specials = list(
23
+ OrderedDict.fromkeys(
24
+ tok
25
+ for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", [])
26
+ if tok is not None
27
+ )
28
+ )
29
+ self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
30
+
31
+ def process(self, example, device=None):
32
+ if self.include_lengths:
33
+ example = example, len(example)
34
+ tensor = self.numericalize(example, device=device)
35
+ return tensor
36
+
37
+ def numericalize(self, ex, device=None):
38
+ if self.include_lengths and not isinstance(ex, tuple):
39
+ raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).")
40
+
41
+ if isinstance(ex, tuple):
42
+ ex, lengths = ex
43
+ lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
44
+
45
+ if self.use_vocab:
46
+ if self.sequential:
47
+ ex = [self.vocab.stoi[x] for x in ex]
48
+ else:
49
+ ex = self.vocab.stoi[ex]
50
+
51
+ if self.postprocessing is not None:
52
+ ex = self.postprocessing(ex, self.vocab)
53
+ else:
54
+ numericalization_func = self.dtypes[self.dtype]
55
+
56
+ if not self.sequential:
57
+ ex = numericalization_func(ex) if isinstance(ex, str) else ex
58
+ if self.postprocessing is not None:
59
+ ex = self.postprocessing(ex, None)
60
+
61
+ var = torch.tensor(ex, dtype=self.dtype, device=device)
62
+
63
+ if self.sequential and not self.batch_first:
64
+ var.t_()
65
+ if self.sequential:
66
+ var = var.contiguous()
67
+
68
+ if self.include_lengths:
69
+ return var, lengths
70
+ return var
data/field/label_field.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import RawField
3
+ from data.field.mini_torchtext.vocab import Vocab
4
+ from collections import Counter
5
+
6
+
7
+ class LabelField(RawField):
8
+ def __self__(self, preprocessing):
9
+ super(LabelField, self).__init__(preprocessing=preprocessing)
10
+ self.vocab = None
11
+
12
+ def build_vocab(self, *args, **kwargs):
13
+ sources = []
14
+ for arg in args:
15
+ if isinstance(arg, torch.utils.data.Dataset):
16
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
17
+ else:
18
+ sources.append(arg)
19
+
20
+ counter = Counter()
21
+ for data in sources:
22
+ for x in data:
23
+ counter.update(x)
24
+
25
+ self.vocab = Vocab(counter, specials=[])
26
+
27
+ def process(self, example, device=None):
28
+ tensor, lengths = self.numericalize(example, device=device)
29
+ return tensor, lengths
30
+
31
+ def numericalize(self, example, device=None):
32
+ example = [self.vocab.stoi[x] + 1 for x in example]
33
+ length = torch.LongTensor([len(example)], device=device).squeeze(0)
34
+ tensor = torch.LongTensor(example, device=device)
35
+
36
+ return tensor, length
data/field/mini_torchtext/example.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import six
2
+ import json
3
+ from functools import reduce
4
+
5
+
6
+ class Example(object):
7
+ """Defines a single training or test example.
8
+
9
+ Stores each column of the example as an attribute.
10
+ """
11
+ @classmethod
12
+ def fromJSON(cls, data, fields):
13
+ ex = cls()
14
+ obj = json.loads(data)
15
+
16
+ for key, vals in fields.items():
17
+ if vals is not None:
18
+ if not isinstance(vals, list):
19
+ vals = [vals]
20
+
21
+ for val in vals:
22
+ # for processing the key likes 'foo.bar'
23
+ name, field = val
24
+ ks = key.split('.')
25
+
26
+ def reducer(obj, key):
27
+ if isinstance(obj, list):
28
+ results = []
29
+ for data in obj:
30
+ if key not in data:
31
+ # key error
32
+ raise ValueError("Specified key {} was not found in "
33
+ "the input data".format(key))
34
+ else:
35
+ results.append(data[key])
36
+ return results
37
+ else:
38
+ # key error
39
+ if key not in obj:
40
+ raise ValueError("Specified key {} was not found in "
41
+ "the input data".format(key))
42
+ else:
43
+ return obj[key]
44
+
45
+ v = reduce(reducer, ks, obj)
46
+ setattr(ex, name, field.preprocess(v))
47
+ return ex
48
+
49
+ @classmethod
50
+ def fromdict(cls, data, fields):
51
+ ex = cls()
52
+ for key, vals in fields.items():
53
+ if key not in data:
54
+ raise ValueError("Specified key {} was not found in "
55
+ "the input data".format(key))
56
+ if vals is not None:
57
+ if not isinstance(vals, list):
58
+ vals = [vals]
59
+ for val in vals:
60
+ name, field = val
61
+ setattr(ex, name, field.preprocess(data[key]))
62
+ return ex
63
+
64
+ @classmethod
65
+ def fromCSV(cls, data, fields, field_to_index=None):
66
+ if field_to_index is None:
67
+ return cls.fromlist(data, fields)
68
+ else:
69
+ assert(isinstance(fields, dict))
70
+ data_dict = {f: data[idx] for f, idx in field_to_index.items()}
71
+ return cls.fromdict(data_dict, fields)
72
+
73
+ @classmethod
74
+ def fromlist(cls, data, fields):
75
+ ex = cls()
76
+ for (name, field), val in zip(fields, data):
77
+ if field is not None:
78
+ if isinstance(val, six.string_types):
79
+ val = val.rstrip('\n')
80
+ # Handle field tuples
81
+ if isinstance(name, tuple):
82
+ for n, f in zip(name, field):
83
+ setattr(ex, n, f.preprocess(val))
84
+ else:
85
+ setattr(ex, name, field.preprocess(val))
86
+ return ex
87
+
88
+ @classmethod
89
+ def fromtree(cls, data, fields, subtrees=False):
90
+ try:
91
+ from nltk.tree import Tree
92
+ except ImportError:
93
+ print("Please install NLTK. "
94
+ "See the docs at http://nltk.org for more information.")
95
+ raise
96
+ tree = Tree.fromstring(data)
97
+ if subtrees:
98
+ return [cls.fromlist(
99
+ [' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
100
+ return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)
data/field/mini_torchtext/field.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+ from collections import Counter, OrderedDict
3
+ from itertools import chain
4
+ import six
5
+ import torch
6
+
7
+ from .pipeline import Pipeline
8
+ from .utils import get_tokenizer, dtype_to_attr, is_tokenizer_serializable
9
+ from .vocab import Vocab
10
+
11
+
12
+ class RawField(object):
13
+ """ Defines a general datatype.
14
+
15
+ Every dataset consists of one or more types of data. For instance, a text
16
+ classification dataset contains sentences and their classes, while a
17
+ machine translation dataset contains paired examples of text in two
18
+ languages. Each of these types of data is represented by a RawField object.
19
+ A RawField object does not assume any property of the data type and
20
+ it holds parameters relating to how a datatype should be processed.
21
+
22
+ Attributes:
23
+ preprocessing: The Pipeline that will be applied to examples
24
+ using this field before creating an example.
25
+ Default: None.
26
+ postprocessing: A Pipeline that will be applied to a list of examples
27
+ using this field before assigning to a batch.
28
+ Function signature: (batch(list)) -> object
29
+ Default: None.
30
+ is_target: Whether this field is a target variable.
31
+ Affects iteration over batches. Default: False
32
+ """
33
+
34
+ def __init__(self, preprocessing=None, postprocessing=None, is_target=False):
35
+ self.preprocessing = preprocessing
36
+ self.postprocessing = postprocessing
37
+ self.is_target = is_target
38
+
39
+ def preprocess(self, x):
40
+ """ Preprocess an example if the `preprocessing` Pipeline is provided. """
41
+ if hasattr(self, "preprocessing") and self.preprocessing is not None:
42
+ return self.preprocessing(x)
43
+ else:
44
+ return x
45
+
46
+ def process(self, batch, *args, **kwargs):
47
+ """ Process a list of examples to create a batch.
48
+
49
+ Postprocess the batch with user-provided Pipeline.
50
+
51
+ Args:
52
+ batch (list(object)): A list of object from a batch of examples.
53
+ Returns:
54
+ object: Processed object given the input and custom
55
+ postprocessing Pipeline.
56
+ """
57
+ if self.postprocessing is not None:
58
+ batch = self.postprocessing(batch)
59
+ return batch
60
+
61
+
62
+ class Field(RawField):
63
+ """Defines a datatype together with instructions for converting to Tensor.
64
+
65
+ Field class models common text processing datatypes that can be represented
66
+ by tensors. It holds a Vocab object that defines the set of possible values
67
+ for elements of the field and their corresponding numerical representations.
68
+ The Field object also holds other parameters relating to how a datatype
69
+ should be numericalized, such as a tokenization method and the kind of
70
+ Tensor that should be produced.
71
+
72
+ If a Field is shared between two columns in a dataset (e.g., question and
73
+ answer in a QA dataset), then they will have a shared vocabulary.
74
+
75
+ Attributes:
76
+ sequential: Whether the datatype represents sequential data. If False,
77
+ no tokenization is applied. Default: True.
78
+ use_vocab: Whether to use a Vocab object. If False, the data in this
79
+ field should already be numerical. Default: True.
80
+ init_token: A token that will be prepended to every example using this
81
+ field, or None for no initial token. Default: None.
82
+ eos_token: A token that will be appended to every example using this
83
+ field, or None for no end-of-sentence token. Default: None.
84
+ fix_length: A fixed length that all examples using this field will be
85
+ padded to, or None for flexible sequence lengths. Default: None.
86
+ dtype: The torch.dtype class that represents a batch of examples
87
+ of this kind of data. Default: torch.long.
88
+ preprocessing: The Pipeline that will be applied to examples
89
+ using this field after tokenizing but before numericalizing. Many
90
+ Datasets replace this attribute with a custom preprocessor.
91
+ Default: None.
92
+ postprocessing: A Pipeline that will be applied to examples using
93
+ this field after numericalizing but before the numbers are turned
94
+ into a Tensor. The pipeline function takes the batch as a list, and
95
+ the field's Vocab.
96
+ Default: None.
97
+ lower: Whether to lowercase the text in this field. Default: False.
98
+ tokenize: The function used to tokenize strings using this field into
99
+ sequential examples. If "spacy", the SpaCy tokenizer is
100
+ used. If a non-serializable function is passed as an argument,
101
+ the field will not be able to be serialized. Default: string.split.
102
+ tokenizer_language: The language of the tokenizer to be constructed.
103
+ Various languages currently supported only in SpaCy.
104
+ include_lengths: Whether to return a tuple of a padded minibatch and
105
+ a list containing the lengths of each examples, or just a padded
106
+ minibatch. Default: False.
107
+ batch_first: Whether to produce tensors with the batch dimension first.
108
+ Default: False.
109
+ pad_token: The string token used as padding. Default: "<pad>".
110
+ unk_token: The string token used to represent OOV words. Default: "<unk>".
111
+ pad_first: Do the padding of the sequence at the beginning. Default: False.
112
+ truncate_first: Do the truncating of the sequence at the beginning. Default: False
113
+ stop_words: Tokens to discard during the preprocessing step. Default: None
114
+ is_target: Whether this field is a target variable.
115
+ Affects iteration over batches. Default: False
116
+ """
117
+
118
+ vocab_cls = Vocab
119
+ # Dictionary mapping PyTorch tensor dtypes to the appropriate Python
120
+ # numeric type.
121
+ dtypes = {
122
+ torch.float32: float,
123
+ torch.float: float,
124
+ torch.float64: float,
125
+ torch.double: float,
126
+ torch.float16: float,
127
+ torch.half: float,
128
+
129
+ torch.uint8: int,
130
+ torch.int8: int,
131
+ torch.int16: int,
132
+ torch.short: int,
133
+ torch.int32: int,
134
+ torch.int: int,
135
+ torch.int64: int,
136
+ torch.long: int,
137
+ }
138
+
139
+ ignore = ['dtype', 'tokenize']
140
+
141
+ def __init__(self, sequential=True, use_vocab=True, init_token=None,
142
+ eos_token=None, fix_length=None, dtype=torch.long,
143
+ preprocessing=None, postprocessing=None, lower=False,
144
+ tokenize=None, tokenizer_language='en', include_lengths=False,
145
+ batch_first=False, pad_token="<pad>", unk_token="<unk>",
146
+ pad_first=False, truncate_first=False, stop_words=None,
147
+ is_target=False):
148
+ self.sequential = sequential
149
+ self.use_vocab = use_vocab
150
+ self.init_token = init_token
151
+ self.eos_token = eos_token
152
+ self.unk_token = unk_token
153
+ self.fix_length = fix_length
154
+ self.dtype = dtype
155
+ self.preprocessing = preprocessing
156
+ self.postprocessing = postprocessing
157
+ self.lower = lower
158
+ # store params to construct tokenizer for serialization
159
+ # in case the tokenizer isn't picklable (e.g. spacy)
160
+ self.tokenizer_args = (tokenize, tokenizer_language)
161
+ self.tokenize = get_tokenizer(tokenize, tokenizer_language)
162
+ self.include_lengths = include_lengths
163
+ self.batch_first = batch_first
164
+ self.pad_token = pad_token if self.sequential else None
165
+ self.pad_first = pad_first
166
+ self.truncate_first = truncate_first
167
+ try:
168
+ self.stop_words = set(stop_words) if stop_words is not None else None
169
+ except TypeError:
170
+ raise ValueError("Stop words must be convertible to a set")
171
+ self.is_target = is_target
172
+
173
+ def __getstate__(self):
174
+ str_type = dtype_to_attr(self.dtype)
175
+ if is_tokenizer_serializable(*self.tokenizer_args):
176
+ tokenize = self.tokenize
177
+ else:
178
+ # signal to restore in `__setstate__`
179
+ tokenize = None
180
+ attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore}
181
+ attrs['dtype'] = str_type
182
+ attrs['tokenize'] = tokenize
183
+
184
+ return attrs
185
+
186
+ def __setstate__(self, state):
187
+ state['dtype'] = getattr(torch, state['dtype'])
188
+ if not state['tokenize']:
189
+ state['tokenize'] = get_tokenizer(*state['tokenizer_args'])
190
+ self.__dict__.update(state)
191
+
192
+ def __hash__(self):
193
+ # we don't expect this to be called often
194
+ return 42
195
+
196
+ def __eq__(self, other):
197
+ if not isinstance(other, RawField):
198
+ return False
199
+
200
+ return self.__dict__ == other.__dict__
201
+
202
+ def preprocess(self, x):
203
+ """Load a single example using this field, tokenizing if necessary.
204
+
205
+ If the input is a Python 2 `str`, it will be converted to Unicode
206
+ first. If `sequential=True`, it will be tokenized. Then the input
207
+ will be optionally lowercased and passed to the user-provided
208
+ `preprocessing` Pipeline."""
209
+ if (six.PY2 and isinstance(x, six.string_types)
210
+ and not isinstance(x, six.text_type)):
211
+ x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
212
+ if self.sequential and isinstance(x, six.text_type):
213
+ x = self.tokenize(x.rstrip('\n'))
214
+ if self.lower:
215
+ x = Pipeline(six.text_type.lower)(x)
216
+ if self.sequential and self.use_vocab and self.stop_words is not None:
217
+ x = [w for w in x if w not in self.stop_words]
218
+ if hasattr(self, "preprocessing") and self.preprocessing is not None:
219
+ return self.preprocessing(x)
220
+ else:
221
+ return x
222
+
223
+ def process(self, batch, device=None):
224
+ """ Process a list of examples to create a torch.Tensor.
225
+
226
+ Pad, numericalize, and postprocess a batch and create a tensor.
227
+
228
+ Args:
229
+ batch (list(object)): A list of object from a batch of examples.
230
+ Returns:
231
+ torch.autograd.Variable: Processed object given the input
232
+ and custom postprocessing Pipeline.
233
+ """
234
+ padded = self.pad(batch)
235
+ tensor = self.numericalize(padded, device=device)
236
+ return tensor
237
+
238
+ def pad(self, minibatch):
239
+ """Pad a batch of examples using this field.
240
+
241
+ Pads to self.fix_length if provided, otherwise pads to the length of
242
+ the longest example in the batch. Prepends self.init_token and appends
243
+ self.eos_token if those attributes are not None. Returns a tuple of the
244
+ padded list and a list containing lengths of each example if
245
+ `self.include_lengths` is `True` and `self.sequential` is `True`, else just
246
+ returns the padded list. If `self.sequential` is `False`, no padding is applied.
247
+ """
248
+ minibatch = list(minibatch)
249
+ if not self.sequential:
250
+ return minibatch
251
+ if self.fix_length is None:
252
+ max_len = max(len(x) for x in minibatch)
253
+ else:
254
+ max_len = self.fix_length + (
255
+ self.init_token, self.eos_token).count(None) - 2
256
+ padded, lengths = [], []
257
+ for x in minibatch:
258
+ if self.pad_first:
259
+ padded.append(
260
+ [self.pad_token] * max(0, max_len - len(x))
261
+ + ([] if self.init_token is None else [self.init_token])
262
+ + list(x[-max_len:] if self.truncate_first else x[:max_len])
263
+ + ([] if self.eos_token is None else [self.eos_token]))
264
+ else:
265
+ padded.append(
266
+ ([] if self.init_token is None else [self.init_token])
267
+ + list(x[-max_len:] if self.truncate_first else x[:max_len])
268
+ + ([] if self.eos_token is None else [self.eos_token])
269
+ + [self.pad_token] * max(0, max_len - len(x)))
270
+ lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
271
+ if self.include_lengths:
272
+ return (padded, lengths)
273
+ return padded
274
+
275
+ def build_vocab(self, *args, **kwargs):
276
+ """Construct the Vocab object for this field from one or more datasets.
277
+
278
+ Arguments:
279
+ Positional arguments: Dataset objects or other iterable data
280
+ sources from which to construct the Vocab object that
281
+ represents the set of possible values for this field. If
282
+ a Dataset object is provided, all columns corresponding
283
+ to this field are used; individual columns can also be
284
+ provided directly.
285
+ Remaining keyword arguments: Passed to the constructor of Vocab.
286
+ """
287
+ counter = Counter()
288
+ sources = []
289
+ for arg in args:
290
+ sources.append(arg)
291
+ for data in sources:
292
+ for x in data:
293
+ if not self.sequential:
294
+ x = [x]
295
+ try:
296
+ counter.update(x)
297
+ except TypeError:
298
+ counter.update(chain.from_iterable(x))
299
+ specials = list(OrderedDict.fromkeys(
300
+ tok for tok in [self.unk_token, self.pad_token, self.init_token,
301
+ self.eos_token] + kwargs.pop('specials', [])
302
+ if tok is not None))
303
+ self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
304
+
305
+ def numericalize(self, arr, device=None):
306
+ """Turn a batch of examples that use this field into a Variable.
307
+
308
+ If the field has include_lengths=True, a tensor of lengths will be
309
+ included in the return value.
310
+
311
+ Arguments:
312
+ arr (List[List[str]], or tuple of (List[List[str]], List[int])):
313
+ List of tokenized and padded examples, or tuple of List of
314
+ tokenized and padded examples and List of lengths of each
315
+ example if self.include_lengths is True.
316
+ device (str or torch.device): A string or instance of `torch.device`
317
+ specifying which device the Variables are going to be created on.
318
+ If left as default, the tensors will be created on cpu. Default: None.
319
+ """
320
+ if self.include_lengths and not isinstance(arr, tuple):
321
+ raise ValueError("Field has include_lengths set to True, but "
322
+ "input data is not a tuple of "
323
+ "(data batch, batch lengths).")
324
+ if isinstance(arr, tuple):
325
+ arr, lengths = arr
326
+ lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
327
+
328
+ if self.use_vocab:
329
+ if self.sequential:
330
+ arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
331
+ else:
332
+ arr = [self.vocab.stoi[x] for x in arr]
333
+
334
+ if self.postprocessing is not None:
335
+ arr = self.postprocessing(arr, self.vocab)
336
+ else:
337
+ if self.dtype not in self.dtypes:
338
+ raise ValueError(
339
+ "Specified Field dtype {} can not be used with "
340
+ "use_vocab=False because we do not know how to numericalize it. "
341
+ "Please raise an issue at "
342
+ "https://github.com/pytorch/text/issues".format(self.dtype))
343
+ numericalization_func = self.dtypes[self.dtype]
344
+ # It doesn't make sense to explicitly coerce to a numeric type if
345
+ # the data is sequential, since it's unclear how to coerce padding tokens
346
+ # to a numeric type.
347
+ if not self.sequential:
348
+ arr = [numericalization_func(x) if isinstance(x, six.string_types)
349
+ else x for x in arr]
350
+ if self.postprocessing is not None:
351
+ arr = self.postprocessing(arr, None)
352
+
353
+ var = torch.tensor(arr, dtype=self.dtype, device=device)
354
+
355
+ if self.sequential and not self.batch_first:
356
+ var.t_()
357
+ if self.sequential:
358
+ var = var.contiguous()
359
+
360
+ if self.include_lengths:
361
+ return var, lengths
362
+ return var
363
+
364
+
365
+ class NestedField(Field):
366
+ """A nested field.
367
+
368
+ A nested field holds another field (called *nesting field*), accepts an untokenized
369
+ string or a list string tokens and groups and treats them as one field as described
370
+ by the nesting field. Every token will be preprocessed, padded, etc. in the manner
371
+ specified by the nesting field. Note that this means a nested field always has
372
+ ``sequential=True``. The two fields' vocabularies will be shared. Their
373
+ numericalization results will be stacked into a single tensor. And NestedField will
374
+ share the same include_lengths with nesting_field, so one shouldn't specify the
375
+ include_lengths in the nesting_field. This field is
376
+ primarily used to implement character embeddings. See ``tests/data/test_field.py``
377
+ for examples on how to use this field.
378
+
379
+ Arguments:
380
+ nesting_field (Field): A field contained in this nested field.
381
+ use_vocab (bool): Whether to use a Vocab object. If False, the data in this
382
+ field should already be numerical. Default: ``True``.
383
+ init_token (str): A token that will be prepended to every example using this
384
+ field, or None for no initial token. Default: ``None``.
385
+ eos_token (str): A token that will be appended to every example using this
386
+ field, or None for no end-of-sentence token. Default: ``None``.
387
+ fix_length (int): A fixed length that all examples using this field will be
388
+ padded to, or ``None`` for flexible sequence lengths. Default: ``None``.
389
+ dtype: The torch.dtype class that represents a batch of examples
390
+ of this kind of data. Default: ``torch.long``.
391
+ preprocessing (Pipeline): The Pipeline that will be applied to examples
392
+ using this field after tokenizing but before numericalizing. Many
393
+ Datasets replace this attribute with a custom preprocessor.
394
+ Default: ``None``.
395
+ postprocessing (Pipeline): A Pipeline that will be applied to examples using
396
+ this field after numericalizing but before the numbers are turned
397
+ into a Tensor. The pipeline function takes the batch as a list, and
398
+ the field's Vocab. Default: ``None``.
399
+ include_lengths: Whether to return a tuple of a padded minibatch and
400
+ a list containing the lengths of each examples, or just a padded
401
+ minibatch. Default: False.
402
+ tokenize: The function used to tokenize strings using this field into
403
+ sequential examples. If "spacy", the SpaCy tokenizer is
404
+ used. If a non-serializable function is passed as an argument,
405
+ the field will not be able to be serialized. Default: string.split.
406
+ tokenizer_language: The language of the tokenizer to be constructed.
407
+ Various languages currently supported only in SpaCy.
408
+ pad_token (str): The string token used as padding. If ``nesting_field`` is
409
+ sequential, this will be set to its ``pad_token``. Default: ``"<pad>"``.
410
+ pad_first (bool): Do the padding of the sequence at the beginning. Default:
411
+ ``False``.
412
+ """
413
+
414
+ def __init__(self, nesting_field, use_vocab=True, init_token=None, eos_token=None,
415
+ fix_length=None, dtype=torch.long, preprocessing=None,
416
+ postprocessing=None, tokenize=None, tokenizer_language='en',
417
+ include_lengths=False, pad_token='<pad>',
418
+ pad_first=False, truncate_first=False):
419
+ if isinstance(nesting_field, NestedField):
420
+ raise ValueError('nesting field must not be another NestedField')
421
+ if nesting_field.include_lengths:
422
+ raise ValueError('nesting field cannot have include_lengths=True')
423
+
424
+ if nesting_field.sequential:
425
+ pad_token = nesting_field.pad_token
426
+ super(NestedField, self).__init__(
427
+ use_vocab=use_vocab,
428
+ init_token=init_token,
429
+ eos_token=eos_token,
430
+ fix_length=fix_length,
431
+ dtype=dtype,
432
+ preprocessing=preprocessing,
433
+ postprocessing=postprocessing,
434
+ lower=nesting_field.lower,
435
+ tokenize=tokenize,
436
+ tokenizer_language=tokenizer_language,
437
+ batch_first=True,
438
+ pad_token=pad_token,
439
+ unk_token=nesting_field.unk_token,
440
+ pad_first=pad_first,
441
+ truncate_first=truncate_first,
442
+ include_lengths=include_lengths
443
+ )
444
+ self.nesting_field = nesting_field
445
+ # in case the user forget to do that
446
+ self.nesting_field.batch_first = True
447
+
448
+ def preprocess(self, xs):
449
+ """Preprocess a single example.
450
+
451
+ Firstly, tokenization and the supplied preprocessing pipeline is applied. Since
452
+ this field is always sequential, the result is a list. Then, each element of
453
+ the list is preprocessed using ``self.nesting_field.preprocess`` and the resulting
454
+ list is returned.
455
+
456
+ Arguments:
457
+ xs (list or str): The input to preprocess.
458
+
459
+ Returns:
460
+ list: The preprocessed list.
461
+ """
462
+ return [self.nesting_field.preprocess(x)
463
+ for x in super(NestedField, self).preprocess(xs)]
464
+
465
+ def pad(self, minibatch):
466
+ """Pad a batch of examples using this field.
467
+
468
+ If ``self.nesting_field.sequential`` is ``False``, each example in the batch must
469
+ be a list of string tokens, and pads them as if by a ``Field`` with
470
+ ``sequential=True``. Otherwise, each example must be a list of list of tokens.
471
+ Using ``self.nesting_field``, pads the list of tokens to
472
+ ``self.nesting_field.fix_length`` if provided, or otherwise to the length of the
473
+ longest list of tokens in the batch. Next, using this field, pads the result by
474
+ filling short examples with ``self.nesting_field.pad_token``.
475
+
476
+ Example:
477
+ >>> import pprint
478
+ >>> pp = pprint.PrettyPrinter(indent=4)
479
+ >>>
480
+ >>> nesting_field = Field(pad_token='<c>', init_token='<w>', eos_token='</w>')
481
+ >>> field = NestedField(nesting_field, init_token='<s>', eos_token='</s>')
482
+ >>> minibatch = [
483
+ ... [list('john'), list('loves'), list('mary')],
484
+ ... [list('mary'), list('cries')],
485
+ ... ]
486
+ >>> padded = field.pad(minibatch)
487
+ >>> pp.pprint(padded)
488
+ [ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
489
+ ['<w>', 'j', 'o', 'h', 'n', '</w>', '<c>'],
490
+ ['<w>', 'l', 'o', 'v', 'e', 's', '</w>'],
491
+ ['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
492
+ ['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>']],
493
+ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
494
+ ['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
495
+ ['<w>', 'c', 'r', 'i', 'e', 's', '</w>'],
496
+ ['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
497
+ ['<c>', '<c>', '<c>', '<c>', '<c>', '<c>', '<c>']]]
498
+
499
+ Arguments:
500
+ minibatch (list): Each element is a list of string if
501
+ ``self.nesting_field.sequential`` is ``False``, a list of list of string
502
+ otherwise.
503
+
504
+ Returns:
505
+ list: The padded minibatch. or (padded, sentence_lens, word_lengths)
506
+ """
507
+ minibatch = list(minibatch)
508
+ if not self.nesting_field.sequential:
509
+ return super(NestedField, self).pad(minibatch)
510
+
511
+ # Save values of attributes to be monkeypatched
512
+ old_pad_token = self.pad_token
513
+ old_init_token = self.init_token
514
+ old_eos_token = self.eos_token
515
+ old_fix_len = self.nesting_field.fix_length
516
+ # Monkeypatch the attributes
517
+ if self.nesting_field.fix_length is None:
518
+ max_len = max(len(xs) for ex in minibatch for xs in ex)
519
+ fix_len = max_len + 2 - (self.nesting_field.init_token,
520
+ self.nesting_field.eos_token).count(None)
521
+ self.nesting_field.fix_length = fix_len
522
+ self.pad_token = [self.pad_token] * self.nesting_field.fix_length
523
+ if self.init_token is not None:
524
+ # self.init_token = self.nesting_field.pad([[self.init_token]])[0]
525
+ self.init_token = [self.init_token]
526
+ if self.eos_token is not None:
527
+ # self.eos_token = self.nesting_field.pad([[self.eos_token]])[0]
528
+ self.eos_token = [self.eos_token]
529
+ # Do padding
530
+ old_include_lengths = self.include_lengths
531
+ self.include_lengths = True
532
+ self.nesting_field.include_lengths = True
533
+ padded, sentence_lengths = super(NestedField, self).pad(minibatch)
534
+ padded_with_lengths = [self.nesting_field.pad(ex) for ex in padded]
535
+ word_lengths = []
536
+ final_padded = []
537
+ max_sen_len = len(padded[0])
538
+ for (pad, lens), sentence_len in zip(padded_with_lengths, sentence_lengths):
539
+ if sentence_len == max_sen_len:
540
+ lens = lens
541
+ pad = pad
542
+ elif self.pad_first:
543
+ lens[:(max_sen_len - sentence_len)] = (
544
+ [0] * (max_sen_len - sentence_len))
545
+ pad[:(max_sen_len - sentence_len)] = (
546
+ [self.pad_token] * (max_sen_len - sentence_len))
547
+ else:
548
+ lens[-(max_sen_len - sentence_len):] = (
549
+ [0] * (max_sen_len - sentence_len))
550
+ pad[-(max_sen_len - sentence_len):] = (
551
+ [self.pad_token] * (max_sen_len - sentence_len))
552
+ word_lengths.append(lens)
553
+ final_padded.append(pad)
554
+ padded = final_padded
555
+
556
+ # Restore monkeypatched attributes
557
+ self.nesting_field.fix_length = old_fix_len
558
+ self.pad_token = old_pad_token
559
+ self.init_token = old_init_token
560
+ self.eos_token = old_eos_token
561
+ self.include_lengths = old_include_lengths
562
+ if self.include_lengths:
563
+ return padded, sentence_lengths, word_lengths
564
+ return padded
565
+
566
+ def build_vocab(self, *args, **kwargs):
567
+ """Construct the Vocab object for nesting field and combine it with this field's vocab.
568
+
569
+ Arguments:
570
+ Positional arguments: Dataset objects or other iterable data
571
+ sources from which to construct the Vocab object that
572
+ represents the set of possible values for the nesting field. If
573
+ a Dataset object is provided, all columns corresponding
574
+ to this field are used; individual columns can also be
575
+ provided directly.
576
+ Remaining keyword arguments: Passed to the constructor of Vocab.
577
+ """
578
+ sources = []
579
+ for arg in args:
580
+ sources.append(arg)
581
+
582
+ flattened = []
583
+ for source in sources:
584
+ flattened.extend(source)
585
+ old_vectors = None
586
+ old_unk_init = None
587
+ old_vectors_cache = None
588
+ if "vectors" in kwargs.keys():
589
+ old_vectors = kwargs["vectors"]
590
+ kwargs["vectors"] = None
591
+ if "unk_init" in kwargs.keys():
592
+ old_unk_init = kwargs["unk_init"]
593
+ kwargs["unk_init"] = None
594
+ if "vectors_cache" in kwargs.keys():
595
+ old_vectors_cache = kwargs["vectors_cache"]
596
+ kwargs["vectors_cache"] = None
597
+ # just build vocab and does not load vector
598
+ self.nesting_field.build_vocab(*flattened, **kwargs)
599
+ super(NestedField, self).build_vocab()
600
+ self.vocab.extend(self.nesting_field.vocab)
601
+ self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
602
+ if old_vectors is not None:
603
+ self.vocab.load_vectors(old_vectors,
604
+ unk_init=old_unk_init, cache=old_vectors_cache)
605
+
606
+ self.nesting_field.vocab = self.vocab
607
+
608
+ def numericalize(self, arrs, device=None):
609
+ """Convert a padded minibatch into a variable tensor.
610
+
611
+ Each item in the minibatch will be numericalized independently and the resulting
612
+ tensors will be stacked at the first dimension.
613
+
614
+ Arguments:
615
+ arr (List[List[str]]): List of tokenized and padded examples.
616
+ device (str or torch.device): A string or instance of `torch.device`
617
+ specifying which device the Variables are going to be created on.
618
+ If left as default, the tensors will be created on cpu. Default: None.
619
+ """
620
+ numericalized = []
621
+ self.nesting_field.include_lengths = False
622
+ if self.include_lengths:
623
+ arrs, sentence_lengths, word_lengths = arrs
624
+
625
+ for arr in arrs:
626
+ numericalized_ex = self.nesting_field.numericalize(
627
+ arr, device=device)
628
+ numericalized.append(numericalized_ex)
629
+ padded_batch = torch.stack(numericalized)
630
+
631
+ self.nesting_field.include_lengths = True
632
+ if self.include_lengths:
633
+ sentence_lengths = \
634
+ torch.tensor(sentence_lengths, dtype=self.dtype, device=device)
635
+ word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
636
+ return (padded_batch, sentence_lengths, word_lengths)
637
+ return padded_batch
data/field/mini_torchtext/pipeline.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Pipeline(object):
2
+ """Defines a pipeline for transforming sequence data.
3
+
4
+ The input is assumed to be utf-8 encoded `str` (Python 3) or
5
+ `unicode` (Python 2).
6
+
7
+ Attributes:
8
+ convert_token: The function to apply to input sequence data.
9
+ pipes: The Pipelines that will be applied to input sequence
10
+ data in order.
11
+ """
12
+
13
+ def __init__(self, convert_token=None):
14
+ """Create a pipeline.
15
+
16
+ Arguments:
17
+ convert_token: The function to apply to input sequence data.
18
+ If None, the identity function is used. Default: None
19
+ """
20
+ if convert_token is None:
21
+ self.convert_token = Pipeline.identity
22
+ elif callable(convert_token):
23
+ self.convert_token = convert_token
24
+ else:
25
+ raise ValueError("Pipeline input convert_token {} is not None "
26
+ "or callable".format(convert_token))
27
+ self.pipes = [self]
28
+
29
+ def __call__(self, x, *args):
30
+ """Apply the the current Pipeline(s) to an input.
31
+
32
+ Arguments:
33
+ x: The input to process with the Pipeline(s).
34
+ Positional arguments: Forwarded to the `call` function
35
+ of the Pipeline(s).
36
+ """
37
+ for pipe in self.pipes:
38
+ x = pipe.call(x, *args)
39
+ return x
40
+
41
+ def call(self, x, *args):
42
+ """Apply _only_ the convert_token function of the current pipeline
43
+ to the input. If the input is a list, a list with the results of
44
+ applying the `convert_token` function to all input elements is
45
+ returned.
46
+
47
+ Arguments:
48
+ x: The input to apply the convert_token function to.
49
+ Positional arguments: Forwarded to the `convert_token` function
50
+ of the current Pipeline.
51
+ """
52
+ if isinstance(x, list):
53
+ return [self.convert_token(tok, *args) for tok in x]
54
+ return self.convert_token(x, *args)
55
+
56
+ def add_before(self, pipeline):
57
+ """Add a Pipeline to be applied before this processing pipeline.
58
+
59
+ Arguments:
60
+ pipeline: The Pipeline or callable to apply before this
61
+ Pipeline.
62
+ """
63
+ if not isinstance(pipeline, Pipeline):
64
+ pipeline = Pipeline(pipeline)
65
+ self.pipes = pipeline.pipes[:] + self.pipes[:]
66
+ return self
67
+
68
+ def add_after(self, pipeline):
69
+ """Add a Pipeline to be applied after this processing pipeline.
70
+
71
+ Arguments:
72
+ pipeline: The Pipeline or callable to apply after this
73
+ Pipeline.
74
+ """
75
+ if not isinstance(pipeline, Pipeline):
76
+ pipeline = Pipeline(pipeline)
77
+ self.pipes = self.pipes[:] + pipeline.pipes[:]
78
+ return self
79
+
80
+ @staticmethod
81
+ def identity(x):
82
+ """Return a copy of the input.
83
+
84
+ This is here for serialization compatibility with pickle.
85
+ """
86
+ return x
data/field/mini_torchtext/utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from contextlib import contextmanager
3
+ from copy import deepcopy
4
+ import re
5
+
6
+ from functools import partial
7
+
8
+
9
+ def _split_tokenizer(x):
10
+ return x.split()
11
+
12
+
13
+ def _spacy_tokenize(x, spacy):
14
+ return [tok.text for tok in spacy.tokenizer(x)]
15
+
16
+
17
+ _patterns = [r'\'',
18
+ r'\"',
19
+ r'\.',
20
+ r'<br \/>',
21
+ r',',
22
+ r'\(',
23
+ r'\)',
24
+ r'\!',
25
+ r'\?',
26
+ r'\;',
27
+ r'\:',
28
+ r'\s+']
29
+
30
+ _replacements = [' \' ',
31
+ '',
32
+ ' . ',
33
+ ' ',
34
+ ' , ',
35
+ ' ( ',
36
+ ' ) ',
37
+ ' ! ',
38
+ ' ? ',
39
+ ' ',
40
+ ' ',
41
+ ' ']
42
+
43
+ _patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements))
44
+
45
+
46
+ def _basic_english_normalize(line):
47
+ r"""
48
+ Basic normalization for a line of text.
49
+ Normalization includes
50
+ - lowercasing
51
+ - complete some basic text normalization for English words as follows:
52
+ add spaces before and after '\''
53
+ remove '\"',
54
+ add spaces before and after '.'
55
+ replace '<br \/>'with single space
56
+ add spaces before and after ','
57
+ add spaces before and after '('
58
+ add spaces before and after ')'
59
+ add spaces before and after '!'
60
+ add spaces before and after '?'
61
+ replace ';' with single space
62
+ replace ':' with single space
63
+ replace multiple spaces with single space
64
+
65
+ Returns a list of tokens after splitting on whitespace.
66
+ """
67
+
68
+ line = line.lower()
69
+ for pattern_re, replaced_str in _patterns_dict:
70
+ line = pattern_re.sub(replaced_str, line)
71
+ return line.split()
72
+
73
+
74
+ def get_tokenizer(tokenizer, language='en'):
75
+ r"""
76
+ Generate tokenizer function for a string sentence.
77
+
78
+ Arguments:
79
+ tokenizer: the name of tokenizer function. If None, it returns split()
80
+ function, which splits the string sentence by space.
81
+ If basic_english, it returns _basic_english_normalize() function,
82
+ which normalize the string first and split by space. If a callable
83
+ function, it will return the function. If a tokenizer library
84
+ (e.g. spacy, moses, toktok, revtok, subword), it returns the
85
+ corresponding library.
86
+ language: Default en
87
+
88
+ Examples:
89
+ >>> import torchtext
90
+ >>> from torchtext.data import get_tokenizer
91
+ >>> tokenizer = get_tokenizer("basic_english")
92
+ >>> tokens = tokenizer("You can now install TorchText using pip!")
93
+ >>> tokens
94
+ >>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']
95
+
96
+ """
97
+
98
+ # default tokenizer is string.split(), added as a module function for serialization
99
+ if tokenizer is None:
100
+ return _split_tokenizer
101
+
102
+ if tokenizer == "basic_english":
103
+ if language != 'en':
104
+ raise ValueError("Basic normalization is only available for Enlish(en)")
105
+ return _basic_english_normalize
106
+
107
+ # simply return if a function is passed
108
+ if callable(tokenizer):
109
+ return tokenizer
110
+
111
+ if tokenizer == "spacy":
112
+ try:
113
+ import spacy
114
+ spacy = spacy.load(language)
115
+ return partial(_spacy_tokenize, spacy=spacy)
116
+ except ImportError:
117
+ print("Please install SpaCy. "
118
+ "See the docs at https://spacy.io for more information.")
119
+ raise
120
+ except AttributeError:
121
+ print("Please install SpaCy and the SpaCy {} tokenizer. "
122
+ "See the docs at https://spacy.io for more "
123
+ "information.".format(language))
124
+ raise
125
+ elif tokenizer == "moses":
126
+ try:
127
+ from sacremoses import MosesTokenizer
128
+ moses_tokenizer = MosesTokenizer()
129
+ return moses_tokenizer.tokenize
130
+ except ImportError:
131
+ print("Please install SacreMoses. "
132
+ "See the docs at https://github.com/alvations/sacremoses "
133
+ "for more information.")
134
+ raise
135
+ elif tokenizer == "toktok":
136
+ try:
137
+ from nltk.tokenize.toktok import ToktokTokenizer
138
+ toktok = ToktokTokenizer()
139
+ return toktok.tokenize
140
+ except ImportError:
141
+ print("Please install NLTK. "
142
+ "See the docs at https://nltk.org for more information.")
143
+ raise
144
+ elif tokenizer == 'revtok':
145
+ try:
146
+ import revtok
147
+ return revtok.tokenize
148
+ except ImportError:
149
+ print("Please install revtok.")
150
+ raise
151
+ elif tokenizer == 'subword':
152
+ try:
153
+ import revtok
154
+ return partial(revtok.tokenize, decap=True)
155
+ except ImportError:
156
+ print("Please install revtok.")
157
+ raise
158
+ raise ValueError("Requested tokenizer {}, valid choices are a "
159
+ "callable that takes a single string as input, "
160
+ "\"revtok\" for the revtok reversible tokenizer, "
161
+ "\"subword\" for the revtok caps-aware tokenizer, "
162
+ "\"spacy\" for the SpaCy English tokenizer, or "
163
+ "\"moses\" for the NLTK port of the Moses tokenization "
164
+ "script.".format(tokenizer))
165
+
166
+
167
+ def is_tokenizer_serializable(tokenizer, language):
168
+ """Extend with other tokenizers which are found to not be serializable
169
+ """
170
+ if tokenizer == 'spacy':
171
+ return False
172
+ return True
173
+
174
+
175
+ def interleave_keys(a, b):
176
+ """Interleave bits from two sort keys to form a joint sort key.
177
+
178
+ Examples that are similar in both of the provided keys will have similar
179
+ values for the key defined by this function. Useful for tasks with two
180
+ text fields like machine translation or natural language inference.
181
+ """
182
+ def interleave(args):
183
+ return ''.join([x for t in zip(*args) for x in t])
184
+ return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2)
185
+
186
+
187
+ def get_torch_version():
188
+ import torch
189
+ v = torch.__version__
190
+ version_substrings = v.split('.')
191
+ major, minor = version_substrings[0], version_substrings[1]
192
+ return int(major), int(minor)
193
+
194
+
195
+ def dtype_to_attr(dtype):
196
+ # convert torch.dtype to dtype string id
197
+ # e.g. torch.int32 -> "int32"
198
+ # used for serialization
199
+ _, dtype = str(dtype).split('.')
200
+ return dtype
201
+
202
+
203
+ # TODO: Write more tests!
204
+ def ngrams_iterator(token_list, ngrams):
205
+ """Return an iterator that yields the given tokens and their ngrams.
206
+
207
+ Arguments:
208
+ token_list: A list of tokens
209
+ ngrams: the number of ngrams.
210
+
211
+ Examples:
212
+ >>> token_list = ['here', 'we', 'are']
213
+ >>> list(ngrams_iterator(token_list, 2))
214
+ >>> ['here', 'here we', 'we', 'we are', 'are']
215
+ """
216
+
217
+ def _get_ngrams(n):
218
+ return zip(*[token_list[i:] for i in range(n)])
219
+
220
+ for x in token_list:
221
+ yield x
222
+ for n in range(2, ngrams + 1):
223
+ for x in _get_ngrams(n):
224
+ yield ' '.join(x)
225
+
226
+
227
+ class RandomShuffler(object):
228
+ """Use random functions while keeping track of the random state to make it
229
+ reproducible and deterministic."""
230
+
231
+ def __init__(self, random_state=None):
232
+ self._random_state = random_state
233
+ if self._random_state is None:
234
+ self._random_state = random.getstate()
235
+
236
+ @contextmanager
237
+ def use_internal_state(self):
238
+ """Use a specific RNG state."""
239
+ old_state = random.getstate()
240
+ random.setstate(self._random_state)
241
+ yield
242
+ self._random_state = random.getstate()
243
+ random.setstate(old_state)
244
+
245
+ @property
246
+ def random_state(self):
247
+ return deepcopy(self._random_state)
248
+
249
+ @random_state.setter
250
+ def random_state(self, s):
251
+ self._random_state = s
252
+
253
+ def __call__(self, data):
254
+ """Shuffle and return a new list."""
255
+ with self.use_internal_state():
256
+ return random.sample(data, len(data))
data/field/mini_torchtext/vocab.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import unicode_literals
2
+ from collections import defaultdict
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class Vocab(object):
9
+ """Defines a vocabulary object that will be used to numericalize a field.
10
+
11
+ Attributes:
12
+ freqs: A collections.Counter object holding the frequencies of tokens
13
+ in the data used to build the Vocab.
14
+ stoi: A collections.defaultdict instance mapping token strings to
15
+ numerical identifiers.
16
+ itos: A list of token strings indexed by their numerical identifiers.
17
+ """
18
+
19
+ # TODO (@mttk): Populate classs with default values of special symbols
20
+ UNK = '<unk>'
21
+
22
+ def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True):
23
+ """Create a Vocab object from a collections.Counter.
24
+
25
+ Arguments:
26
+ counter: collections.Counter object holding the frequencies of
27
+ each value found in the data.
28
+ max_size: The maximum size of the vocabulary, or None for no
29
+ maximum. Default: None.
30
+ min_freq: The minimum frequency needed to include a token in the
31
+ vocabulary. Values less than 1 will be set to 1. Default: 1.
32
+ specials: The list of special tokens (e.g., padding or eos) that
33
+ will be prepended to the vocabulary. Default: ['<unk'>, '<pad>']
34
+ specials_first: Whether to add special tokens into the vocabulary at first.
35
+ If it is False, they are added into the vocabulary at last.
36
+ Default: True.
37
+ """
38
+ self.freqs = counter
39
+ counter = counter.copy()
40
+ min_freq = max(min_freq, 1)
41
+
42
+ self.itos = list()
43
+ self.unk_index = None
44
+ if specials_first:
45
+ self.itos = list(specials)
46
+ # only extend max size if specials are prepended
47
+ max_size = None if max_size is None else max_size + len(specials)
48
+
49
+ # frequencies of special tokens are not counted when building vocabulary
50
+ # in frequency order
51
+ for tok in specials:
52
+ del counter[tok]
53
+
54
+ # sort by frequency, then alphabetically
55
+ words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
56
+ words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
57
+
58
+ for word, freq in words_and_frequencies:
59
+ if freq < min_freq or len(self.itos) == max_size:
60
+ break
61
+ self.itos.append(word)
62
+
63
+ if Vocab.UNK in specials: # hard-coded for now
64
+ unk_index = specials.index(Vocab.UNK) # position in list
65
+ # account for ordering of specials, set variable
66
+ self.unk_index = unk_index if specials_first else len(self.itos) + unk_index
67
+ self.stoi = defaultdict(self._default_unk_index)
68
+ else:
69
+ self.stoi = defaultdict()
70
+
71
+ if not specials_first:
72
+ self.itos.extend(list(specials))
73
+
74
+ # stoi is simply a reverse dict for itos
75
+ self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
76
+
77
+ def _default_unk_index(self):
78
+ return self.unk_index
79
+
80
+ def __getitem__(self, token):
81
+ return self.stoi.get(token, self.stoi.get(Vocab.UNK))
82
+
83
+ def __getstate__(self):
84
+ # avoid picking defaultdict
85
+ attrs = dict(self.__dict__)
86
+ # cast to regular dict
87
+ attrs['stoi'] = dict(self.stoi)
88
+ return attrs
89
+
90
+ def __setstate__(self, state):
91
+ if state.get("unk_index", None) is None:
92
+ stoi = defaultdict()
93
+ else:
94
+ stoi = defaultdict(self._default_unk_index)
95
+ stoi.update(state['stoi'])
96
+ state['stoi'] = stoi
97
+ self.__dict__.update(state)
98
+
99
+ def __eq__(self, other):
100
+ if self.freqs != other.freqs:
101
+ return False
102
+ if self.stoi != other.stoi:
103
+ return False
104
+ if self.itos != other.itos:
105
+ return False
106
+ return True
107
+
108
+ def __len__(self):
109
+ return len(self.itos)
110
+
111
+ def extend(self, v, sort=False):
112
+ words = sorted(v.itos) if sort else v.itos
113
+ for w in words:
114
+ if w not in self.stoi:
115
+ self.itos.append(w)
116
+ self.stoi[w] = len(self.itos) - 1
data/field/nested_field.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import NestedField as TorchTextNestedField
6
+
7
+
8
+ class NestedField(TorchTextNestedField):
9
+ def pad(self, example):
10
+ self.nesting_field.include_lengths = self.include_lengths
11
+ if not self.include_lengths:
12
+ return self.nesting_field.pad(example)
13
+
14
+ sentence_length = len(example)
15
+ example, word_lengths = self.nesting_field.pad(example)
16
+ return example, sentence_length, word_lengths
17
+
18
+ def numericalize(self, arr, device=None):
19
+ numericalized = []
20
+ self.nesting_field.include_lengths = False
21
+ if self.include_lengths:
22
+ arr, sentence_length, word_lengths = arr
23
+
24
+ numericalized = self.nesting_field.numericalize(arr, device=device)
25
+
26
+ self.nesting_field.include_lengths = True
27
+ if self.include_lengths:
28
+ sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device)
29
+ word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
30
+ return (numericalized, sentence_length, word_lengths)
31
+ return numericalized
32
+
33
+ def build_vocab(self, *args, **kwargs):
34
+ sources = []
35
+ for arg in args:
36
+ if isinstance(arg, torch.utils.data.Dataset):
37
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
38
+ else:
39
+ sources.append(arg)
40
+
41
+ flattened = []
42
+ for source in sources:
43
+ flattened.extend(source)
44
+
45
+ # just build vocab and does not load vector
46
+ self.nesting_field.build_vocab(*flattened, **kwargs)
47
+ super(TorchTextNestedField, self).build_vocab()
48
+ self.vocab.extend(self.nesting_field.vocab)
49
+ self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
50
+ self.nesting_field.vocab = self.vocab
data/parser/__init__.py ADDED
File without changes
data/parser/from_mrp/__init__.py ADDED
File without changes
data/parser/from_mrp/abstract_parser.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.parser.json_parser import example_from_json
6
+
7
+
8
+ class AbstractParser(torch.utils.data.Dataset):
9
+ def __init__(self, fields, data, filter_pred=None):
10
+ super(AbstractParser, self).__init__()
11
+
12
+ self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())]
13
+
14
+ if isinstance(fields, dict):
15
+ fields, field_dict = [], fields
16
+ for field in field_dict.values():
17
+ if isinstance(field, list):
18
+ fields.extend(field)
19
+ else:
20
+ fields.append(field)
21
+
22
+ if filter_pred is not None:
23
+ make_list = isinstance(self.examples, list)
24
+ self.examples = filter(filter_pred, self.examples)
25
+ if make_list:
26
+ self.examples = list(self.examples)
27
+
28
+ self.fields = dict(fields)
29
+
30
+ # Unpack field tuples
31
+ for n, f in list(self.fields.items()):
32
+ if isinstance(n, tuple):
33
+ self.fields.update(zip(n, f))
34
+ del self.fields[n]
35
+
36
+ def __getitem__(self, i):
37
+ item = self.examples[i]
38
+ processed_item = {}
39
+ for (name, field) in self.fields.items():
40
+ if field is not None:
41
+ processed_item[name] = field.process(getattr(item, name), device=None)
42
+ return processed_item
43
+
44
+ def __len__(self):
45
+ return len(self.examples)
46
+
47
+ def get_examples(self, attr):
48
+ if attr in self.fields:
49
+ for x in self.examples:
50
+ yield getattr(x, attr)
data/parser/from_mrp/evaluation_parser.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class EvaluationParser(AbstractParser):
9
+ def __init__(self, args, fields):
10
+ path = args.test_data
11
+ self.data = utils.load_dataset(path)
12
+
13
+ for sentence in self.data.values():
14
+ sentence["token anchors"] = [[a["from"], a["to"]] for a in sentence["token anchors"]]
15
+
16
+ utils.create_bert_tokens(self.data, args.encoder)
17
+
18
+ super(EvaluationParser, self).__init__(fields, self.data)
data/parser/from_mrp/labeled_edge_parser.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class LabeledEdgeParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, n_node_token_pairs = 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for edge in sentence["edges"]:
21
+ if "label" not in edge:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, sentence in utils.node_generator(self.data):
26
+ node["label"] = "Node"
27
+
28
+ self.node_counter += 1
29
+
30
+ utils.create_bert_tokens(self.data, args.encoder)
31
+
32
+ # create edge vectors
33
+ for sentence in self.data.values():
34
+ assert sentence["tops"] == [0], sentence
35
+ N = len(sentence["nodes"])
36
+
37
+ edge_count = utils.create_edges(sentence)
38
+ self.edge_counter += edge_count
39
+ self.no_edge_counter += N * (N - 1) - edge_count
40
+
41
+ sentence["nodes"] = sentence["nodes"][1:]
42
+ N = len(sentence["nodes"])
43
+
44
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
45
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
46
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
47
+ sentence["anchored labels"] = [len(sentence["input"]), []]
48
+ for i, node in enumerate(sentence["nodes"]):
49
+ anchored_labels = []
50
+
51
+ for anchor in node["anchors"]:
52
+ sentence["anchor edges"][-1].append((i, anchor))
53
+ anchored_labels.append((anchor, node["label"]))
54
+
55
+ sentence["anchored labels"][1].append(anchored_labels)
56
+
57
+ anchor_count += len(node["anchors"])
58
+ n_node_token_pairs += len(sentence["input"])
59
+
60
+ sentence["id"] = [sentence["id"]]
61
+
62
+ self.anchor_freq = anchor_count / n_node_token_pairs
63
+ self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
64
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
65
+
66
+ super(LabeledEdgeParser, self).__init__(fields, self.data, filter_pred)
67
+
68
+ @staticmethod
69
+ def node_similarity_key(node):
70
+ return tuple([node["label"]] + node["anchors"])
data/parser/from_mrp/node_centric_parser.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class NodeCentricParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, n_node_token_pairs = 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for node in sentence["nodes"]:
21
+ if "label" not in node:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, _ in utils.node_generator(self.data):
26
+ self.node_counter += 1
27
+
28
+ # print(f"Number of unlabeled nodes: {unlabeled_count}", flush=True)
29
+
30
+ utils.create_bert_tokens(self.data, args.encoder)
31
+
32
+ # create edge vectors
33
+ for sentence in self.data.values():
34
+ N = len(sentence["nodes"])
35
+
36
+ edge_count = utils.create_edges(sentence)
37
+ self.edge_counter += edge_count
38
+ # self.no_edge_counter += len([n for n in sentence["nodes"] if n["label"] in ["Source", "Target"]]) * len([n for n in sentence["nodes"] if n["label"] not in ["Source", "Target"]]) - edge_count
39
+ self.no_edge_counter += N * (N - 1) - edge_count
40
+
41
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
42
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
43
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
44
+ sentence["anchored labels"] = [len(sentence["input"]), []]
45
+ for i, node in enumerate(sentence["nodes"]):
46
+ anchored_labels = []
47
+ #if len(node["anchors"]) == 0:
48
+ # print(f"Empty node in {sentence['id']}", flush=True)
49
+
50
+ for anchor in node["anchors"]:
51
+ sentence["anchor edges"][-1].append((i, anchor))
52
+ anchored_labels.append((anchor, node["label"]))
53
+
54
+ sentence["anchored labels"][1].append(anchored_labels)
55
+
56
+ anchor_count += len(node["anchors"])
57
+ n_node_token_pairs += len(sentence["input"])
58
+
59
+ sentence["id"] = [sentence["id"]]
60
+
61
+ self.anchor_freq = anchor_count / n_node_token_pairs
62
+ self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
63
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
64
+
65
+ super(NodeCentricParser, self).__init__(fields, self.data, filter_pred)
66
+
67
+ @staticmethod
68
+ def node_similarity_key(node):
69
+ return tuple([node["label"]] + node["anchors"])
data/parser/from_mrp/request_parser.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import utility.parser_utils as utils
5
+ from data.parser.from_mrp.abstract_parser import AbstractParser
6
+
7
+
8
+ class RequestParser(AbstractParser):
9
+ def __init__(self, sentences, args, fields):
10
+ self.data = {i: {"id": str(i), "sentence": sentence} for i, sentence in enumerate(sentences)}
11
+
12
+ sentences = [example["sentence"] for example in self.data.values()]
13
+
14
+ for example in self.data.values():
15
+ example["input"] = example["sentence"].strip().split(' ')
16
+ example["token anchors"], offset = [], 0
17
+ for token in example["input"]:
18
+ example["token anchors"].append([offset, offset + len(token)])
19
+ offset += len(token) + 1
20
+
21
+ utils.create_bert_tokens(self.data, args.encoder)
22
+
23
+ super(RequestParser, self).__init__(fields, self.data)
data/parser/from_mrp/sequential_parser.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class SequentialParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, source_anchor_count, target_anchor_count, n_node_token_pairs = 0, 0, 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for node in sentence["nodes"]:
21
+ if "label" not in node:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, _ in utils.node_generator(self.data):
26
+ node["target anchors"] = []
27
+ node["source anchors"] = []
28
+
29
+ for sentence in self.data.values():
30
+ for e in sentence["edges"]:
31
+ source, target = e["source"], e["target"]
32
+
33
+ if sentence["nodes"][target]["label"] == "Target":
34
+ sentence["nodes"][source]["target anchors"] += sentence["nodes"][target]["anchors"]
35
+ elif sentence["nodes"][target]["label"] == "Source":
36
+ sentence["nodes"][source]["source anchors"] += sentence["nodes"][target]["anchors"]
37
+
38
+ for i, node in list(enumerate(sentence["nodes"]))[::-1]:
39
+ if "label" not in node or node["label"] in ["Source", "Target"]:
40
+ del sentence["nodes"][i]
41
+ sentence["edges"] = []
42
+
43
+ for node, sentence in utils.node_generator(self.data):
44
+ self.node_counter += 1
45
+
46
+ utils.create_bert_tokens(self.data, args.encoder)
47
+
48
+ # create edge vectors
49
+ for sentence in self.data.values():
50
+ N = len(sentence["nodes"])
51
+
52
+ utils.create_edges(sentence)
53
+ self.no_edge_counter += N * (N - 1)
54
+
55
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
56
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []]
57
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []]
58
+
59
+ sentence["anchored labels"] = [len(sentence["input"]), []]
60
+ for i, node in enumerate(sentence["nodes"]):
61
+ anchored_labels = []
62
+
63
+ for anchor in node["anchors"]:
64
+ sentence["anchor edges"][-1].append((i, anchor))
65
+ anchored_labels.append((anchor, node["label"]))
66
+
67
+ for anchor in node["source anchors"]:
68
+ sentence["source anchor edges"][-1].append((i, anchor))
69
+ for anchor in node["target anchors"]:
70
+ sentence["target anchor edges"][-1].append((i, anchor))
71
+
72
+ sentence["anchored labels"][1].append(anchored_labels)
73
+
74
+ anchor_count += len(node["anchors"])
75
+ source_anchor_count += len(node["source anchors"])
76
+ target_anchor_count += len(node["target anchors"])
77
+ n_node_token_pairs += len(sentence["input"])
78
+
79
+ sentence["id"] = [sentence["id"]]
80
+
81
+ self.anchor_freq = anchor_count / n_node_token_pairs
82
+ self.source_anchor_freq = anchor_count / n_node_token_pairs
83
+ self.target_anchor_freq = anchor_count / n_node_token_pairs
84
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
85
+
86
+ super(SequentialParser, self).__init__(fields, self.data, filter_pred)
87
+
88
+ @staticmethod
89
+ def node_similarity_key(node):
90
+ return tuple([node["label"]] + node["anchors"])
data/parser/json_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from data.field.mini_torchtext.example import Example
3
+
4
+
5
+ def example_from_json(obj, fields):
6
+ ex = Example()
7
+ for key, vals in fields.items():
8
+ if vals is not None:
9
+ if not isinstance(vals, list):
10
+ vals = [vals]
11
+ for val in vals:
12
+ # for processing the key likes 'foo.bar'
13
+ name, field = val
14
+ ks = key.split(".")
15
+
16
+ def reducer(obj, key):
17
+ if isinstance(obj, list):
18
+ results = []
19
+ for data in obj:
20
+ if key not in data:
21
+ # key error
22
+ raise ValueError("Specified key {} was not found in " "the input data".format(key))
23
+ else:
24
+ results.append(data[key])
25
+ return results
26
+ else:
27
+ # key error
28
+ if key not in obj:
29
+ raise ValueError("Specified key {} was not found in " "the input data".format(key))
30
+ else:
31
+ return obj[key]
32
+
33
+ v = reduce(reducer, ks, obj)
34
+ setattr(ex, name, field.preprocess(v))
35
+ return ex
data/parser/to_mrp/__init__.py ADDED
File without changes
data/parser/to_mrp/abstract_parser.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ class AbstractParser:
5
+ def __init__(self, dataset):
6
+ self.dataset = dataset
7
+
8
+ def create_nodes(self, prediction):
9
+ return [
10
+ {"id": i, "label": self.label_to_str(l, prediction["anchors"][i], prediction)}
11
+ for i, l in enumerate(prediction["labels"])
12
+ ]
13
+
14
+ def label_to_str(self, label, anchors, prediction):
15
+ return self.dataset.label_field.vocab.itos[label - 1]
16
+
17
+ def create_edges(self, prediction, nodes):
18
+ N = len(nodes)
19
+ node_sets = [{"id": n, "set": set([n])} for n in range(N)]
20
+ _, indices = prediction["edge presence"][:N, :N].reshape(-1).sort(descending=True)
21
+ sources, targets = indices // N, indices % N
22
+
23
+ edges = []
24
+ for i in range((N - 1) * N // 2):
25
+ source, target = sources[i].item(), targets[i].item()
26
+ p = prediction["edge presence"][source, target]
27
+
28
+ if p < 0.5 and len(edges) >= N - 1:
29
+ break
30
+
31
+ if node_sets[source]["set"] is node_sets[target]["set"] and p < 0.5:
32
+ continue
33
+
34
+ self.create_edge(source, target, prediction, edges, nodes)
35
+
36
+ if node_sets[source]["set"] is not node_sets[target]["set"]:
37
+ from_set = node_sets[source]["set"]
38
+ for n in node_sets[target]["set"]:
39
+ from_set.add(n)
40
+ node_sets[n]["set"] = from_set
41
+
42
+ return edges
43
+
44
+ def create_edge(self, source, target, prediction, edges, nodes):
45
+ label = self.get_edge_label(prediction, source, target)
46
+ edge = {"source": source, "target": target, "label": label}
47
+
48
+ edges.append(edge)
49
+
50
+ def create_anchors(self, prediction, nodes, join_contiguous=True, at_least_one=False, single_anchor=False, mode="anchors"):
51
+ for i, node in enumerate(nodes):
52
+ threshold = 0.5 if not at_least_one else min(0.5, prediction[mode][i].max().item())
53
+ node[mode] = (prediction[mode][i] >= threshold).nonzero(as_tuple=False).squeeze(-1)
54
+ node[mode] = prediction["token intervals"][node[mode], :]
55
+
56
+ if single_anchor and len(node[mode]) > 1:
57
+ start = min(a[0].item() for a in node[mode])
58
+ end = max(a[1].item() for a in node[mode])
59
+ node[mode] = [{"from": start, "to": end}]
60
+ continue
61
+
62
+ node[mode] = [{"from": f.item(), "to": t.item()} for f, t in node[mode]]
63
+ node[mode] = sorted(node[mode], key=lambda a: a["from"])
64
+
65
+ if join_contiguous and len(node[mode]) > 1:
66
+ cleaned_anchors = []
67
+ end, start = node[mode][0]["from"], node[mode][0]["from"]
68
+ for anchor in node[mode]:
69
+ if end < anchor["from"]:
70
+ cleaned_anchors.append({"from": start, "to": end})
71
+ start = anchor["from"]
72
+ end = anchor["to"]
73
+ cleaned_anchors.append({"from": start, "to": end})
74
+
75
+ node[mode] = cleaned_anchors
76
+
77
+ return nodes
78
+
79
+ def get_edge_label(self, prediction, source, target):
80
+ return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].item()]
data/parser/to_mrp/labeled_edge_parser.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class LabeledEdgeParser(AbstractParser):
8
+ def __init__(self, *args):
9
+ super().__init__(*args)
10
+ self.source_id = self.dataset.edge_label_field.vocab.stoi["Source"]
11
+ self.target_id = self.dataset.edge_label_field.vocab.stoi["Target"]
12
+
13
+ def parse(self, prediction):
14
+ output = {}
15
+
16
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
17
+ output["nodes"] = self.create_nodes(prediction)
18
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
19
+ output["nodes"] = [{"id": 0}] + output["nodes"]
20
+ output["edges"] = self.create_edges(prediction, output["nodes"])
21
+
22
+ return output
23
+
24
+ def create_nodes(self, prediction):
25
+ return [{"id": i + 1} for i, l in enumerate(prediction["labels"])]
26
+
27
+ def create_edges(self, prediction, nodes):
28
+ N = len(nodes)
29
+ edge_prediction = prediction["edge presence"][:N, :N]
30
+
31
+ edges = []
32
+ for target in range(1, N):
33
+ if edge_prediction[0, target] >= 0.5:
34
+ prediction["edge labels"][0, target, self.source_id] = float("-inf")
35
+ prediction["edge labels"][0, target, self.target_id] = float("-inf")
36
+ self.create_edge(0, target, prediction, edges, nodes)
37
+
38
+ for source in range(1, N):
39
+ for target in range(1, N):
40
+ if source == target:
41
+ continue
42
+ if edge_prediction[source, target] < 0.5:
43
+ continue
44
+ for i in range(prediction["edge labels"].size(2)):
45
+ if i not in [self.source_id, self.target_id]:
46
+ prediction["edge labels"][source, target, i] = float("-inf")
47
+ self.create_edge(source, target, prediction, edges, nodes)
48
+
49
+ return edges
50
+
51
+ def get_edge_label(self, prediction, source, target):
52
+ return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].argmax(-1).item()]
data/parser/to_mrp/node_centric_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class NodeCentricParser(AbstractParser):
8
+ def parse(self, prediction):
9
+ output = {}
10
+
11
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
12
+ output["nodes"] = self.create_nodes(prediction)
13
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
14
+ output["edges"] = self.create_edges(prediction, output["nodes"])
15
+
16
+ return output
17
+
18
+ def create_edge(self, source, target, prediction, edges, nodes):
19
+ edge = {"source": source, "target": target, "label": None}
20
+ edges.append(edge)
21
+
22
+ def create_edges(self, prediction, nodes):
23
+ N = len(nodes)
24
+ edge_prediction = prediction["edge presence"][:N, :N]
25
+
26
+ targets = [i for i, node in enumerate(nodes) if node["label"] in ["Source", "Target"]]
27
+ sources = [i for i, node in enumerate(nodes) if node["label"] not in ["Source", "Target"]]
28
+
29
+ edges = []
30
+ for target in targets:
31
+ for source in sources:
32
+ if edge_prediction[source, target] >= 0.5:
33
+ self.create_edge(source, target, prediction, edges, nodes)
34
+
35
+ return edges
data/parser/to_mrp/sequential_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class SequentialParser(AbstractParser):
8
+ def parse(self, prediction):
9
+ output = {}
10
+
11
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
12
+ output["nodes"] = self.create_nodes(prediction)
13
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True, mode="anchors")
14
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="source anchors")
15
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="target anchors")
16
+ output["edges"], output["nodes"] = self.create_targets_sources(output["nodes"])
17
+
18
+ return output
19
+
20
+ def create_targets_sources(self, nodes):
21
+ edges, new_nodes = [], []
22
+ for i, node in enumerate(nodes):
23
+ new_node_id = len(nodes) + len(new_nodes)
24
+ if len(node["source anchors"]) > 0:
25
+ new_nodes.append({"id": new_node_id, "label": "Source", "anchors": node["source anchors"]})
26
+ edges.append({"source": i, "target": new_node_id, "label": ""})
27
+ new_node_id += 1
28
+ del node["source anchors"]
29
+
30
+ if len(node["target anchors"]) > 0:
31
+ new_nodes.append({"id": new_node_id, "label": "Target", "anchors": node["target anchors"]})
32
+ edges.append({"source": i, "target": new_node_id, "label": ""})
33
+ del node["target anchors"]
34
+
35
+ return edges, nodes + new_nodes
model/__init__.py ADDED
File without changes
model/head/__init__.py ADDED
File without changes
model/head/abstract_head.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from model.module.edge_classifier import EdgeClassifier
10
+ from model.module.anchor_classifier import AnchorClassifier
11
+ from utility.cross_entropy import cross_entropy, binary_cross_entropy
12
+ from utility.hungarian_matching import get_matching, reorder, match_anchor, match_label
13
+ from utility.utils import create_padding_mask
14
+
15
+
16
+ class AbstractHead(nn.Module):
17
+ def __init__(self, dataset, args, config, initialize: bool):
18
+ super(AbstractHead, self).__init__()
19
+
20
+ self.edge_classifier = self.init_edge_classifier(dataset, args, config, initialize)
21
+ self.label_classifier = self.init_label_classifier(dataset, args, config, initialize)
22
+ self.anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="anchor")
23
+ self.source_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="source_anchor")
24
+ self.target_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="target_anchor")
25
+
26
+ self.query_length = args.query_length
27
+ self.focal = args.focal
28
+ self.dataset = dataset
29
+
30
+ def forward(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch):
31
+ output = {}
32
+
33
+ decoder_lens = self.query_length * batch["every_input"][1]
34
+ output["label"] = self.forward_label(decoder_output)
35
+ output["anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
36
+ output["source_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
37
+ output["target_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
38
+
39
+ cost_matrices = self.create_cost_matrices(output, batch, decoder_lens)
40
+ matching = get_matching(cost_matrices)
41
+ decoder_output = reorder(decoder_output, matching, batch["labels"][0].size(1))
42
+ output["edge presence"], output["edge label"] = self.forward_edge(decoder_output)
43
+
44
+ return self.loss(output, batch, matching, decoder_mask)
45
+
46
+ def predict(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch, **kwargs):
47
+ every_input, word_lens = batch["every_input"]
48
+ decoder_lens = self.query_length * word_lens
49
+ batch_size = every_input.size(0)
50
+
51
+ label_pred = self.forward_label(decoder_output)
52
+ anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
53
+ source_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
54
+ target_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
55
+
56
+ labels = [[] for _ in range(batch_size)]
57
+ anchors, source_anchors, target_anchors = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)], [[] for _ in range(batch_size)]
58
+
59
+ for b in range(batch_size):
60
+ label_indices = self.inference_label(label_pred[b, :decoder_lens[b], :]).cpu()
61
+ for t in range(label_indices.size(0)):
62
+ label_index = label_indices[t].item()
63
+ if label_index == 0:
64
+ continue
65
+
66
+ decoder_output[b, len(labels[b]), :] = decoder_output[b, t, :]
67
+
68
+ labels[b].append(label_index)
69
+ if anchor_pred is None:
70
+ anchors[b].append(list(range(t // self.query_length, word_lens[b])))
71
+ else:
72
+ anchors[b].append(self.inference_anchor(anchor_pred[b, t, :word_lens[b]]).cpu())
73
+
74
+ if source_anchor_pred is None:
75
+ source_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
76
+ else:
77
+ source_anchors[b].append(self.inference_anchor(source_anchor_pred[b, t, :word_lens[b]]).cpu())
78
+
79
+ if target_anchor_pred is None:
80
+ target_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
81
+ else:
82
+ target_anchors[b].append(self.inference_anchor(target_anchor_pred[b, t, :word_lens[b]]).cpu())
83
+
84
+ decoder_output = decoder_output[:, : max(len(l) for l in labels), :]
85
+ edge_presence, edge_labels = self.forward_edge(decoder_output)
86
+
87
+ outputs = [
88
+ self.parser.parse(
89
+ {
90
+ "labels": labels[b],
91
+ "anchors": anchors[b],
92
+ "source anchors": source_anchors[b],
93
+ "target anchors": target_anchors[b],
94
+ "edge presence": self.inference_edge_presence(edge_presence, b),
95
+ "edge labels": self.inference_edge_label(edge_labels, b),
96
+ "id": batch["id"][b].cpu(),
97
+ "tokens": batch["every_input"][0][b, : word_lens[b]].cpu(),
98
+ "token intervals": batch["token_intervals"][b, :, :].cpu(),
99
+ },
100
+ **kwargs
101
+ )
102
+ for b in range(batch_size)
103
+ ]
104
+
105
+ return outputs
106
+
107
+ def loss(self, output, batch, matching, decoder_mask):
108
+ batch_size = batch["every_input"][0].size(0)
109
+ device = batch["every_input"][0].device
110
+ T_label = batch["labels"][0].size(1)
111
+ T_input = batch["every_input"][0].size(1)
112
+ T_edge = batch["edge_presence"].size(1)
113
+
114
+ input_mask = create_padding_mask(batch_size, T_input, batch["every_input"][1], device) # shape: (B, T_input)
115
+ label_mask = create_padding_mask(batch_size, T_label, batch["labels"][1], device) # shape: (B, T_label)
116
+ edge_mask = torch.eye(T_label, T_label, device=device, dtype=torch.bool).unsqueeze(0) # shape: (1, T_label, T_label)
117
+ edge_mask = edge_mask | label_mask.unsqueeze(1) | label_mask.unsqueeze(2) # shape: (B, T_label, T_label)
118
+ if T_edge != T_label:
119
+ edge_mask = F.pad(edge_mask, (T_edge - T_label, 0, T_edge - T_label, 0), value=0)
120
+ edge_label_mask = (batch["edge_presence"] == 0) | edge_mask
121
+
122
+ if output["edge label"] is not None:
123
+ batch["edge_labels"] = (
124
+ batch["edge_labels"][0][:, :, :, :output["edge label"].size(-1)],
125
+ batch["edge_labels"][1],
126
+ )
127
+
128
+ losses = {}
129
+ losses.update(self.loss_label(output, batch, decoder_mask, matching))
130
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="anchor"))
131
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="source_anchor"))
132
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="target_anchor"))
133
+ losses.update(self.loss_edge_presence(output, batch, edge_mask))
134
+ losses.update(self.loss_edge_label(output, batch, edge_label_mask.unsqueeze(-1)))
135
+
136
+ stats = {f"{key}": value.detach().cpu().item() for key, value in losses.items()}
137
+ total_loss = sum(losses.values()) / len(losses)
138
+
139
+ return total_loss, stats
140
+
141
+ @torch.no_grad()
142
+ def create_cost_matrices(self, output, batch, decoder_lens):
143
+ batch_size = len(batch["labels"][1])
144
+ decoder_lens = decoder_lens.cpu()
145
+
146
+ matrices = []
147
+ for b in range(batch_size):
148
+ label_cost_matrix = self.label_cost_matrix(output, batch, decoder_lens, b)
149
+ anchor_cost_matrix = self.anchor_cost_matrix(output, batch, decoder_lens, b)
150
+
151
+ cost_matrix = label_cost_matrix * anchor_cost_matrix
152
+ matrices.append(cost_matrix.cpu())
153
+
154
+ return matrices
155
+
156
+ def init_edge_classifier(self, dataset, args, config, initialize: bool):
157
+ if not config["edge presence"] and not config["edge label"]:
158
+ return None
159
+ return EdgeClassifier(dataset, args, initialize, presence=config["edge presence"], label=config["edge label"])
160
+
161
+ def init_label_classifier(self, dataset, args, config, initialize: bool):
162
+ if not config["label"]:
163
+ return None
164
+
165
+ classifier = nn.Sequential(
166
+ nn.Dropout(args.dropout_label),
167
+ nn.Linear(args.hidden_size, len(dataset.label_field.vocab) + 1, bias=True)
168
+ )
169
+ if initialize:
170
+ classifier[1].bias.data = dataset.label_freqs.log()
171
+
172
+ return classifier
173
+
174
+ def init_anchor_classifier(self, dataset, args, config, initialize: bool, mode="anchor"):
175
+ if not config[mode]:
176
+ return None
177
+
178
+ return AnchorClassifier(dataset, args, initialize, mode=mode)
179
+
180
+ def forward_edge(self, decoder_output):
181
+ if self.edge_classifier is None:
182
+ return None, None
183
+ return self.edge_classifier(decoder_output)
184
+
185
+ def forward_label(self, decoder_output):
186
+ if self.label_classifier is None:
187
+ return None
188
+ return torch.log_softmax(self.label_classifier(decoder_output), dim=-1)
189
+
190
+ def forward_anchor(self, decoder_output, encoder_output, encoder_mask, mode="anchor"):
191
+ classifier = getattr(self, f"{mode}_classifier")
192
+ if classifier is None:
193
+ return None
194
+ return classifier(decoder_output, encoder_output, encoder_mask)
195
+
196
+ def inference_label(self, prediction):
197
+ prediction = prediction.exp()
198
+ return torch.where(
199
+ prediction[:, 0] > prediction[:, 1:].sum(-1),
200
+ torch.zeros(prediction.size(0), dtype=torch.long, device=prediction.device),
201
+ prediction[:, 1:].argmax(dim=-1) + 1
202
+ )
203
+
204
+ def inference_anchor(self, prediction):
205
+ return prediction.sigmoid()
206
+
207
+ def inference_edge_presence(self, prediction, example_index: int):
208
+ if prediction is None:
209
+ return None
210
+
211
+ N = prediction.size(1)
212
+ mask = torch.eye(N, N, device=prediction.device, dtype=torch.bool)
213
+ return prediction[example_index, :, :].sigmoid().masked_fill(mask, 0.0).cpu()
214
+
215
+ def inference_edge_label(self, prediction, example_index: int):
216
+ if prediction is None:
217
+ return None
218
+ return prediction[example_index, :, :, :].cpu()
219
+
220
+ def loss_edge_presence(self, prediction, target, mask):
221
+ if self.edge_classifier is None or prediction["edge presence"] is None:
222
+ return {}
223
+ return {"edge presence": binary_cross_entropy(prediction["edge presence"], target["edge_presence"].float(), mask)}
224
+
225
+ def loss_edge_label(self, prediction, target, mask):
226
+ if self.edge_classifier is None or prediction["edge label"] is None:
227
+ return {}
228
+ return {"edge label": binary_cross_entropy(prediction["edge label"], target["edge_labels"][0].float(), mask)}
229
+
230
+ def loss_label(self, prediction, target, mask, matching):
231
+ if self.label_classifier is None or prediction["label"] is None:
232
+ return {}
233
+
234
+ prediction = prediction["label"]
235
+ target = match_label(
236
+ target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
237
+ )
238
+ return {"label": cross_entropy(prediction, target, mask, focal=self.focal)}
239
+
240
+ def loss_anchor(self, prediction, target, mask, matching, mode="anchor"):
241
+ if getattr(self, f"{mode}_classifier") is None or prediction[mode] is None:
242
+ return {}
243
+
244
+ prediction = prediction[mode]
245
+ target, anchor_mask = match_anchor(target[mode], matching, prediction.shape, prediction.device)
246
+ mask = anchor_mask.unsqueeze(-1) | mask.unsqueeze(-2)
247
+ return {mode: binary_cross_entropy(prediction, target.float(), mask)}
248
+
249
+ def label_cost_matrix(self, output, batch, decoder_lens, b: int):
250
+ if output["label"] is None:
251
+ return 1.0
252
+
253
+ target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, num_classes)
254
+ label_prob = output["label"][b, : decoder_lens[b], :].exp().unsqueeze(0) # shape: (1, num_queries, num_classes)
255
+ tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, num_classes)
256
+ cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
257
+
258
+ return cost_matrix
259
+
260
+ def anchor_cost_matrix(self, output, batch, decoder_lens, b: int):
261
+ if output["anchor"] is None:
262
+ return 1.0
263
+
264
+ num_nodes = batch["labels"][1][b]
265
+ word_lens = batch["every_input"][1]
266
+ target_anchors, _ = batch["anchor"]
267
+ pred_anchors = output["anchor"].sigmoid()
268
+
269
+ tgt_align = target_anchors[b, : num_nodes, : word_lens[b]] # shape: (num_nodes, num_inputs)
270
+ align_prob = pred_anchors[b, : decoder_lens[b], : word_lens[b]] # shape: (num_queries, num_inputs)
271
+ align_prob = align_prob.unsqueeze(1).expand(-1, num_nodes, -1) # shape: (num_queries, num_nodes, num_inputs)
272
+ align_prob = torch.where(tgt_align.unsqueeze(0).bool(), align_prob, 1.0 - align_prob) # shape: (num_queries, num_nodes, num_inputs)
273
+ cost_matrix = align_prob.log().mean(-1).exp() # shape: (num_queries, num_nodes)
274
+ return cost_matrix
model/head/labeled_edge_head.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from model.head.abstract_head import AbstractHead
8
+ from data.parser.to_mrp.labeled_edge_parser import LabeledEdgeParser
9
+ from utility.cross_entropy import binary_cross_entropy
10
+ from utility.hungarian_matching import match_label
11
+
12
+
13
+ class LabeledEdgeHead(AbstractHead):
14
+ def __init__(self, dataset, args, initialize):
15
+ config = {
16
+ "label": True,
17
+ "edge presence": True,
18
+ "edge label": True,
19
+ "anchor": True,
20
+ "source_anchor": False,
21
+ "target_anchor": False
22
+ }
23
+ super(LabeledEdgeHead, self).__init__(dataset, args, config, initialize)
24
+
25
+ self.top_node = nn.Parameter(torch.randn(1, 1, args.hidden_size), requires_grad=True)
26
+ self.parser = LabeledEdgeParser(dataset)
27
+
28
+ def init_label_classifier(self, dataset, args, config, initialize: bool):
29
+ classifier = nn.Sequential(
30
+ nn.Dropout(args.dropout_label),
31
+ nn.Linear(args.hidden_size, 1, bias=True)
32
+ )
33
+ if initialize:
34
+ bias_init = torch.tensor([dataset.label_freqs[1]])
35
+ classifier[1].bias.data = (bias_init / (1.0 - bias_init)).log()
36
+
37
+ return classifier
38
+
39
+ def forward_label(self, decoder_output):
40
+ return self.label_classifier(decoder_output)
41
+
42
+ def forward_edge(self, decoder_output):
43
+ top_node = self.top_node.expand(decoder_output.size(0), -1, -1)
44
+ decoder_output = torch.cat([top_node, decoder_output], dim=1)
45
+ return self.edge_classifier(decoder_output)
46
+
47
+ def loss_label(self, prediction, target, mask, matching):
48
+ prediction = prediction["label"]
49
+ target = match_label(
50
+ target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
51
+ )
52
+ return {"label": binary_cross_entropy(prediction.squeeze(-1), target.float(), mask, focal=self.focal)}
53
+
54
+ def inference_label(self, prediction):
55
+ return (prediction.squeeze(-1) > 0.0).long()
56
+
57
+ def label_cost_matrix(self, output, batch, decoder_lens, b: int):
58
+ if output["label"] is None:
59
+ return 1.0
60
+
61
+ target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, 2)
62
+ label_prob = output["label"][b, : decoder_lens[b], :].sigmoid().unsqueeze(0) # shape: (1, num_queries, 1)
63
+ label_prob = torch.cat([1.0 - label_prob, label_prob], dim=-1) # shape: (1, num_queries, 2)
64
+ tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, 2)
65
+ cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
66
+
67
+ return cost_matrix
model/head/node_centric_head.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+
6
+ from model.head.abstract_head import AbstractHead
7
+ from data.parser.to_mrp.node_centric_parser import NodeCentricParser
8
+ from utility.cross_entropy import binary_cross_entropy
9
+
10
+
11
+ class NodeCentricHead(AbstractHead):
12
+ def __init__(self, dataset, args, initialize):
13
+ config = {
14
+ "label": True,
15
+ "edge presence": True,
16
+ "edge label": False,
17
+ "anchor": True,
18
+ "source_anchor": False,
19
+ "target_anchor": False
20
+ }
21
+ super(NodeCentricHead, self).__init__(dataset, args, config, initialize)
22
+
23
+ self.source_id = dataset.label_field.vocab.stoi["Source"] + 1
24
+ self.target_id = dataset.label_field.vocab.stoi["Target"] + 1
25
+ self.parser = NodeCentricParser(dataset)
model/head/sequential_head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from model.head.abstract_head import AbstractHead
9
+ from data.parser.to_mrp.sequential_parser import SequentialParser
10
+ from utility.cross_entropy import cross_entropy
11
+
12
+
13
+ class SequentialHead(AbstractHead):
14
+ def __init__(self, dataset, args, initialize):
15
+ config = {
16
+ "label": True,
17
+ "edge presence": False,
18
+ "edge label": False,
19
+ "anchor": True,
20
+ "source_anchor": True,
21
+ "target_anchor": True
22
+ }
23
+ super(SequentialHead, self).__init__(dataset, args, config, initialize)
24
+ self.parser = SequentialParser(dataset)
model/model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from model.module.encoder import Encoder
8
+
9
+ from model.module.transformer import Decoder
10
+ from model.head.node_centric_head import NodeCentricHead
11
+ from model.head.labeled_edge_head import LabeledEdgeHead
12
+ from model.head.sequential_head import SequentialHead
13
+ from utility.utils import create_padding_mask
14
+
15
+
16
+ class Model(nn.Module):
17
+ def __init__(self, dataset, args, initialize=True):
18
+ super(Model, self).__init__()
19
+ self.encoder = Encoder(args, dataset)
20
+ if args.n_layers > 0:
21
+ self.decoder = Decoder(args)
22
+ else:
23
+ self.decoder = lambda x, *args: x # identity function, which ignores all arguments except the first one
24
+
25
+ if args.graph_mode == "sequential":
26
+ self.head = SequentialHead(dataset, args, initialize)
27
+ elif args.graph_mode == "node-centric":
28
+ self.head = NodeCentricHead(dataset, args, initialize)
29
+ elif args.graph_mode == "labeled-edge":
30
+ self.head = LabeledEdgeHead(dataset, args, initialize)
31
+
32
+ self.query_length = args.query_length
33
+ self.dataset = dataset
34
+ self.args = args
35
+
36
+ def forward(self, batch, inference=False, **kwargs):
37
+ every_input, word_lens = batch["every_input"]
38
+ decoder_lens = self.query_length * word_lens
39
+ batch_size, input_len = every_input.size(0), every_input.size(1)
40
+ device = every_input.device
41
+
42
+ encoder_mask = create_padding_mask(batch_size, input_len, word_lens, device)
43
+ decoder_mask = create_padding_mask(batch_size, self.query_length * input_len, decoder_lens, device)
44
+
45
+ encoder_output, decoder_input = self.encoder(batch["input"], batch["char_form_input"], batch["input_scatter"], input_len)
46
+
47
+ decoder_output = self.decoder(decoder_input, encoder_output, decoder_mask, encoder_mask)
48
+
49
+ if inference:
50
+ return self.head.predict(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
51
+ else:
52
+ return self.head(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
53
+
54
+ def get_params_for_optimizer(self, args):
55
+ encoder_decay, encoder_no_decay = self.get_encoder_parameters(args.n_encoder_layers)
56
+ decoder_decay, decoder_no_decay = self.get_decoder_parameters()
57
+
58
+ parameters = [{"params": p, "weight_decay": args.encoder_weight_decay} for p in encoder_decay]
59
+ parameters += [{"params": p, "weight_decay": 0.0} for p in encoder_no_decay]
60
+ parameters += [
61
+ {"params": decoder_decay, "weight_decay": args.decoder_weight_decay},
62
+ {"params": decoder_no_decay, "weight_decay": 0.0},
63
+ ]
64
+ return parameters
65
+
66
+ def get_decoder_parameters(self):
67
+ no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
68
+ decay_params = (p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
69
+ no_decay_params = (p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
70
+
71
+ return decay_params, no_decay_params
72
+
73
+ def get_encoder_parameters(self, n_layers):
74
+ no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
75
+ decay_params = [
76
+ [p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
77
+ ]
78
+ no_decay_params = [
79
+ [p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
80
+ ]
81
+
82
+ return decay_params, no_decay_params
model/module/__init__.py ADDED
File without changes
model/module/anchor_classifier.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from model.module.biaffine import Biaffine
8
+
9
+
10
+ class AnchorClassifier(nn.Module):
11
+ def __init__(self, dataset, args, initialize: bool, bias=True, mode="anchor"):
12
+ super(AnchorClassifier, self).__init__()
13
+
14
+ self.token_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
15
+ self.label_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
16
+ self.dropout = nn.Dropout(args.dropout_anchor)
17
+
18
+ if bias and initialize:
19
+ bias_init = torch.tensor([getattr(dataset, f"{mode}_freq")])
20
+ bias_init = (bias_init / (1.0 - bias_init)).log()
21
+ else:
22
+ bias_init = None
23
+
24
+ self.output = Biaffine(args.hidden_size_anchor, 1, bias=bias, bias_init=bias_init)
25
+
26
+ def forward(self, label, tokens, encoder_mask):
27
+ tokens = self.dropout(F.elu(self.token_f(tokens))) # shape: (B, T_w, H)
28
+ label = self.dropout(F.elu(self.label_f(label))) # shape: (B, T_l, H)
29
+ anchor = self.output(label, tokens).squeeze(-1) # shape: (B, T_l, T_w)
30
+
31
+ anchor = anchor.masked_fill(encoder_mask.unsqueeze(1), float("-inf")) # shape: (B, T_l, T_w)
32
+ return anchor
model/module/biaffine.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch.nn as nn
5
+ from model.module.bilinear import Bilinear
6
+
7
+
8
+ class Biaffine(nn.Module):
9
+ def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
10
+ super(Biaffine, self).__init__()
11
+
12
+ self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
13
+ self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)
14
+
15
+ self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
16
+ if bias_init is not None:
17
+ self.bilinear.bias.data = bias_init
18
+
19
+ def forward(self, x, y):
20
+ return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)
model/module/bilinear.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/NLPInBLCU/BiaffineDependencyParsing/blob/master/modules/biaffine.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Bilinear(nn.Module):
8
+ """
9
+ 使用版本
10
+ A bilinear module that deals with broadcasting for efficient memory usage.
11
+ Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
12
+ Output: tensor of size (N x L1 x L2 x O)"""
13
+
14
+ def __init__(self, input1_size, input2_size, output_size, bias=True):
15
+ super(Bilinear, self).__init__()
16
+
17
+ self.input1_size = input1_size
18
+ self.input2_size = input2_size
19
+ self.output_size = output_size
20
+
21
+ self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
22
+ self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else None
23
+
24
+ self.reset_parameters()
25
+
26
+ def reset_parameters(self):
27
+ nn.init.zeros_(self.weight)
28
+
29
+ def forward(self, input1, input2):
30
+ input1_size = list(input1.size())
31
+ input2_size = list(input2.size())
32
+
33
+ intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size),)
34
+
35
+ input2 = input2.transpose(1, 2)
36
+ output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
37
+
38
+ output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
39
+
40
+ if self.bias is not None:
41
+ output = output + self.bias
42
+
43
+ return output
model/module/char_embedding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
7
+
8
+
9
+ class CharEmbedding(nn.Module):
10
+ def __init__(self, vocab_size: int, embedding_size: int, output_size: int):
11
+ super(CharEmbedding, self).__init__()
12
+
13
+ self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False)
14
+ self.layer_norm = nn.LayerNorm(embedding_size)
15
+ self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True)
16
+ self.out_linear = nn.Linear(2*embedding_size, output_size)
17
+ self.layer_norm_2 = nn.LayerNorm(output_size)
18
+
19
+ def forward(self, words, sentence_lens, word_lens):
20
+ # input shape: (B, W, C)
21
+ n_words = words.size(1)
22
+ sentence_lens = sentence_lens.cpu()
23
+ sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) # shape: (B*W, C)
24
+ lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) # shape: (B*W)
25
+ word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) # shape: (B*W*C)
26
+
27
+ embedded = self.embedding(word_packed.data) # shape: (B*W*C, D)
28
+ embedded = self.layer_norm(embedded) # shape: (B*W*C, D)
29
+
30
+ embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3])
31
+ _, embedded = self.gru(embedded_packed) # shape: (layers * 2, B*W, D)
32
+
33
+ embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) # shape: (B*W, 2*D)
34
+ embedded = F.relu(embedded)
35
+ embedded = self.out_linear(embedded)
36
+ embedded = self.layer_norm_2(embedded)
37
+
38
+ embedded, _ = pad_packed_sequence(
39
+ PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words,
40
+ ) # shape: (B, W, 2*D)
41
+
42
+ return embedded # shape: (B, W, 2*D)
model/module/edge_classifier.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from model.module.biaffine import Biaffine
8
+
9
+
10
+ class EdgeClassifier(nn.Module):
11
+ def __init__(self, dataset, args, initialize: bool, presence: bool, label: bool):
12
+ super(EdgeClassifier, self).__init__()
13
+
14
+ self.presence = presence
15
+ if self.presence:
16
+ if initialize:
17
+ presence_init = torch.tensor([dataset.edge_presence_freq])
18
+ presence_init = (presence_init / (1.0 - presence_init)).log()
19
+ else:
20
+ presence_init = None
21
+
22
+ self.edge_presence = EdgeBiaffine(
23
+ args.hidden_size, args.hidden_size_edge_presence, 1, args.dropout_edge_presence, bias_init=presence_init
24
+ )
25
+
26
+ self.label = label
27
+ if self.label:
28
+ label_init = (dataset.edge_label_freqs / (1.0 - dataset.edge_label_freqs)).log() if initialize else None
29
+ n_labels = len(dataset.edge_label_field.vocab)
30
+ self.edge_label = EdgeBiaffine(
31
+ args.hidden_size, args.hidden_size_edge_label, n_labels, args.dropout_edge_label, bias_init=label_init
32
+ )
33
+
34
+ def forward(self, x):
35
+ presence, label = None, None
36
+
37
+ if self.presence:
38
+ presence = self.edge_presence(x).squeeze(-1) # shape: (B, T, T)
39
+ if self.label:
40
+ label = self.edge_label(x) # shape: (B, T, T, O_1)
41
+
42
+ return presence, label
43
+
44
+
45
+ class EdgeBiaffine(nn.Module):
46
+ def __init__(self, hidden_dim, bottleneck_dim, output_dim, dropout, bias_init=None):
47
+ super(EdgeBiaffine, self).__init__()
48
+ self.hidden = nn.Linear(hidden_dim, 2 * bottleneck_dim)
49
+ self.output = Biaffine(bottleneck_dim, output_dim, bias_init=bias_init)
50
+ self.dropout = nn.Dropout(dropout)
51
+
52
+ def forward(self, x):
53
+ x = self.dropout(F.elu(self.hidden(x))) # shape: (B, T, 2H)
54
+ predecessors, current = x.chunk(2, dim=-1) # shape: (B, T, H), (B, T, H)
55
+ edge = self.output(current, predecessors) # shape: (B, T, T, O)
56
+ return edge
model/module/encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import AutoModel
11
+ from model.module.char_embedding import CharEmbedding
12
+
13
+
14
+ class WordDropout(nn.Dropout):
15
+ def forward(self, input_tensor):
16
+ if self.p == 0:
17
+ return input_tensor
18
+
19
+ ones = input_tensor.new_ones(input_tensor.shape[:-1])
20
+ dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
21
+
22
+ return dropout_mask.unsqueeze(-1) * input_tensor
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, args, dataset):
27
+ super(Encoder, self).__init__()
28
+
29
+ self.dim = args.hidden_size
30
+ self.n_layers = args.n_encoder_layers
31
+ self.width_factor = args.query_length
32
+
33
+ self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False)
34
+ # self.bert._set_gradient_checkpointing(self.bert.encoder, value=True)
35
+ if args.encoder_freeze_embedding:
36
+ self.bert.embeddings.requires_grad_(False)
37
+ self.bert.embeddings.LayerNorm.requires_grad_(True)
38
+
39
+ if args.freeze_bert:
40
+ self.bert.requires_grad_(False)
41
+
42
+ self.use_char_embedding = args.char_embedding
43
+ if self.use_char_embedding:
44
+ self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim)
45
+ self.word_dropout = WordDropout(args.dropout_word)
46
+
47
+ self.post_layer_norm = nn.LayerNorm(self.dim)
48
+ self.subword_attention = nn.Linear(self.dim, 1)
49
+
50
+ if self.width_factor > 1:
51
+ self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor)
52
+ else:
53
+ self.query_generator = nn.Identity()
54
+
55
+ self.encoded_layer_norm = nn.LayerNorm(self.dim)
56
+ self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True)
57
+
58
+ def forward(self, bert_input, form_chars, to_scatter, n_words):
59
+ tokens, mask = bert_input
60
+ batch_size = tokens.size(0)
61
+
62
+ encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:]
63
+ encoded = torch.stack(encoded, dim=0) # shape: (12, B, T, H)
64
+ encoded = self.encoded_layer_norm(encoded)
65
+
66
+ if self.training:
67
+ time_len = encoded.size(2)
68
+ scores = self.scores.expand(-1, batch_size, time_len, -1)
69
+ dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device)
70
+ dropout.bernoulli_(0.1)
71
+ scores = scores.masked_fill(dropout, float("-inf"))
72
+ else:
73
+ scores = self.scores
74
+
75
+ scores = F.softmax(scores, dim=0)
76
+ encoded = (scores * encoded).sum(0) # shape: (B, T, H)
77
+ encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0) # shape: (B, T, H)
78
+
79
+ subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim) # shape: (B, T, 1)
80
+ subword_attention = subword_attention.expand_as(to_scatter) # shape: (B, T_subword, T_word)
81
+ subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf")) # shape: (B, T_subword, T_word)
82
+ subword_attention = torch.softmax(subword_attention, dim=1) # shape: (B, T_subword, T_word)
83
+ subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0) # shape: (B, T_subword, T_word)
84
+
85
+ encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention)
86
+ encoder_output = self.post_layer_norm(encoder_output)
87
+
88
+ if self.use_char_embedding:
89
+ form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2])
90
+ encoder_output = self.word_dropout(encoder_output) + form_char_embedding
91
+
92
+ decoder_input = self.query_generator(encoder_output)
93
+ decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2) # shape: (B, T*Q, D)
94
+
95
+ return encoder_output, decoder_input