larkkin commited on
Commit
c45d283
1 Parent(s): 4d8e00f

Add code and readme

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +70 -0
  2. app.py +97 -0
  3. config/__init__.py +0 -0
  4. config/params.py +89 -0
  5. data/__init__.py +0 -0
  6. data/batch.py +95 -0
  7. data/dataset.py +245 -0
  8. data/field/__init__.py +0 -0
  9. data/field/anchor_field.py +19 -0
  10. data/field/anchored_label_field.py +38 -0
  11. data/field/basic_field.py +11 -0
  12. data/field/bert_field.py +18 -0
  13. data/field/edge_field.py +63 -0
  14. data/field/edge_label_field.py +67 -0
  15. data/field/field.py +70 -0
  16. data/field/label_field.py +36 -0
  17. data/field/mini_torchtext/example.py +100 -0
  18. data/field/mini_torchtext/field.py +637 -0
  19. data/field/mini_torchtext/pipeline.py +86 -0
  20. data/field/mini_torchtext/utils.py +256 -0
  21. data/field/mini_torchtext/vocab.py +116 -0
  22. data/field/nested_field.py +50 -0
  23. data/parser/__init__.py +0 -0
  24. data/parser/from_mrp/__init__.py +0 -0
  25. data/parser/from_mrp/abstract_parser.py +50 -0
  26. data/parser/from_mrp/evaluation_parser.py +18 -0
  27. data/parser/from_mrp/labeled_edge_parser.py +70 -0
  28. data/parser/from_mrp/node_centric_parser.py +69 -0
  29. data/parser/from_mrp/request_parser.py +23 -0
  30. data/parser/from_mrp/sequential_parser.py +90 -0
  31. data/parser/json_parser.py +35 -0
  32. data/parser/to_mrp/__init__.py +0 -0
  33. data/parser/to_mrp/abstract_parser.py +80 -0
  34. data/parser/to_mrp/labeled_edge_parser.py +52 -0
  35. data/parser/to_mrp/node_centric_parser.py +35 -0
  36. data/parser/to_mrp/sequential_parser.py +35 -0
  37. model/__init__.py +0 -0
  38. model/head/__init__.py +0 -0
  39. model/head/abstract_head.py +274 -0
  40. model/head/labeled_edge_head.py +67 -0
  41. model/head/node_centric_head.py +25 -0
  42. model/head/sequential_head.py +24 -0
  43. model/model.py +82 -0
  44. model/module/__init__.py +0 -0
  45. model/module/anchor_classifier.py +32 -0
  46. model/module/biaffine.py +20 -0
  47. model/module/bilinear.py +43 -0
  48. model/module/char_embedding.py +42 -0
  49. model/module/edge_classifier.py +56 -0
  50. model/module/encoder.py +95 -0
README.md CHANGED
@@ -1,3 +1,73 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ datasets:
4
+ - ltg/norec
5
+ language:
6
+ - 'no'
7
+ pipeline_tag: token-classification
8
+
9
+
10
+ model-index:
11
+ - name: SSA-Perin
12
+ results:
13
+ - task:
14
+ type: structured sentiment analysis
15
+ dataset:
16
+ name: NoReC
17
+ type: NoReC
18
+ metrics:
19
+ - name: Unlabeled sentiment tuple F1
20
+ type: Unlabeled sentiment tuple F1
21
+ value: 44.12%
22
+ - name: Target F1
23
+ type: Target F1
24
+ value: 56.44%
25
+ - name: Relative polarity precision
26
+ type: Relative polarity precision
27
+ value: 93.19%
28
  ---
29
+
30
+
31
+
32
+ This repository contains a pretrained model (and an easy-to-run wrapper for it) for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC_fine dataset](https://github.com/ltgoslo/norec_fine).
33
+ This is an implementation of the method described in
34
+ ```bibtex
35
+ @misc{samuel2022direct,
36
+ title={Direct parsing to sentiment graphs},
37
+ author={David Samuel and Jeremy Barnes and Robin Kurtz and Stephan Oepen and Lilja Øvrelid and Erik Velldal},
38
+ year={2022},
39
+ eprint={2203.13209},
40
+ archivePrefix={arXiv},
41
+ primaryClass={cs.CL}
42
+ }
43
+ ```
44
+ The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
45
+ The model is also available in the form of a [HF space](https://huggingface.co/spaces/ltg/ssa-perin).
46
+
47
+
48
+ The sentiment graph model is based on an underlying masked language model – [NorBERT 2](https://huggingface.co/ltg/norbert2).
49
+ The proposed method suggests three different ways to encode the sentiment graph: "node-centric", "labeled-edge", and "opinion-tuple".
50
+ The current model
51
+ - uses "labeled-edge" graph encoding
52
+ - does not use character-level embedding
53
+ - all other hyperparameters are set to [default values](https://github.com/jerbarnes/direct_parsing_to_sent_graph/blob/main/perin/config/edge_norec.yaml)
54
+ , and it achieves the following results on the held-out set of the dataset:
55
+
56
+ | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
57
+ |:----------------------------:|:----------:|:---------------------------:|
58
+ | 0.434 | 0.541 | 0.926 |
59
+
60
+
61
+ The model can be easily used for predicting sentiment tuples as follows:
62
+
63
+ ```python
64
+ >>> import model_wrapper
65
+ >>> model = model_wrapper.PredictionModel()
66
+ >>> model.predict(['vi liker svart kaffe'])
67
+ [{'sent_id': '0',
68
+ 'text': 'vi liker svart kaffe',
69
+ 'opinions': [{'Source': [['vi'], ['0:2']],
70
+ 'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
71
+ 'Polar_expression': [['liker'], ['3:8']],
72
+ 'Polarity': 'Positive'}]}]
73
+ ```
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import model_wrapper
3
+
4
+
5
+ model = model_wrapper.PredictionModel()
6
+
7
+
8
+ def pretty_print_opinion(opinion_dict):
9
+ res = []
10
+ maxlen = max([len(key) for key in opinion_dict.keys()]) + 2
11
+ maxlen = 0
12
+ for key, value in opinion_dict.items():
13
+ if key == 'Polarity':
14
+ res.append(f'{(key + ":").ljust(maxlen)} {value}')
15
+ else:
16
+ res.append(f'{(key + ":").ljust(maxlen)} \'{" ".join(value[0])}\'')
17
+ return '\n'.join(res) + '\n'
18
+
19
+
20
+ def predict(text):
21
+ print(f'Input message "{text}"')
22
+ try:
23
+ predictions = model([text])
24
+ prediction = predictions[0]
25
+ results = []
26
+ if not prediction['opinions']:
27
+ return 'No opinions detected'
28
+ for opinion in prediction['opinions']:
29
+ results.append(pretty_print_opinion(opinion))
30
+ print(f'Successfully predicted SA for input message "{text}": {results}')
31
+ return '\n'.join(results)
32
+ except Exception as e:
33
+ print(f'Error for input message "{text}": {e}')
34
+ raise e
35
+
36
+
37
+
38
+ markdown_text = '''
39
+ <br>
40
+ <br>
41
+ This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
42
+ This space containt an implementation of method described in "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
43
+
44
+ The sentiment graph model is based on an underlying masked language model – [NorBERT 2](https://huggingface.co/ltg/norbert2).
45
+ The proposed method suggests three different ways to encode the sentiment graph: "node-centric", "labeled-edge", and "opinion-tuple".
46
+ The current model
47
+ - uses "labeled-edge" graph encoding
48
+ - does not use character-level embedding
49
+ - all other hyperparameters are set to [default values](https://github.com/jerbarnes/direct_parsing_to_sent_graph/blob/main/perin/config/edge_norec.yaml)
50
+ , and it achieves the following results on the held-out set of the NoReC dataset:
51
+
52
+ | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
53
+ |:----------------------------:|:----------:|:---------------------------:|
54
+ | 0.434 | 0.541 | 0.926 |
55
+
56
+
57
+ In "Word Substitution with Masked Language Models as Data Augmentation for Sentiment Analysis", we analyzed data augmentation strategies for improving performance of the model. Using masked-language modeling (MLM), we augmented the sentences with MLM-substituted words inside, outside, or inside+outside the actual sentiment tuples. The results below show that augmentation may be improve the model performance. This space, however, runs the original model trained without augmentation.
58
+
59
+ | | Augmentation rate | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
60
+ |----------------|-------------------|------------------------------|-----------|-----------------------------|
61
+ | Baseline | 0% | 43.39 | 54.13 | 92.59 |
62
+ | Outside | 59% | **45.08** | 56.18 | 92.95 |
63
+ | Inside | 9% | 43.38 | 55.62 | 92.49 |
64
+ | Inside+Outside | 27% | 44.12 | **56.44** | **93.19** |
65
+
66
+
67
+
68
+ The model can be easily used for predicting sentiment tuples as follows:
69
+
70
+ ```python
71
+ >>> import model_wrapper
72
+ >>> model = model_wrapper.PredictionModel()
73
+ >>> model.predict(['vi liker svart kaffe'])
74
+ [{'sent_id': '0',
75
+ 'text': 'vi liker svart kaffe',
76
+ 'opinions': [{'Source': [['vi'], ['0:2']],
77
+ 'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
78
+ 'Polar_expression': [['liker'], ['3:8']],
79
+ 'Polarity': 'Positive'}]}]
80
+ ```
81
+ '''
82
+
83
+
84
+
85
+ with gr.Blocks() as demo:
86
+ with gr.Row() as row:
87
+ text_input = gr.Textbox(label="input")
88
+ text_output = gr.Textbox(label="output")
89
+ with gr.Row() as row:
90
+ text_button = gr.Button("submit")
91
+
92
+ text_button.click(fn=predict, inputs=text_input, outputs=text_output)
93
+
94
+ gr.Markdown(markdown_text)
95
+
96
+
97
+ demo.launch()
config/__init__.py ADDED
File without changes
config/params.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+
3
+
4
+ class Params:
5
+ def __init__(self):
6
+ self.graph_mode = "sequential" # possibilities: {sequential, node-centric, edge-labeled}
7
+ self.accumulation_steps = 1 # number of gradient accumulation steps for achieving a bigger batch_size
8
+ self.activation = "relu" # transformer (decoder) activation function, supported values: {'relu', 'gelu', 'sigmoid', 'mish'}
9
+ self.predict_intensity = False
10
+ self.batch_size = 32 # batch size (further divided into multiple GPUs)
11
+ self.beta_2 = 0.98 # beta 2 parameter for Adam(W) optimizer
12
+ self.blank_weight = 1.0 # weight of cross-entropy loss for predicting an empty label
13
+ self.char_embedding = True # use character embedding in addition to bert
14
+ self.char_embedding_size = 128 # dimension of the character embedding layer in the character embedding module
15
+ self.decoder_delay_steps = 0 # number of initial steps with frozen decoder
16
+ self.decoder_learning_rate = 6e-4 # initial decoder learning rate
17
+ self.decoder_weight_decay = 1.2e-6 # amount of weight decay
18
+ self.dropout_anchor = 0.5 # dropout at the last layer of anchor classifier
19
+ self.dropout_edge_label = 0.5 # dropout at the last layer of edge label classifier
20
+ self.dropout_edge_presence = 0.5 # dropout at the last layer of edge presence classifier
21
+ self.dropout_label = 0.5 # dropout at the last layer of label classifier
22
+ self.dropout_transformer = 0.5 # dropout for the transformer layers (decoder)
23
+ self.dropout_transformer_attention = 0.1 # dropout for the transformer's attention (decoder)
24
+ self.dropout_word = 0.1 # probability of dropping out a whole word from the encoder (in favour of char embedding)
25
+ self.encoder = "xlm-roberta-base" # pretrained encoder model
26
+ self.encoder_delay_steps = 2000 # number of initial steps with frozen XLM-R
27
+ self.encoder_freeze_embedding = True # freeze the first embedding layer in XLM-R
28
+ self.encoder_learning_rate = 6e-5 # initial encoder learning rate
29
+ self.encoder_weight_decay = 1e-2 # amount of weight decay
30
+ self.lr_decay_multiplier = 100
31
+ self.epochs = 100 # number of epochs for train
32
+ self.focal = True # use focal loss for the label prediction
33
+ self.freeze_bert = False # use focal loss for the label prediction
34
+ self.group_ops = False # group 'opN' edge labels into one
35
+ self.hidden_size_ff = 4 * 768 # hidden size of the transformer feed-forward submodule
36
+ self.hidden_size_anchor = 128 # hidden size anchor biaffine layer
37
+ self.hidden_size_edge_label = 256 # hidden size for edge label biaffine layer
38
+ self.hidden_size_edge_presence = 512 # hidden size for edge label biaffine layer
39
+ self.layerwise_lr_decay = 1.0 # layerwise decay of learning rate in the encoder
40
+ self.n_attention_heads = 8 # number of attention heads in the decoding transformer
41
+ self.n_layers = 3 # number of layers in the decoder
42
+ self.query_length = 4 # number of queries genereted for each word on the input
43
+ self.pre_norm = True # use pre-normalized version of the transformer (as in Transformers without Tears)
44
+ self.warmup_steps = 6000 # number of the warm-up steps for the inverse_sqrt scheduler
45
+
46
+ def init_data_paths(self):
47
+ directory_1 = {
48
+ "sequential": "node_centric_mrp",
49
+ "node-centric": "node_centric_mrp",
50
+ "labeled-edge": "labeled_edge_mrp"
51
+ }[self.graph_mode]
52
+ directory_2 = {
53
+ ("darmstadt", "en"): "darmstadt_unis",
54
+ ("mpqa", "en"): "mpqa",
55
+ ("multibooked", "ca"): "multibooked_ca",
56
+ ("multibooked", "eu"): "multibooked_eu",
57
+ ("norec", "no"): "norec",
58
+ ("opener", "en"): "opener_en",
59
+ ("opener", "es"): "opener_es",
60
+ }[(self.framework, self.language)]
61
+
62
+ self.training_data = f"{self.data_directory}/{directory_1}/{directory_2}/train.mrp"
63
+ self.validation_data = f"{self.data_directory}/{directory_1}/{directory_2}/dev.mrp"
64
+ self.test_data = f"{self.data_directory}/{directory_1}/{directory_2}/test.mrp"
65
+
66
+ self.raw_training_data = f"{self.data_directory}/raw/{directory_2}/train.json"
67
+ self.raw_validation_data = f"{self.data_directory}/raw/{directory_2}/dev.json"
68
+
69
+ return self
70
+
71
+ def load_state_dict(self, d):
72
+ for k, v in d.items():
73
+ setattr(self, k, v)
74
+ return self
75
+
76
+ def state_dict(self):
77
+ members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
78
+ return {k: self.__dict__[k] for k in members}
79
+
80
+ def load(self, args):
81
+ with open(args.config, "r", encoding="utf-8") as f:
82
+ params = yaml.safe_load(f)
83
+ self.load_state_dict(params)
84
+ self.init_data_paths()
85
+
86
+ def save(self, json_path):
87
+ with open(json_path, "w", encoding="utf-8") as f:
88
+ d = self.state_dict()
89
+ yaml.dump(d, f)
data/__init__.py ADDED
File without changes
data/batch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class Batch:
9
+ @staticmethod
10
+ def build(data):
11
+ fields = list(data[0].keys())
12
+ transposed = {}
13
+ for field in fields:
14
+ if isinstance(data[0][field], tuple):
15
+ transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field])))
16
+ else:
17
+ transposed[field] = Batch._stack(field, [example[field] for example in data])
18
+
19
+ return transposed
20
+
21
+ @staticmethod
22
+ def _stack(field: str, examples):
23
+ if field == "anchored_labels":
24
+ return examples
25
+
26
+ dim = examples[0].dim()
27
+
28
+ if dim == 0:
29
+ return torch.stack(examples)
30
+
31
+ lengths = [max(example.size(i) for example in examples) for i in range(dim)]
32
+ if any(length == 0 for length in lengths):
33
+ return torch.LongTensor(len(examples), *lengths)
34
+
35
+ examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples]
36
+ return torch.stack(examples)
37
+
38
+ @staticmethod
39
+ def _pad_size(example, total_size):
40
+ return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))]
41
+
42
+ @staticmethod
43
+ def index_select(batch, indices):
44
+ filtered_batch = {}
45
+ for key, examples in batch.items():
46
+ if isinstance(examples, list) or isinstance(examples, tuple):
47
+ filtered_batch[key] = [example.index_select(0, indices) for example in examples]
48
+ else:
49
+ filtered_batch[key] = examples.index_select(0, indices)
50
+
51
+ return filtered_batch
52
+
53
+ @staticmethod
54
+ def to_str(batch):
55
+ string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()])
56
+ return string
57
+
58
+ @staticmethod
59
+ def to(batch, device):
60
+ converted = {}
61
+ for field in batch.keys():
62
+ converted[field] = Batch._to(batch[field], device)
63
+ return converted
64
+
65
+ @staticmethod
66
+ def _short_str(tensor):
67
+ # unwrap variable to tensor
68
+ if not torch.is_tensor(tensor):
69
+ # (1) unpack variable
70
+ if hasattr(tensor, "data"):
71
+ tensor = getattr(tensor, "data")
72
+ # (2) handle include_lengths
73
+ elif isinstance(tensor, tuple) or isinstance(tensor, list):
74
+ return str(tuple(Batch._short_str(t) for t in tensor))
75
+ # (3) fallback to default str
76
+ else:
77
+ return str(tensor)
78
+
79
+ # copied from torch _tensor_str
80
+ size_str = "x".join(str(size) for size in tensor.size())
81
+ device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device())
82
+ strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str)
83
+ return strt
84
+
85
+ @staticmethod
86
+ def _to(tensor, device):
87
+ if not torch.is_tensor(tensor):
88
+ if isinstance(tensor, tuple):
89
+ return tuple(Batch._to(t, device) for t in tensor)
90
+ elif isinstance(tensor, list):
91
+ return [Batch._to(t, device) for t in tensor]
92
+ else:
93
+ raise Exception(f"unsupported type of {tensor} to be casted to cuda")
94
+
95
+ return tensor.to(device, non_blocking=True)
data/dataset.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import pickle
5
+
6
+ import torch
7
+
8
+ from data.parser.from_mrp.node_centric_parser import NodeCentricParser
9
+ from data.parser.from_mrp.labeled_edge_parser import LabeledEdgeParser
10
+ from data.parser.from_mrp.sequential_parser import SequentialParser
11
+ from data.parser.from_mrp.evaluation_parser import EvaluationParser
12
+ from data.parser.from_mrp.request_parser import RequestParser
13
+ from data.field.edge_field import EdgeField
14
+ from data.field.edge_label_field import EdgeLabelField
15
+ from data.field.field import Field
16
+ from data.field.mini_torchtext.field import Field as TorchTextField
17
+ from data.field.label_field import LabelField
18
+ from data.field.anchored_label_field import AnchoredLabelField
19
+ from data.field.nested_field import NestedField
20
+ from data.field.basic_field import BasicField
21
+ from data.field.bert_field import BertField
22
+ from data.field.anchor_field import AnchorField
23
+ from data.batch import Batch
24
+
25
+
26
+ def char_tokenize(word):
27
+ return [c for i, c in enumerate(word)] # if i < 10 or len(word) - i <= 10]
28
+
29
+
30
+ class Collate:
31
+ def __call__(self, batch):
32
+ batch.sort(key=lambda example: example["every_input"][0].size(0), reverse=True)
33
+ return Batch.build(batch)
34
+
35
+
36
+ class Dataset:
37
+ def __init__(self, args, verbose=True):
38
+ self.verbose = verbose
39
+ self.sos, self.eos, self.pad, self.unk = "<sos>", "<eos>", "<pad>", "<unk>"
40
+
41
+ self.bert_input_field = BertField()
42
+ self.scatter_field = BasicField()
43
+ self.every_word_input_field = Field(lower=True, init_token=self.sos, eos_token=self.eos, batch_first=True, include_lengths=True)
44
+
45
+ char_form_nesting = TorchTextField(tokenize=char_tokenize, init_token=self.sos, eos_token=self.eos, batch_first=True)
46
+ self.char_form_field = NestedField(char_form_nesting, include_lengths=True)
47
+
48
+ self.label_field = LabelField(preprocessing=lambda nodes: [n["label"] for n in nodes])
49
+ self.anchored_label_field = AnchoredLabelField()
50
+
51
+ self.id_field = Field(batch_first=True, tokenize=lambda x: [x])
52
+ self.edge_presence_field = EdgeField()
53
+ self.edge_label_field = EdgeLabelField()
54
+ self.anchor_field = AnchorField()
55
+ self.source_anchor_field = AnchorField()
56
+ self.target_anchor_field = AnchorField()
57
+ self.token_interval_field = BasicField()
58
+
59
+ self.load_dataset(args)
60
+
61
+ def log(self, text):
62
+ if not self.verbose:
63
+ return
64
+ print(text, flush=True)
65
+
66
+ def load_state_dict(self, args, d):
67
+ for key, value in d["vocabs"].items():
68
+ getattr(self, key).vocab = pickle.loads(value)
69
+
70
+ def state_dict(self):
71
+ return {
72
+ "vocabs": {key: pickle.dumps(value.vocab) for key, value in self.__dict__.items() if hasattr(value, "vocab")}
73
+ }
74
+
75
+ def load_sentences(self, sentences, args):
76
+ dataset = RequestParser(
77
+ sentences, args,
78
+ fields={
79
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
80
+ "bert input": ("input", self.bert_input_field),
81
+ "to scatter": ("input_scatter", self.scatter_field),
82
+ "token anchors": ("token_intervals", self.token_interval_field),
83
+ "id": ("id", self.id_field),
84
+ },
85
+ )
86
+
87
+ self.every_word_input_field.build_vocab(dataset, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
88
+ self.id_field.build_vocab(dataset, min_freq=1, specials=[])
89
+
90
+ return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=Collate())
91
+
92
+ def load_dataset(self, args):
93
+ parser = {
94
+ "sequential": SequentialParser,
95
+ "node-centric": NodeCentricParser,
96
+ "labeled-edge": LabeledEdgeParser
97
+ }[args.graph_mode]
98
+
99
+ train = parser(
100
+ args, "training",
101
+ fields={
102
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
103
+ "bert input": ("input", self.bert_input_field),
104
+ "to scatter": ("input_scatter", self.scatter_field),
105
+ "nodes": ("labels", self.label_field),
106
+ "anchored labels": ("anchored_labels", self.anchored_label_field),
107
+ "edge presence": ("edge_presence", self.edge_presence_field),
108
+ "edge labels": ("edge_labels", self.edge_label_field),
109
+ "anchor edges": ("anchor", self.anchor_field),
110
+ "source anchor edges": ("source_anchor", self.source_anchor_field),
111
+ "target anchor edges": ("target_anchor", self.target_anchor_field),
112
+ "token anchors": ("token_intervals", self.token_interval_field),
113
+ "id": ("id", self.id_field),
114
+ },
115
+ filter_pred=lambda example: len(example.input) <= 256,
116
+ )
117
+
118
+ val = parser(
119
+ args, "validation",
120
+ fields={
121
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
122
+ "bert input": ("input", self.bert_input_field),
123
+ "to scatter": ("input_scatter", self.scatter_field),
124
+ "nodes": ("labels", self.label_field),
125
+ "anchored labels": ("anchored_labels", self.anchored_label_field),
126
+ "edge presence": ("edge_presence", self.edge_presence_field),
127
+ "edge labels": ("edge_labels", self.edge_label_field),
128
+ "anchor edges": ("anchor", self.anchor_field),
129
+ "source anchor edges": ("source_anchor", self.source_anchor_field),
130
+ "target anchor edges": ("target_anchor", self.target_anchor_field),
131
+ "token anchors": ("token_intervals", self.token_interval_field),
132
+ "id": ("id", self.id_field),
133
+ },
134
+ )
135
+
136
+ test = EvaluationParser(
137
+ args,
138
+ fields={
139
+ "input": [("every_input", self.every_word_input_field), ("char_form_input", self.char_form_field)],
140
+ "bert input": ("input", self.bert_input_field),
141
+ "to scatter": ("input_scatter", self.scatter_field),
142
+ "token anchors": ("token_intervals", self.token_interval_field),
143
+ "id": ("id", self.id_field),
144
+ },
145
+ )
146
+
147
+ del train.data, val.data, test.data # TODO: why?
148
+ for f in list(train.fields.values()) + list(val.fields.values()) + list(test.fields.values()): # TODO: why?
149
+ if hasattr(f, "preprocessing"):
150
+ del f.preprocessing
151
+
152
+ self.train_size = len(train)
153
+ self.val_size = len(val)
154
+ self.test_size = len(test)
155
+
156
+ self.log(f"\n{self.train_size} sentences in the train split")
157
+ self.log(f"{self.val_size} sentences in the validation split")
158
+ self.log(f"{self.test_size} sentences in the test split")
159
+
160
+ self.node_count = train.node_counter
161
+ self.token_count = train.input_count
162
+ self.edge_count = train.edge_counter
163
+ self.no_edge_count = train.no_edge_counter
164
+ self.anchor_freq = train.anchor_freq
165
+
166
+ self.source_anchor_freq = train.source_anchor_freq if hasattr(train, "source_anchor_freq") else 0.5
167
+ self.target_anchor_freq = train.target_anchor_freq if hasattr(train, "target_anchor_freq") else 0.5
168
+ self.log(f"{self.node_count} nodes in the train split")
169
+
170
+ self.every_word_input_field.build_vocab(val, test, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
171
+ self.char_form_field.build_vocab(train, min_freq=1, specials=[self.pad, self.unk, self.sos, self.eos])
172
+ self.char_form_field.nesting_field.vocab = self.char_form_field.vocab
173
+ self.id_field.build_vocab(train, val, test, min_freq=1, specials=[])
174
+ self.label_field.build_vocab(train)
175
+ self.anchored_label_field.vocab = self.label_field.vocab
176
+ self.edge_label_field.build_vocab(train)
177
+ print(list(self.edge_label_field.vocab.freqs.keys()), flush=True)
178
+
179
+ self.char_form_vocab_size = len(self.char_form_field.vocab)
180
+ self.create_label_freqs(args)
181
+ self.create_edge_freqs(args)
182
+
183
+ self.log(f"Edge frequency: {self.edge_presence_freq*100:.2f} %")
184
+ self.log(f"{len(self.label_field.vocab)} words in the label vocabulary")
185
+ self.log(f"{len(self.anchored_label_field.vocab)} words in the anchored label vocabulary")
186
+ self.log(f"{len(self.edge_label_field.vocab)} words in the edge label vocabulary")
187
+ self.log(f"{len(self.char_form_field.vocab)} characters in the vocabulary")
188
+
189
+ self.log(self.label_field.vocab.freqs)
190
+ self.log(self.anchored_label_field.vocab.freqs)
191
+
192
+ self.train = torch.utils.data.DataLoader(
193
+ train,
194
+ batch_size=args.batch_size,
195
+ shuffle=True,
196
+ num_workers=args.workers,
197
+ collate_fn=Collate(),
198
+ pin_memory=True,
199
+ drop_last=True
200
+ )
201
+ self.train_size = len(self.train.dataset)
202
+
203
+ self.val = torch.utils.data.DataLoader(
204
+ val,
205
+ batch_size=args.batch_size,
206
+ shuffle=False,
207
+ num_workers=args.workers,
208
+ collate_fn=Collate(),
209
+ pin_memory=True,
210
+ )
211
+ self.val_size = len(self.val.dataset)
212
+
213
+ self.test = torch.utils.data.DataLoader(
214
+ test,
215
+ batch_size=args.batch_size,
216
+ shuffle=False,
217
+ num_workers=args.workers,
218
+ collate_fn=Collate(),
219
+ pin_memory=True,
220
+ )
221
+ self.test_size = len(self.test.dataset)
222
+
223
+ if self.verbose:
224
+ batch = next(iter(self.train))
225
+ print(f"\nBatch content: {Batch.to_str(batch)}\n")
226
+ print(flush=True)
227
+
228
+ def create_label_freqs(self, args):
229
+ n_rules = len(self.label_field.vocab)
230
+ blank_count = (args.query_length * self.token_count - self.node_count)
231
+ label_counts = [blank_count] + [
232
+ self.label_field.vocab.freqs[self.label_field.vocab.itos[i]]
233
+ for i in range(n_rules)
234
+ ]
235
+ label_counts = torch.FloatTensor(label_counts)
236
+ self.label_freqs = label_counts / (self.node_count + blank_count)
237
+ self.log(f"Label frequency: {self.label_freqs}")
238
+
239
+ def create_edge_freqs(self, args):
240
+ edge_counter = [
241
+ self.edge_label_field.vocab.freqs[self.edge_label_field.vocab.itos[i]] for i in range(len(self.edge_label_field.vocab))
242
+ ]
243
+ edge_counter = torch.FloatTensor(edge_counter)
244
+ self.edge_label_freqs = edge_counter / self.edge_count
245
+ self.edge_presence_freq = self.edge_count / (self.edge_count + self.no_edge_count)
data/field/__init__.py ADDED
File without changes
data/field/anchor_field.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class AnchorField(RawField):
9
+ def process(self, batch, device=None):
10
+ tensors, masks = self.pad(batch, device)
11
+ return tensors, masks
12
+
13
+ def pad(self, anchors, device):
14
+ tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device)
15
+ for anchor in anchors[-1]:
16
+ tensor[anchor[0], anchor[1]] = 1
17
+ mask = tensor.sum(-1) == 0
18
+
19
+ return tensor, mask
data/field/anchored_label_field.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import RawField
3
+
4
+
5
+ class AnchoredLabelField(RawField):
6
+ def __init__(self):
7
+ super(AnchoredLabelField, self).__init__()
8
+ self.vocab = None
9
+
10
+ def process(self, example, device=None):
11
+ example = self.numericalize(example)
12
+ tensor = self.pad(example, device)
13
+ return tensor
14
+
15
+ def pad(self, example, device):
16
+ n_labels = len(self.vocab)
17
+ n_nodes, n_tokens = len(example[1]), example[0]
18
+
19
+ tensor = torch.full([n_nodes, n_tokens, n_labels + 1], 0, dtype=torch.long, device=device)
20
+ for i_node, node in enumerate(example[1]):
21
+ for anchor, rule in node:
22
+ tensor[i_node, anchor, rule + 1] = 1
23
+
24
+ return tensor
25
+
26
+ def numericalize(self, arr):
27
+ def multi_map(array, function):
28
+ if isinstance(array, tuple):
29
+ return (array[0], function(array[1]))
30
+ elif isinstance(array, list):
31
+ return [multi_map(a, function) for a in array]
32
+ else:
33
+ return array
34
+
35
+ if self.vocab is not None:
36
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
37
+
38
+ return arr
data/field/basic_field.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class BasicField(RawField):
9
+ def process(self, example, device=None):
10
+ tensor = torch.tensor(example, dtype=torch.long, device=device)
11
+ return tensor
data/field/bert_field.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+
7
+
8
+ class BertField(RawField):
9
+ def __init__(self):
10
+ super(BertField, self).__init__()
11
+
12
+ def process(self, example, device=None):
13
+ attention_mask = [1] * len(example)
14
+
15
+ example = torch.LongTensor(example, device=device)
16
+ attention_mask = torch.ones_like(example)
17
+
18
+ return example, attention_mask
data/field/edge_field.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+ from data.field.mini_torchtext.vocab import Vocab
7
+ from collections import Counter
8
+ import types
9
+
10
+
11
+ class EdgeField(RawField):
12
+ def __init__(self):
13
+ super(EdgeField, self).__init__()
14
+ self.vocab = None
15
+
16
+ def process(self, edges, device=None):
17
+ edges = self.numericalize(edges)
18
+ tensor = self.pad(edges, device)
19
+ return tensor
20
+
21
+ def pad(self, edges, device):
22
+ tensor = torch.zeros(edges[0], edges[1], dtype=torch.long, device=device)
23
+ for edge in edges[-1]:
24
+ tensor[edge[0], edge[1]] = edge[2]
25
+
26
+ return tensor
27
+
28
+ def numericalize(self, arr):
29
+ def multi_map(array, function):
30
+ if isinstance(array, tuple):
31
+ return (array[0], array[1], function(array[2]))
32
+ elif isinstance(array, list):
33
+ return [multi_map(array[i], function) for i in range(len(array))]
34
+ else:
35
+ return array
36
+
37
+ if self.vocab is not None:
38
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x is not None else 0)
39
+ return arr
40
+
41
+ def build_vocab(self, *args):
42
+ def generate(l):
43
+ if isinstance(l, tuple):
44
+ yield l[2]
45
+ elif isinstance(l, list) or isinstance(l, types.GeneratorType):
46
+ for i in l:
47
+ yield from generate(i)
48
+ else:
49
+ return
50
+
51
+ counter = Counter()
52
+ sources = []
53
+ for arg in args:
54
+ if isinstance(arg, torch.utils.data.Dataset):
55
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
56
+ else:
57
+ sources.append(arg)
58
+
59
+ for x in generate(sources):
60
+ if x is not None:
61
+ counter.update([x])
62
+
63
+ self.vocab = Vocab(counter, specials=[])
data/field/edge_label_field.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import RawField
6
+ from data.field.mini_torchtext.vocab import Vocab
7
+ from collections import Counter
8
+ import types
9
+
10
+
11
+ class EdgeLabelField(RawField):
12
+ def process(self, edges, device=None):
13
+ edges, masks = self.numericalize(edges)
14
+ edges, masks = self.pad(edges, masks, device)
15
+
16
+ return edges, masks
17
+
18
+ def pad(self, edges, masks, device):
19
+ n_labels = len(self.vocab)
20
+
21
+ tensor = torch.zeros(edges[0], edges[1], n_labels, dtype=torch.long, device=device)
22
+ mask_tensor = torch.zeros(edges[0], edges[1], dtype=torch.bool, device=device)
23
+
24
+ for edge in edges[-1]:
25
+ tensor[edge[0], edge[1], edge[2]] = 1
26
+
27
+ for mask in masks[-1]:
28
+ mask_tensor[mask[0], mask[1]] = mask[2]
29
+
30
+ return tensor, mask_tensor
31
+
32
+ def numericalize(self, arr):
33
+ def multi_map(array, function):
34
+ if isinstance(array, tuple):
35
+ return (array[0], array[1], function(array[2]))
36
+ elif isinstance(array, list):
37
+ return [multi_map(array[i], function) for i in range(len(array))]
38
+ else:
39
+ return array
40
+
41
+ mask = multi_map(arr, lambda x: x is None)
42
+ arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0)
43
+ return arr, mask
44
+
45
+ def build_vocab(self, *args):
46
+ def generate(l):
47
+ if isinstance(l, tuple):
48
+ yield l[2]
49
+ elif isinstance(l, list) or isinstance(l, types.GeneratorType):
50
+ for i in l:
51
+ yield from generate(i)
52
+ else:
53
+ return
54
+
55
+ counter = Counter()
56
+ sources = []
57
+ for arg in args:
58
+ if isinstance(arg, torch.utils.data.Dataset):
59
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
60
+ else:
61
+ sources.append(arg)
62
+
63
+ for x in generate(sources):
64
+ if x is not None:
65
+ counter.update([x])
66
+
67
+ self.vocab = Vocab(counter, specials=[])
data/field/field.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import Field as TorchTextField
3
+ from collections import Counter, OrderedDict
4
+
5
+
6
+ # small change of vocab building to correspond to our version of Dataset
7
+ class Field(TorchTextField):
8
+ def build_vocab(self, *args, **kwargs):
9
+ counter = Counter()
10
+ sources = []
11
+ for arg in args:
12
+ if isinstance(arg, torch.utils.data.Dataset):
13
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
14
+ else:
15
+ sources.append(arg)
16
+ for data in sources:
17
+ for x in data:
18
+ if not self.sequential:
19
+ x = [x]
20
+ counter.update(x)
21
+
22
+ specials = list(
23
+ OrderedDict.fromkeys(
24
+ tok
25
+ for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", [])
26
+ if tok is not None
27
+ )
28
+ )
29
+ self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
30
+
31
+ def process(self, example, device=None):
32
+ if self.include_lengths:
33
+ example = example, len(example)
34
+ tensor = self.numericalize(example, device=device)
35
+ return tensor
36
+
37
+ def numericalize(self, ex, device=None):
38
+ if self.include_lengths and not isinstance(ex, tuple):
39
+ raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).")
40
+
41
+ if isinstance(ex, tuple):
42
+ ex, lengths = ex
43
+ lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
44
+
45
+ if self.use_vocab:
46
+ if self.sequential:
47
+ ex = [self.vocab.stoi[x] for x in ex]
48
+ else:
49
+ ex = self.vocab.stoi[ex]
50
+
51
+ if self.postprocessing is not None:
52
+ ex = self.postprocessing(ex, self.vocab)
53
+ else:
54
+ numericalization_func = self.dtypes[self.dtype]
55
+
56
+ if not self.sequential:
57
+ ex = numericalization_func(ex) if isinstance(ex, str) else ex
58
+ if self.postprocessing is not None:
59
+ ex = self.postprocessing(ex, None)
60
+
61
+ var = torch.tensor(ex, dtype=self.dtype, device=device)
62
+
63
+ if self.sequential and not self.batch_first:
64
+ var.t_()
65
+ if self.sequential:
66
+ var = var.contiguous()
67
+
68
+ if self.include_lengths:
69
+ return var, lengths
70
+ return var
data/field/label_field.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data.field.mini_torchtext.field import RawField
3
+ from data.field.mini_torchtext.vocab import Vocab
4
+ from collections import Counter
5
+
6
+
7
+ class LabelField(RawField):
8
+ def __self__(self, preprocessing):
9
+ super(LabelField, self).__init__(preprocessing=preprocessing)
10
+ self.vocab = None
11
+
12
+ def build_vocab(self, *args, **kwargs):
13
+ sources = []
14
+ for arg in args:
15
+ if isinstance(arg, torch.utils.data.Dataset):
16
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
17
+ else:
18
+ sources.append(arg)
19
+
20
+ counter = Counter()
21
+ for data in sources:
22
+ for x in data:
23
+ counter.update(x)
24
+
25
+ self.vocab = Vocab(counter, specials=[])
26
+
27
+ def process(self, example, device=None):
28
+ tensor, lengths = self.numericalize(example, device=device)
29
+ return tensor, lengths
30
+
31
+ def numericalize(self, example, device=None):
32
+ example = [self.vocab.stoi[x] + 1 for x in example]
33
+ length = torch.LongTensor([len(example)], device=device).squeeze(0)
34
+ tensor = torch.LongTensor(example, device=device)
35
+
36
+ return tensor, length
data/field/mini_torchtext/example.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import six
2
+ import json
3
+ from functools import reduce
4
+
5
+
6
+ class Example(object):
7
+ """Defines a single training or test example.
8
+
9
+ Stores each column of the example as an attribute.
10
+ """
11
+ @classmethod
12
+ def fromJSON(cls, data, fields):
13
+ ex = cls()
14
+ obj = json.loads(data)
15
+
16
+ for key, vals in fields.items():
17
+ if vals is not None:
18
+ if not isinstance(vals, list):
19
+ vals = [vals]
20
+
21
+ for val in vals:
22
+ # for processing the key likes 'foo.bar'
23
+ name, field = val
24
+ ks = key.split('.')
25
+
26
+ def reducer(obj, key):
27
+ if isinstance(obj, list):
28
+ results = []
29
+ for data in obj:
30
+ if key not in data:
31
+ # key error
32
+ raise ValueError("Specified key {} was not found in "
33
+ "the input data".format(key))
34
+ else:
35
+ results.append(data[key])
36
+ return results
37
+ else:
38
+ # key error
39
+ if key not in obj:
40
+ raise ValueError("Specified key {} was not found in "
41
+ "the input data".format(key))
42
+ else:
43
+ return obj[key]
44
+
45
+ v = reduce(reducer, ks, obj)
46
+ setattr(ex, name, field.preprocess(v))
47
+ return ex
48
+
49
+ @classmethod
50
+ def fromdict(cls, data, fields):
51
+ ex = cls()
52
+ for key, vals in fields.items():
53
+ if key not in data:
54
+ raise ValueError("Specified key {} was not found in "
55
+ "the input data".format(key))
56
+ if vals is not None:
57
+ if not isinstance(vals, list):
58
+ vals = [vals]
59
+ for val in vals:
60
+ name, field = val
61
+ setattr(ex, name, field.preprocess(data[key]))
62
+ return ex
63
+
64
+ @classmethod
65
+ def fromCSV(cls, data, fields, field_to_index=None):
66
+ if field_to_index is None:
67
+ return cls.fromlist(data, fields)
68
+ else:
69
+ assert(isinstance(fields, dict))
70
+ data_dict = {f: data[idx] for f, idx in field_to_index.items()}
71
+ return cls.fromdict(data_dict, fields)
72
+
73
+ @classmethod
74
+ def fromlist(cls, data, fields):
75
+ ex = cls()
76
+ for (name, field), val in zip(fields, data):
77
+ if field is not None:
78
+ if isinstance(val, six.string_types):
79
+ val = val.rstrip('\n')
80
+ # Handle field tuples
81
+ if isinstance(name, tuple):
82
+ for n, f in zip(name, field):
83
+ setattr(ex, n, f.preprocess(val))
84
+ else:
85
+ setattr(ex, name, field.preprocess(val))
86
+ return ex
87
+
88
+ @classmethod
89
+ def fromtree(cls, data, fields, subtrees=False):
90
+ try:
91
+ from nltk.tree import Tree
92
+ except ImportError:
93
+ print("Please install NLTK. "
94
+ "See the docs at http://nltk.org for more information.")
95
+ raise
96
+ tree = Tree.fromstring(data)
97
+ if subtrees:
98
+ return [cls.fromlist(
99
+ [' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()]
100
+ return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields)
data/field/mini_torchtext/field.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+ from collections import Counter, OrderedDict
3
+ from itertools import chain
4
+ import six
5
+ import torch
6
+
7
+ from .pipeline import Pipeline
8
+ from .utils import get_tokenizer, dtype_to_attr, is_tokenizer_serializable
9
+ from .vocab import Vocab
10
+
11
+
12
+ class RawField(object):
13
+ """ Defines a general datatype.
14
+
15
+ Every dataset consists of one or more types of data. For instance, a text
16
+ classification dataset contains sentences and their classes, while a
17
+ machine translation dataset contains paired examples of text in two
18
+ languages. Each of these types of data is represented by a RawField object.
19
+ A RawField object does not assume any property of the data type and
20
+ it holds parameters relating to how a datatype should be processed.
21
+
22
+ Attributes:
23
+ preprocessing: The Pipeline that will be applied to examples
24
+ using this field before creating an example.
25
+ Default: None.
26
+ postprocessing: A Pipeline that will be applied to a list of examples
27
+ using this field before assigning to a batch.
28
+ Function signature: (batch(list)) -> object
29
+ Default: None.
30
+ is_target: Whether this field is a target variable.
31
+ Affects iteration over batches. Default: False
32
+ """
33
+
34
+ def __init__(self, preprocessing=None, postprocessing=None, is_target=False):
35
+ self.preprocessing = preprocessing
36
+ self.postprocessing = postprocessing
37
+ self.is_target = is_target
38
+
39
+ def preprocess(self, x):
40
+ """ Preprocess an example if the `preprocessing` Pipeline is provided. """
41
+ if hasattr(self, "preprocessing") and self.preprocessing is not None:
42
+ return self.preprocessing(x)
43
+ else:
44
+ return x
45
+
46
+ def process(self, batch, *args, **kwargs):
47
+ """ Process a list of examples to create a batch.
48
+
49
+ Postprocess the batch with user-provided Pipeline.
50
+
51
+ Args:
52
+ batch (list(object)): A list of object from a batch of examples.
53
+ Returns:
54
+ object: Processed object given the input and custom
55
+ postprocessing Pipeline.
56
+ """
57
+ if self.postprocessing is not None:
58
+ batch = self.postprocessing(batch)
59
+ return batch
60
+
61
+
62
+ class Field(RawField):
63
+ """Defines a datatype together with instructions for converting to Tensor.
64
+
65
+ Field class models common text processing datatypes that can be represented
66
+ by tensors. It holds a Vocab object that defines the set of possible values
67
+ for elements of the field and their corresponding numerical representations.
68
+ The Field object also holds other parameters relating to how a datatype
69
+ should be numericalized, such as a tokenization method and the kind of
70
+ Tensor that should be produced.
71
+
72
+ If a Field is shared between two columns in a dataset (e.g., question and
73
+ answer in a QA dataset), then they will have a shared vocabulary.
74
+
75
+ Attributes:
76
+ sequential: Whether the datatype represents sequential data. If False,
77
+ no tokenization is applied. Default: True.
78
+ use_vocab: Whether to use a Vocab object. If False, the data in this
79
+ field should already be numerical. Default: True.
80
+ init_token: A token that will be prepended to every example using this
81
+ field, or None for no initial token. Default: None.
82
+ eos_token: A token that will be appended to every example using this
83
+ field, or None for no end-of-sentence token. Default: None.
84
+ fix_length: A fixed length that all examples using this field will be
85
+ padded to, or None for flexible sequence lengths. Default: None.
86
+ dtype: The torch.dtype class that represents a batch of examples
87
+ of this kind of data. Default: torch.long.
88
+ preprocessing: The Pipeline that will be applied to examples
89
+ using this field after tokenizing but before numericalizing. Many
90
+ Datasets replace this attribute with a custom preprocessor.
91
+ Default: None.
92
+ postprocessing: A Pipeline that will be applied to examples using
93
+ this field after numericalizing but before the numbers are turned
94
+ into a Tensor. The pipeline function takes the batch as a list, and
95
+ the field's Vocab.
96
+ Default: None.
97
+ lower: Whether to lowercase the text in this field. Default: False.
98
+ tokenize: The function used to tokenize strings using this field into
99
+ sequential examples. If "spacy", the SpaCy tokenizer is
100
+ used. If a non-serializable function is passed as an argument,
101
+ the field will not be able to be serialized. Default: string.split.
102
+ tokenizer_language: The language of the tokenizer to be constructed.
103
+ Various languages currently supported only in SpaCy.
104
+ include_lengths: Whether to return a tuple of a padded minibatch and
105
+ a list containing the lengths of each examples, or just a padded
106
+ minibatch. Default: False.
107
+ batch_first: Whether to produce tensors with the batch dimension first.
108
+ Default: False.
109
+ pad_token: The string token used as padding. Default: "<pad>".
110
+ unk_token: The string token used to represent OOV words. Default: "<unk>".
111
+ pad_first: Do the padding of the sequence at the beginning. Default: False.
112
+ truncate_first: Do the truncating of the sequence at the beginning. Default: False
113
+ stop_words: Tokens to discard during the preprocessing step. Default: None
114
+ is_target: Whether this field is a target variable.
115
+ Affects iteration over batches. Default: False
116
+ """
117
+
118
+ vocab_cls = Vocab
119
+ # Dictionary mapping PyTorch tensor dtypes to the appropriate Python
120
+ # numeric type.
121
+ dtypes = {
122
+ torch.float32: float,
123
+ torch.float: float,
124
+ torch.float64: float,
125
+ torch.double: float,
126
+ torch.float16: float,
127
+ torch.half: float,
128
+
129
+ torch.uint8: int,
130
+ torch.int8: int,
131
+ torch.int16: int,
132
+ torch.short: int,
133
+ torch.int32: int,
134
+ torch.int: int,
135
+ torch.int64: int,
136
+ torch.long: int,
137
+ }
138
+
139
+ ignore = ['dtype', 'tokenize']
140
+
141
+ def __init__(self, sequential=True, use_vocab=True, init_token=None,
142
+ eos_token=None, fix_length=None, dtype=torch.long,
143
+ preprocessing=None, postprocessing=None, lower=False,
144
+ tokenize=None, tokenizer_language='en', include_lengths=False,
145
+ batch_first=False, pad_token="<pad>", unk_token="<unk>",
146
+ pad_first=False, truncate_first=False, stop_words=None,
147
+ is_target=False):
148
+ self.sequential = sequential
149
+ self.use_vocab = use_vocab
150
+ self.init_token = init_token
151
+ self.eos_token = eos_token
152
+ self.unk_token = unk_token
153
+ self.fix_length = fix_length
154
+ self.dtype = dtype
155
+ self.preprocessing = preprocessing
156
+ self.postprocessing = postprocessing
157
+ self.lower = lower
158
+ # store params to construct tokenizer for serialization
159
+ # in case the tokenizer isn't picklable (e.g. spacy)
160
+ self.tokenizer_args = (tokenize, tokenizer_language)
161
+ self.tokenize = get_tokenizer(tokenize, tokenizer_language)
162
+ self.include_lengths = include_lengths
163
+ self.batch_first = batch_first
164
+ self.pad_token = pad_token if self.sequential else None
165
+ self.pad_first = pad_first
166
+ self.truncate_first = truncate_first
167
+ try:
168
+ self.stop_words = set(stop_words) if stop_words is not None else None
169
+ except TypeError:
170
+ raise ValueError("Stop words must be convertible to a set")
171
+ self.is_target = is_target
172
+
173
+ def __getstate__(self):
174
+ str_type = dtype_to_attr(self.dtype)
175
+ if is_tokenizer_serializable(*self.tokenizer_args):
176
+ tokenize = self.tokenize
177
+ else:
178
+ # signal to restore in `__setstate__`
179
+ tokenize = None
180
+ attrs = {k: v for k, v in self.__dict__.items() if k not in self.ignore}
181
+ attrs['dtype'] = str_type
182
+ attrs['tokenize'] = tokenize
183
+
184
+ return attrs
185
+
186
+ def __setstate__(self, state):
187
+ state['dtype'] = getattr(torch, state['dtype'])
188
+ if not state['tokenize']:
189
+ state['tokenize'] = get_tokenizer(*state['tokenizer_args'])
190
+ self.__dict__.update(state)
191
+
192
+ def __hash__(self):
193
+ # we don't expect this to be called often
194
+ return 42
195
+
196
+ def __eq__(self, other):
197
+ if not isinstance(other, RawField):
198
+ return False
199
+
200
+ return self.__dict__ == other.__dict__
201
+
202
+ def preprocess(self, x):
203
+ """Load a single example using this field, tokenizing if necessary.
204
+
205
+ If the input is a Python 2 `str`, it will be converted to Unicode
206
+ first. If `sequential=True`, it will be tokenized. Then the input
207
+ will be optionally lowercased and passed to the user-provided
208
+ `preprocessing` Pipeline."""
209
+ if (six.PY2 and isinstance(x, six.string_types)
210
+ and not isinstance(x, six.text_type)):
211
+ x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
212
+ if self.sequential and isinstance(x, six.text_type):
213
+ x = self.tokenize(x.rstrip('\n'))
214
+ if self.lower:
215
+ x = Pipeline(six.text_type.lower)(x)
216
+ if self.sequential and self.use_vocab and self.stop_words is not None:
217
+ x = [w for w in x if w not in self.stop_words]
218
+ if hasattr(self, "preprocessing") and self.preprocessing is not None:
219
+ return self.preprocessing(x)
220
+ else:
221
+ return x
222
+
223
+ def process(self, batch, device=None):
224
+ """ Process a list of examples to create a torch.Tensor.
225
+
226
+ Pad, numericalize, and postprocess a batch and create a tensor.
227
+
228
+ Args:
229
+ batch (list(object)): A list of object from a batch of examples.
230
+ Returns:
231
+ torch.autograd.Variable: Processed object given the input
232
+ and custom postprocessing Pipeline.
233
+ """
234
+ padded = self.pad(batch)
235
+ tensor = self.numericalize(padded, device=device)
236
+ return tensor
237
+
238
+ def pad(self, minibatch):
239
+ """Pad a batch of examples using this field.
240
+
241
+ Pads to self.fix_length if provided, otherwise pads to the length of
242
+ the longest example in the batch. Prepends self.init_token and appends
243
+ self.eos_token if those attributes are not None. Returns a tuple of the
244
+ padded list and a list containing lengths of each example if
245
+ `self.include_lengths` is `True` and `self.sequential` is `True`, else just
246
+ returns the padded list. If `self.sequential` is `False`, no padding is applied.
247
+ """
248
+ minibatch = list(minibatch)
249
+ if not self.sequential:
250
+ return minibatch
251
+ if self.fix_length is None:
252
+ max_len = max(len(x) for x in minibatch)
253
+ else:
254
+ max_len = self.fix_length + (
255
+ self.init_token, self.eos_token).count(None) - 2
256
+ padded, lengths = [], []
257
+ for x in minibatch:
258
+ if self.pad_first:
259
+ padded.append(
260
+ [self.pad_token] * max(0, max_len - len(x))
261
+ + ([] if self.init_token is None else [self.init_token])
262
+ + list(x[-max_len:] if self.truncate_first else x[:max_len])
263
+ + ([] if self.eos_token is None else [self.eos_token]))
264
+ else:
265
+ padded.append(
266
+ ([] if self.init_token is None else [self.init_token])
267
+ + list(x[-max_len:] if self.truncate_first else x[:max_len])
268
+ + ([] if self.eos_token is None else [self.eos_token])
269
+ + [self.pad_token] * max(0, max_len - len(x)))
270
+ lengths.append(len(padded[-1]) - max(0, max_len - len(x)))
271
+ if self.include_lengths:
272
+ return (padded, lengths)
273
+ return padded
274
+
275
+ def build_vocab(self, *args, **kwargs):
276
+ """Construct the Vocab object for this field from one or more datasets.
277
+
278
+ Arguments:
279
+ Positional arguments: Dataset objects or other iterable data
280
+ sources from which to construct the Vocab object that
281
+ represents the set of possible values for this field. If
282
+ a Dataset object is provided, all columns corresponding
283
+ to this field are used; individual columns can also be
284
+ provided directly.
285
+ Remaining keyword arguments: Passed to the constructor of Vocab.
286
+ """
287
+ counter = Counter()
288
+ sources = []
289
+ for arg in args:
290
+ sources.append(arg)
291
+ for data in sources:
292
+ for x in data:
293
+ if not self.sequential:
294
+ x = [x]
295
+ try:
296
+ counter.update(x)
297
+ except TypeError:
298
+ counter.update(chain.from_iterable(x))
299
+ specials = list(OrderedDict.fromkeys(
300
+ tok for tok in [self.unk_token, self.pad_token, self.init_token,
301
+ self.eos_token] + kwargs.pop('specials', [])
302
+ if tok is not None))
303
+ self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
304
+
305
+ def numericalize(self, arr, device=None):
306
+ """Turn a batch of examples that use this field into a Variable.
307
+
308
+ If the field has include_lengths=True, a tensor of lengths will be
309
+ included in the return value.
310
+
311
+ Arguments:
312
+ arr (List[List[str]], or tuple of (List[List[str]], List[int])):
313
+ List of tokenized and padded examples, or tuple of List of
314
+ tokenized and padded examples and List of lengths of each
315
+ example if self.include_lengths is True.
316
+ device (str or torch.device): A string or instance of `torch.device`
317
+ specifying which device the Variables are going to be created on.
318
+ If left as default, the tensors will be created on cpu. Default: None.
319
+ """
320
+ if self.include_lengths and not isinstance(arr, tuple):
321
+ raise ValueError("Field has include_lengths set to True, but "
322
+ "input data is not a tuple of "
323
+ "(data batch, batch lengths).")
324
+ if isinstance(arr, tuple):
325
+ arr, lengths = arr
326
+ lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
327
+
328
+ if self.use_vocab:
329
+ if self.sequential:
330
+ arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
331
+ else:
332
+ arr = [self.vocab.stoi[x] for x in arr]
333
+
334
+ if self.postprocessing is not None:
335
+ arr = self.postprocessing(arr, self.vocab)
336
+ else:
337
+ if self.dtype not in self.dtypes:
338
+ raise ValueError(
339
+ "Specified Field dtype {} can not be used with "
340
+ "use_vocab=False because we do not know how to numericalize it. "
341
+ "Please raise an issue at "
342
+ "https://github.com/pytorch/text/issues".format(self.dtype))
343
+ numericalization_func = self.dtypes[self.dtype]
344
+ # It doesn't make sense to explicitly coerce to a numeric type if
345
+ # the data is sequential, since it's unclear how to coerce padding tokens
346
+ # to a numeric type.
347
+ if not self.sequential:
348
+ arr = [numericalization_func(x) if isinstance(x, six.string_types)
349
+ else x for x in arr]
350
+ if self.postprocessing is not None:
351
+ arr = self.postprocessing(arr, None)
352
+
353
+ var = torch.tensor(arr, dtype=self.dtype, device=device)
354
+
355
+ if self.sequential and not self.batch_first:
356
+ var.t_()
357
+ if self.sequential:
358
+ var = var.contiguous()
359
+
360
+ if self.include_lengths:
361
+ return var, lengths
362
+ return var
363
+
364
+
365
+ class NestedField(Field):
366
+ """A nested field.
367
+
368
+ A nested field holds another field (called *nesting field*), accepts an untokenized
369
+ string or a list string tokens and groups and treats them as one field as described
370
+ by the nesting field. Every token will be preprocessed, padded, etc. in the manner
371
+ specified by the nesting field. Note that this means a nested field always has
372
+ ``sequential=True``. The two fields' vocabularies will be shared. Their
373
+ numericalization results will be stacked into a single tensor. And NestedField will
374
+ share the same include_lengths with nesting_field, so one shouldn't specify the
375
+ include_lengths in the nesting_field. This field is
376
+ primarily used to implement character embeddings. See ``tests/data/test_field.py``
377
+ for examples on how to use this field.
378
+
379
+ Arguments:
380
+ nesting_field (Field): A field contained in this nested field.
381
+ use_vocab (bool): Whether to use a Vocab object. If False, the data in this
382
+ field should already be numerical. Default: ``True``.
383
+ init_token (str): A token that will be prepended to every example using this
384
+ field, or None for no initial token. Default: ``None``.
385
+ eos_token (str): A token that will be appended to every example using this
386
+ field, or None for no end-of-sentence token. Default: ``None``.
387
+ fix_length (int): A fixed length that all examples using this field will be
388
+ padded to, or ``None`` for flexible sequence lengths. Default: ``None``.
389
+ dtype: The torch.dtype class that represents a batch of examples
390
+ of this kind of data. Default: ``torch.long``.
391
+ preprocessing (Pipeline): The Pipeline that will be applied to examples
392
+ using this field after tokenizing but before numericalizing. Many
393
+ Datasets replace this attribute with a custom preprocessor.
394
+ Default: ``None``.
395
+ postprocessing (Pipeline): A Pipeline that will be applied to examples using
396
+ this field after numericalizing but before the numbers are turned
397
+ into a Tensor. The pipeline function takes the batch as a list, and
398
+ the field's Vocab. Default: ``None``.
399
+ include_lengths: Whether to return a tuple of a padded minibatch and
400
+ a list containing the lengths of each examples, or just a padded
401
+ minibatch. Default: False.
402
+ tokenize: The function used to tokenize strings using this field into
403
+ sequential examples. If "spacy", the SpaCy tokenizer is
404
+ used. If a non-serializable function is passed as an argument,
405
+ the field will not be able to be serialized. Default: string.split.
406
+ tokenizer_language: The language of the tokenizer to be constructed.
407
+ Various languages currently supported only in SpaCy.
408
+ pad_token (str): The string token used as padding. If ``nesting_field`` is
409
+ sequential, this will be set to its ``pad_token``. Default: ``"<pad>"``.
410
+ pad_first (bool): Do the padding of the sequence at the beginning. Default:
411
+ ``False``.
412
+ """
413
+
414
+ def __init__(self, nesting_field, use_vocab=True, init_token=None, eos_token=None,
415
+ fix_length=None, dtype=torch.long, preprocessing=None,
416
+ postprocessing=None, tokenize=None, tokenizer_language='en',
417
+ include_lengths=False, pad_token='<pad>',
418
+ pad_first=False, truncate_first=False):
419
+ if isinstance(nesting_field, NestedField):
420
+ raise ValueError('nesting field must not be another NestedField')
421
+ if nesting_field.include_lengths:
422
+ raise ValueError('nesting field cannot have include_lengths=True')
423
+
424
+ if nesting_field.sequential:
425
+ pad_token = nesting_field.pad_token
426
+ super(NestedField, self).__init__(
427
+ use_vocab=use_vocab,
428
+ init_token=init_token,
429
+ eos_token=eos_token,
430
+ fix_length=fix_length,
431
+ dtype=dtype,
432
+ preprocessing=preprocessing,
433
+ postprocessing=postprocessing,
434
+ lower=nesting_field.lower,
435
+ tokenize=tokenize,
436
+ tokenizer_language=tokenizer_language,
437
+ batch_first=True,
438
+ pad_token=pad_token,
439
+ unk_token=nesting_field.unk_token,
440
+ pad_first=pad_first,
441
+ truncate_first=truncate_first,
442
+ include_lengths=include_lengths
443
+ )
444
+ self.nesting_field = nesting_field
445
+ # in case the user forget to do that
446
+ self.nesting_field.batch_first = True
447
+
448
+ def preprocess(self, xs):
449
+ """Preprocess a single example.
450
+
451
+ Firstly, tokenization and the supplied preprocessing pipeline is applied. Since
452
+ this field is always sequential, the result is a list. Then, each element of
453
+ the list is preprocessed using ``self.nesting_field.preprocess`` and the resulting
454
+ list is returned.
455
+
456
+ Arguments:
457
+ xs (list or str): The input to preprocess.
458
+
459
+ Returns:
460
+ list: The preprocessed list.
461
+ """
462
+ return [self.nesting_field.preprocess(x)
463
+ for x in super(NestedField, self).preprocess(xs)]
464
+
465
+ def pad(self, minibatch):
466
+ """Pad a batch of examples using this field.
467
+
468
+ If ``self.nesting_field.sequential`` is ``False``, each example in the batch must
469
+ be a list of string tokens, and pads them as if by a ``Field`` with
470
+ ``sequential=True``. Otherwise, each example must be a list of list of tokens.
471
+ Using ``self.nesting_field``, pads the list of tokens to
472
+ ``self.nesting_field.fix_length`` if provided, or otherwise to the length of the
473
+ longest list of tokens in the batch. Next, using this field, pads the result by
474
+ filling short examples with ``self.nesting_field.pad_token``.
475
+
476
+ Example:
477
+ >>> import pprint
478
+ >>> pp = pprint.PrettyPrinter(indent=4)
479
+ >>>
480
+ >>> nesting_field = Field(pad_token='<c>', init_token='<w>', eos_token='</w>')
481
+ >>> field = NestedField(nesting_field, init_token='<s>', eos_token='</s>')
482
+ >>> minibatch = [
483
+ ... [list('john'), list('loves'), list('mary')],
484
+ ... [list('mary'), list('cries')],
485
+ ... ]
486
+ >>> padded = field.pad(minibatch)
487
+ >>> pp.pprint(padded)
488
+ [ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
489
+ ['<w>', 'j', 'o', 'h', 'n', '</w>', '<c>'],
490
+ ['<w>', 'l', 'o', 'v', 'e', 's', '</w>'],
491
+ ['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
492
+ ['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>']],
493
+ [ ['<w>', '<s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
494
+ ['<w>', 'm', 'a', 'r', 'y', '</w>', '<c>'],
495
+ ['<w>', 'c', 'r', 'i', 'e', 's', '</w>'],
496
+ ['<w>', '</s>', '</w>', '<c>', '<c>', '<c>', '<c>'],
497
+ ['<c>', '<c>', '<c>', '<c>', '<c>', '<c>', '<c>']]]
498
+
499
+ Arguments:
500
+ minibatch (list): Each element is a list of string if
501
+ ``self.nesting_field.sequential`` is ``False``, a list of list of string
502
+ otherwise.
503
+
504
+ Returns:
505
+ list: The padded minibatch. or (padded, sentence_lens, word_lengths)
506
+ """
507
+ minibatch = list(minibatch)
508
+ if not self.nesting_field.sequential:
509
+ return super(NestedField, self).pad(minibatch)
510
+
511
+ # Save values of attributes to be monkeypatched
512
+ old_pad_token = self.pad_token
513
+ old_init_token = self.init_token
514
+ old_eos_token = self.eos_token
515
+ old_fix_len = self.nesting_field.fix_length
516
+ # Monkeypatch the attributes
517
+ if self.nesting_field.fix_length is None:
518
+ max_len = max(len(xs) for ex in minibatch for xs in ex)
519
+ fix_len = max_len + 2 - (self.nesting_field.init_token,
520
+ self.nesting_field.eos_token).count(None)
521
+ self.nesting_field.fix_length = fix_len
522
+ self.pad_token = [self.pad_token] * self.nesting_field.fix_length
523
+ if self.init_token is not None:
524
+ # self.init_token = self.nesting_field.pad([[self.init_token]])[0]
525
+ self.init_token = [self.init_token]
526
+ if self.eos_token is not None:
527
+ # self.eos_token = self.nesting_field.pad([[self.eos_token]])[0]
528
+ self.eos_token = [self.eos_token]
529
+ # Do padding
530
+ old_include_lengths = self.include_lengths
531
+ self.include_lengths = True
532
+ self.nesting_field.include_lengths = True
533
+ padded, sentence_lengths = super(NestedField, self).pad(minibatch)
534
+ padded_with_lengths = [self.nesting_field.pad(ex) for ex in padded]
535
+ word_lengths = []
536
+ final_padded = []
537
+ max_sen_len = len(padded[0])
538
+ for (pad, lens), sentence_len in zip(padded_with_lengths, sentence_lengths):
539
+ if sentence_len == max_sen_len:
540
+ lens = lens
541
+ pad = pad
542
+ elif self.pad_first:
543
+ lens[:(max_sen_len - sentence_len)] = (
544
+ [0] * (max_sen_len - sentence_len))
545
+ pad[:(max_sen_len - sentence_len)] = (
546
+ [self.pad_token] * (max_sen_len - sentence_len))
547
+ else:
548
+ lens[-(max_sen_len - sentence_len):] = (
549
+ [0] * (max_sen_len - sentence_len))
550
+ pad[-(max_sen_len - sentence_len):] = (
551
+ [self.pad_token] * (max_sen_len - sentence_len))
552
+ word_lengths.append(lens)
553
+ final_padded.append(pad)
554
+ padded = final_padded
555
+
556
+ # Restore monkeypatched attributes
557
+ self.nesting_field.fix_length = old_fix_len
558
+ self.pad_token = old_pad_token
559
+ self.init_token = old_init_token
560
+ self.eos_token = old_eos_token
561
+ self.include_lengths = old_include_lengths
562
+ if self.include_lengths:
563
+ return padded, sentence_lengths, word_lengths
564
+ return padded
565
+
566
+ def build_vocab(self, *args, **kwargs):
567
+ """Construct the Vocab object for nesting field and combine it with this field's vocab.
568
+
569
+ Arguments:
570
+ Positional arguments: Dataset objects or other iterable data
571
+ sources from which to construct the Vocab object that
572
+ represents the set of possible values for the nesting field. If
573
+ a Dataset object is provided, all columns corresponding
574
+ to this field are used; individual columns can also be
575
+ provided directly.
576
+ Remaining keyword arguments: Passed to the constructor of Vocab.
577
+ """
578
+ sources = []
579
+ for arg in args:
580
+ sources.append(arg)
581
+
582
+ flattened = []
583
+ for source in sources:
584
+ flattened.extend(source)
585
+ old_vectors = None
586
+ old_unk_init = None
587
+ old_vectors_cache = None
588
+ if "vectors" in kwargs.keys():
589
+ old_vectors = kwargs["vectors"]
590
+ kwargs["vectors"] = None
591
+ if "unk_init" in kwargs.keys():
592
+ old_unk_init = kwargs["unk_init"]
593
+ kwargs["unk_init"] = None
594
+ if "vectors_cache" in kwargs.keys():
595
+ old_vectors_cache = kwargs["vectors_cache"]
596
+ kwargs["vectors_cache"] = None
597
+ # just build vocab and does not load vector
598
+ self.nesting_field.build_vocab(*flattened, **kwargs)
599
+ super(NestedField, self).build_vocab()
600
+ self.vocab.extend(self.nesting_field.vocab)
601
+ self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
602
+ if old_vectors is not None:
603
+ self.vocab.load_vectors(old_vectors,
604
+ unk_init=old_unk_init, cache=old_vectors_cache)
605
+
606
+ self.nesting_field.vocab = self.vocab
607
+
608
+ def numericalize(self, arrs, device=None):
609
+ """Convert a padded minibatch into a variable tensor.
610
+
611
+ Each item in the minibatch will be numericalized independently and the resulting
612
+ tensors will be stacked at the first dimension.
613
+
614
+ Arguments:
615
+ arr (List[List[str]]): List of tokenized and padded examples.
616
+ device (str or torch.device): A string or instance of `torch.device`
617
+ specifying which device the Variables are going to be created on.
618
+ If left as default, the tensors will be created on cpu. Default: None.
619
+ """
620
+ numericalized = []
621
+ self.nesting_field.include_lengths = False
622
+ if self.include_lengths:
623
+ arrs, sentence_lengths, word_lengths = arrs
624
+
625
+ for arr in arrs:
626
+ numericalized_ex = self.nesting_field.numericalize(
627
+ arr, device=device)
628
+ numericalized.append(numericalized_ex)
629
+ padded_batch = torch.stack(numericalized)
630
+
631
+ self.nesting_field.include_lengths = True
632
+ if self.include_lengths:
633
+ sentence_lengths = \
634
+ torch.tensor(sentence_lengths, dtype=self.dtype, device=device)
635
+ word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
636
+ return (padded_batch, sentence_lengths, word_lengths)
637
+ return padded_batch
data/field/mini_torchtext/pipeline.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Pipeline(object):
2
+ """Defines a pipeline for transforming sequence data.
3
+
4
+ The input is assumed to be utf-8 encoded `str` (Python 3) or
5
+ `unicode` (Python 2).
6
+
7
+ Attributes:
8
+ convert_token: The function to apply to input sequence data.
9
+ pipes: The Pipelines that will be applied to input sequence
10
+ data in order.
11
+ """
12
+
13
+ def __init__(self, convert_token=None):
14
+ """Create a pipeline.
15
+
16
+ Arguments:
17
+ convert_token: The function to apply to input sequence data.
18
+ If None, the identity function is used. Default: None
19
+ """
20
+ if convert_token is None:
21
+ self.convert_token = Pipeline.identity
22
+ elif callable(convert_token):
23
+ self.convert_token = convert_token
24
+ else:
25
+ raise ValueError("Pipeline input convert_token {} is not None "
26
+ "or callable".format(convert_token))
27
+ self.pipes = [self]
28
+
29
+ def __call__(self, x, *args):
30
+ """Apply the the current Pipeline(s) to an input.
31
+
32
+ Arguments:
33
+ x: The input to process with the Pipeline(s).
34
+ Positional arguments: Forwarded to the `call` function
35
+ of the Pipeline(s).
36
+ """
37
+ for pipe in self.pipes:
38
+ x = pipe.call(x, *args)
39
+ return x
40
+
41
+ def call(self, x, *args):
42
+ """Apply _only_ the convert_token function of the current pipeline
43
+ to the input. If the input is a list, a list with the results of
44
+ applying the `convert_token` function to all input elements is
45
+ returned.
46
+
47
+ Arguments:
48
+ x: The input to apply the convert_token function to.
49
+ Positional arguments: Forwarded to the `convert_token` function
50
+ of the current Pipeline.
51
+ """
52
+ if isinstance(x, list):
53
+ return [self.convert_token(tok, *args) for tok in x]
54
+ return self.convert_token(x, *args)
55
+
56
+ def add_before(self, pipeline):
57
+ """Add a Pipeline to be applied before this processing pipeline.
58
+
59
+ Arguments:
60
+ pipeline: The Pipeline or callable to apply before this
61
+ Pipeline.
62
+ """
63
+ if not isinstance(pipeline, Pipeline):
64
+ pipeline = Pipeline(pipeline)
65
+ self.pipes = pipeline.pipes[:] + self.pipes[:]
66
+ return self
67
+
68
+ def add_after(self, pipeline):
69
+ """Add a Pipeline to be applied after this processing pipeline.
70
+
71
+ Arguments:
72
+ pipeline: The Pipeline or callable to apply after this
73
+ Pipeline.
74
+ """
75
+ if not isinstance(pipeline, Pipeline):
76
+ pipeline = Pipeline(pipeline)
77
+ self.pipes = self.pipes[:] + pipeline.pipes[:]
78
+ return self
79
+
80
+ @staticmethod
81
+ def identity(x):
82
+ """Return a copy of the input.
83
+
84
+ This is here for serialization compatibility with pickle.
85
+ """
86
+ return x
data/field/mini_torchtext/utils.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from contextlib import contextmanager
3
+ from copy import deepcopy
4
+ import re
5
+
6
+ from functools import partial
7
+
8
+
9
+ def _split_tokenizer(x):
10
+ return x.split()
11
+
12
+
13
+ def _spacy_tokenize(x, spacy):
14
+ return [tok.text for tok in spacy.tokenizer(x)]
15
+
16
+
17
+ _patterns = [r'\'',
18
+ r'\"',
19
+ r'\.',
20
+ r'<br \/>',
21
+ r',',
22
+ r'\(',
23
+ r'\)',
24
+ r'\!',
25
+ r'\?',
26
+ r'\;',
27
+ r'\:',
28
+ r'\s+']
29
+
30
+ _replacements = [' \' ',
31
+ '',
32
+ ' . ',
33
+ ' ',
34
+ ' , ',
35
+ ' ( ',
36
+ ' ) ',
37
+ ' ! ',
38
+ ' ? ',
39
+ ' ',
40
+ ' ',
41
+ ' ']
42
+
43
+ _patterns_dict = list((re.compile(p), r) for p, r in zip(_patterns, _replacements))
44
+
45
+
46
+ def _basic_english_normalize(line):
47
+ r"""
48
+ Basic normalization for a line of text.
49
+ Normalization includes
50
+ - lowercasing
51
+ - complete some basic text normalization for English words as follows:
52
+ add spaces before and after '\''
53
+ remove '\"',
54
+ add spaces before and after '.'
55
+ replace '<br \/>'with single space
56
+ add spaces before and after ','
57
+ add spaces before and after '('
58
+ add spaces before and after ')'
59
+ add spaces before and after '!'
60
+ add spaces before and after '?'
61
+ replace ';' with single space
62
+ replace ':' with single space
63
+ replace multiple spaces with single space
64
+
65
+ Returns a list of tokens after splitting on whitespace.
66
+ """
67
+
68
+ line = line.lower()
69
+ for pattern_re, replaced_str in _patterns_dict:
70
+ line = pattern_re.sub(replaced_str, line)
71
+ return line.split()
72
+
73
+
74
+ def get_tokenizer(tokenizer, language='en'):
75
+ r"""
76
+ Generate tokenizer function for a string sentence.
77
+
78
+ Arguments:
79
+ tokenizer: the name of tokenizer function. If None, it returns split()
80
+ function, which splits the string sentence by space.
81
+ If basic_english, it returns _basic_english_normalize() function,
82
+ which normalize the string first and split by space. If a callable
83
+ function, it will return the function. If a tokenizer library
84
+ (e.g. spacy, moses, toktok, revtok, subword), it returns the
85
+ corresponding library.
86
+ language: Default en
87
+
88
+ Examples:
89
+ >>> import torchtext
90
+ >>> from torchtext.data import get_tokenizer
91
+ >>> tokenizer = get_tokenizer("basic_english")
92
+ >>> tokens = tokenizer("You can now install TorchText using pip!")
93
+ >>> tokens
94
+ >>> ['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']
95
+
96
+ """
97
+
98
+ # default tokenizer is string.split(), added as a module function for serialization
99
+ if tokenizer is None:
100
+ return _split_tokenizer
101
+
102
+ if tokenizer == "basic_english":
103
+ if language != 'en':
104
+ raise ValueError("Basic normalization is only available for Enlish(en)")
105
+ return _basic_english_normalize
106
+
107
+ # simply return if a function is passed
108
+ if callable(tokenizer):
109
+ return tokenizer
110
+
111
+ if tokenizer == "spacy":
112
+ try:
113
+ import spacy
114
+ spacy = spacy.load(language)
115
+ return partial(_spacy_tokenize, spacy=spacy)
116
+ except ImportError:
117
+ print("Please install SpaCy. "
118
+ "See the docs at https://spacy.io for more information.")
119
+ raise
120
+ except AttributeError:
121
+ print("Please install SpaCy and the SpaCy {} tokenizer. "
122
+ "See the docs at https://spacy.io for more "
123
+ "information.".format(language))
124
+ raise
125
+ elif tokenizer == "moses":
126
+ try:
127
+ from sacremoses import MosesTokenizer
128
+ moses_tokenizer = MosesTokenizer()
129
+ return moses_tokenizer.tokenize
130
+ except ImportError:
131
+ print("Please install SacreMoses. "
132
+ "See the docs at https://github.com/alvations/sacremoses "
133
+ "for more information.")
134
+ raise
135
+ elif tokenizer == "toktok":
136
+ try:
137
+ from nltk.tokenize.toktok import ToktokTokenizer
138
+ toktok = ToktokTokenizer()
139
+ return toktok.tokenize
140
+ except ImportError:
141
+ print("Please install NLTK. "
142
+ "See the docs at https://nltk.org for more information.")
143
+ raise
144
+ elif tokenizer == 'revtok':
145
+ try:
146
+ import revtok
147
+ return revtok.tokenize
148
+ except ImportError:
149
+ print("Please install revtok.")
150
+ raise
151
+ elif tokenizer == 'subword':
152
+ try:
153
+ import revtok
154
+ return partial(revtok.tokenize, decap=True)
155
+ except ImportError:
156
+ print("Please install revtok.")
157
+ raise
158
+ raise ValueError("Requested tokenizer {}, valid choices are a "
159
+ "callable that takes a single string as input, "
160
+ "\"revtok\" for the revtok reversible tokenizer, "
161
+ "\"subword\" for the revtok caps-aware tokenizer, "
162
+ "\"spacy\" for the SpaCy English tokenizer, or "
163
+ "\"moses\" for the NLTK port of the Moses tokenization "
164
+ "script.".format(tokenizer))
165
+
166
+
167
+ def is_tokenizer_serializable(tokenizer, language):
168
+ """Extend with other tokenizers which are found to not be serializable
169
+ """
170
+ if tokenizer == 'spacy':
171
+ return False
172
+ return True
173
+
174
+
175
+ def interleave_keys(a, b):
176
+ """Interleave bits from two sort keys to form a joint sort key.
177
+
178
+ Examples that are similar in both of the provided keys will have similar
179
+ values for the key defined by this function. Useful for tasks with two
180
+ text fields like machine translation or natural language inference.
181
+ """
182
+ def interleave(args):
183
+ return ''.join([x for t in zip(*args) for x in t])
184
+ return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2)
185
+
186
+
187
+ def get_torch_version():
188
+ import torch
189
+ v = torch.__version__
190
+ version_substrings = v.split('.')
191
+ major, minor = version_substrings[0], version_substrings[1]
192
+ return int(major), int(minor)
193
+
194
+
195
+ def dtype_to_attr(dtype):
196
+ # convert torch.dtype to dtype string id
197
+ # e.g. torch.int32 -> "int32"
198
+ # used for serialization
199
+ _, dtype = str(dtype).split('.')
200
+ return dtype
201
+
202
+
203
+ # TODO: Write more tests!
204
+ def ngrams_iterator(token_list, ngrams):
205
+ """Return an iterator that yields the given tokens and their ngrams.
206
+
207
+ Arguments:
208
+ token_list: A list of tokens
209
+ ngrams: the number of ngrams.
210
+
211
+ Examples:
212
+ >>> token_list = ['here', 'we', 'are']
213
+ >>> list(ngrams_iterator(token_list, 2))
214
+ >>> ['here', 'here we', 'we', 'we are', 'are']
215
+ """
216
+
217
+ def _get_ngrams(n):
218
+ return zip(*[token_list[i:] for i in range(n)])
219
+
220
+ for x in token_list:
221
+ yield x
222
+ for n in range(2, ngrams + 1):
223
+ for x in _get_ngrams(n):
224
+ yield ' '.join(x)
225
+
226
+
227
+ class RandomShuffler(object):
228
+ """Use random functions while keeping track of the random state to make it
229
+ reproducible and deterministic."""
230
+
231
+ def __init__(self, random_state=None):
232
+ self._random_state = random_state
233
+ if self._random_state is None:
234
+ self._random_state = random.getstate()
235
+
236
+ @contextmanager
237
+ def use_internal_state(self):
238
+ """Use a specific RNG state."""
239
+ old_state = random.getstate()
240
+ random.setstate(self._random_state)
241
+ yield
242
+ self._random_state = random.getstate()
243
+ random.setstate(old_state)
244
+
245
+ @property
246
+ def random_state(self):
247
+ return deepcopy(self._random_state)
248
+
249
+ @random_state.setter
250
+ def random_state(self, s):
251
+ self._random_state = s
252
+
253
+ def __call__(self, data):
254
+ """Shuffle and return a new list."""
255
+ with self.use_internal_state():
256
+ return random.sample(data, len(data))
data/field/mini_torchtext/vocab.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import unicode_literals
2
+ from collections import defaultdict
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class Vocab(object):
9
+ """Defines a vocabulary object that will be used to numericalize a field.
10
+
11
+ Attributes:
12
+ freqs: A collections.Counter object holding the frequencies of tokens
13
+ in the data used to build the Vocab.
14
+ stoi: A collections.defaultdict instance mapping token strings to
15
+ numerical identifiers.
16
+ itos: A list of token strings indexed by their numerical identifiers.
17
+ """
18
+
19
+ # TODO (@mttk): Populate classs with default values of special symbols
20
+ UNK = '<unk>'
21
+
22
+ def __init__(self, counter, max_size=None, min_freq=1, specials=['<unk>', '<pad>'], specials_first=True):
23
+ """Create a Vocab object from a collections.Counter.
24
+
25
+ Arguments:
26
+ counter: collections.Counter object holding the frequencies of
27
+ each value found in the data.
28
+ max_size: The maximum size of the vocabulary, or None for no
29
+ maximum. Default: None.
30
+ min_freq: The minimum frequency needed to include a token in the
31
+ vocabulary. Values less than 1 will be set to 1. Default: 1.
32
+ specials: The list of special tokens (e.g., padding or eos) that
33
+ will be prepended to the vocabulary. Default: ['<unk'>, '<pad>']
34
+ specials_first: Whether to add special tokens into the vocabulary at first.
35
+ If it is False, they are added into the vocabulary at last.
36
+ Default: True.
37
+ """
38
+ self.freqs = counter
39
+ counter = counter.copy()
40
+ min_freq = max(min_freq, 1)
41
+
42
+ self.itos = list()
43
+ self.unk_index = None
44
+ if specials_first:
45
+ self.itos = list(specials)
46
+ # only extend max size if specials are prepended
47
+ max_size = None if max_size is None else max_size + len(specials)
48
+
49
+ # frequencies of special tokens are not counted when building vocabulary
50
+ # in frequency order
51
+ for tok in specials:
52
+ del counter[tok]
53
+
54
+ # sort by frequency, then alphabetically
55
+ words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
56
+ words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
57
+
58
+ for word, freq in words_and_frequencies:
59
+ if freq < min_freq or len(self.itos) == max_size:
60
+ break
61
+ self.itos.append(word)
62
+
63
+ if Vocab.UNK in specials: # hard-coded for now
64
+ unk_index = specials.index(Vocab.UNK) # position in list
65
+ # account for ordering of specials, set variable
66
+ self.unk_index = unk_index if specials_first else len(self.itos) + unk_index
67
+ self.stoi = defaultdict(self._default_unk_index)
68
+ else:
69
+ self.stoi = defaultdict()
70
+
71
+ if not specials_first:
72
+ self.itos.extend(list(specials))
73
+
74
+ # stoi is simply a reverse dict for itos
75
+ self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
76
+
77
+ def _default_unk_index(self):
78
+ return self.unk_index
79
+
80
+ def __getitem__(self, token):
81
+ return self.stoi.get(token, self.stoi.get(Vocab.UNK))
82
+
83
+ def __getstate__(self):
84
+ # avoid picking defaultdict
85
+ attrs = dict(self.__dict__)
86
+ # cast to regular dict
87
+ attrs['stoi'] = dict(self.stoi)
88
+ return attrs
89
+
90
+ def __setstate__(self, state):
91
+ if state.get("unk_index", None) is None:
92
+ stoi = defaultdict()
93
+ else:
94
+ stoi = defaultdict(self._default_unk_index)
95
+ stoi.update(state['stoi'])
96
+ state['stoi'] = stoi
97
+ self.__dict__.update(state)
98
+
99
+ def __eq__(self, other):
100
+ if self.freqs != other.freqs:
101
+ return False
102
+ if self.stoi != other.stoi:
103
+ return False
104
+ if self.itos != other.itos:
105
+ return False
106
+ return True
107
+
108
+ def __len__(self):
109
+ return len(self.itos)
110
+
111
+ def extend(self, v, sort=False):
112
+ words = sorted(v.itos) if sort else v.itos
113
+ for w in words:
114
+ if w not in self.stoi:
115
+ self.itos.append(w)
116
+ self.stoi[w] = len(self.itos) - 1
data/field/nested_field.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.field.mini_torchtext.field import NestedField as TorchTextNestedField
6
+
7
+
8
+ class NestedField(TorchTextNestedField):
9
+ def pad(self, example):
10
+ self.nesting_field.include_lengths = self.include_lengths
11
+ if not self.include_lengths:
12
+ return self.nesting_field.pad(example)
13
+
14
+ sentence_length = len(example)
15
+ example, word_lengths = self.nesting_field.pad(example)
16
+ return example, sentence_length, word_lengths
17
+
18
+ def numericalize(self, arr, device=None):
19
+ numericalized = []
20
+ self.nesting_field.include_lengths = False
21
+ if self.include_lengths:
22
+ arr, sentence_length, word_lengths = arr
23
+
24
+ numericalized = self.nesting_field.numericalize(arr, device=device)
25
+
26
+ self.nesting_field.include_lengths = True
27
+ if self.include_lengths:
28
+ sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device)
29
+ word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device)
30
+ return (numericalized, sentence_length, word_lengths)
31
+ return numericalized
32
+
33
+ def build_vocab(self, *args, **kwargs):
34
+ sources = []
35
+ for arg in args:
36
+ if isinstance(arg, torch.utils.data.Dataset):
37
+ sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
38
+ else:
39
+ sources.append(arg)
40
+
41
+ flattened = []
42
+ for source in sources:
43
+ flattened.extend(source)
44
+
45
+ # just build vocab and does not load vector
46
+ self.nesting_field.build_vocab(*flattened, **kwargs)
47
+ super(TorchTextNestedField, self).build_vocab()
48
+ self.vocab.extend(self.nesting_field.vocab)
49
+ self.vocab.freqs = self.nesting_field.vocab.freqs.copy()
50
+ self.nesting_field.vocab = self.vocab
data/parser/__init__.py ADDED
File without changes
data/parser/from_mrp/__init__.py ADDED
File without changes
data/parser/from_mrp/abstract_parser.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ from data.parser.json_parser import example_from_json
6
+
7
+
8
+ class AbstractParser(torch.utils.data.Dataset):
9
+ def __init__(self, fields, data, filter_pred=None):
10
+ super(AbstractParser, self).__init__()
11
+
12
+ self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())]
13
+
14
+ if isinstance(fields, dict):
15
+ fields, field_dict = [], fields
16
+ for field in field_dict.values():
17
+ if isinstance(field, list):
18
+ fields.extend(field)
19
+ else:
20
+ fields.append(field)
21
+
22
+ if filter_pred is not None:
23
+ make_list = isinstance(self.examples, list)
24
+ self.examples = filter(filter_pred, self.examples)
25
+ if make_list:
26
+ self.examples = list(self.examples)
27
+
28
+ self.fields = dict(fields)
29
+
30
+ # Unpack field tuples
31
+ for n, f in list(self.fields.items()):
32
+ if isinstance(n, tuple):
33
+ self.fields.update(zip(n, f))
34
+ del self.fields[n]
35
+
36
+ def __getitem__(self, i):
37
+ item = self.examples[i]
38
+ processed_item = {}
39
+ for (name, field) in self.fields.items():
40
+ if field is not None:
41
+ processed_item[name] = field.process(getattr(item, name), device=None)
42
+ return processed_item
43
+
44
+ def __len__(self):
45
+ return len(self.examples)
46
+
47
+ def get_examples(self, attr):
48
+ if attr in self.fields:
49
+ for x in self.examples:
50
+ yield getattr(x, attr)
data/parser/from_mrp/evaluation_parser.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class EvaluationParser(AbstractParser):
9
+ def __init__(self, args, fields):
10
+ path = args.test_data
11
+ self.data = utils.load_dataset(path)
12
+
13
+ for sentence in self.data.values():
14
+ sentence["token anchors"] = [[a["from"], a["to"]] for a in sentence["token anchors"]]
15
+
16
+ utils.create_bert_tokens(self.data, args.encoder)
17
+
18
+ super(EvaluationParser, self).__init__(fields, self.data)
data/parser/from_mrp/labeled_edge_parser.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class LabeledEdgeParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, n_node_token_pairs = 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for edge in sentence["edges"]:
21
+ if "label" not in edge:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, sentence in utils.node_generator(self.data):
26
+ node["label"] = "Node"
27
+
28
+ self.node_counter += 1
29
+
30
+ utils.create_bert_tokens(self.data, args.encoder)
31
+
32
+ # create edge vectors
33
+ for sentence in self.data.values():
34
+ assert sentence["tops"] == [0], sentence
35
+ N = len(sentence["nodes"])
36
+
37
+ edge_count = utils.create_edges(sentence)
38
+ self.edge_counter += edge_count
39
+ self.no_edge_counter += N * (N - 1) - edge_count
40
+
41
+ sentence["nodes"] = sentence["nodes"][1:]
42
+ N = len(sentence["nodes"])
43
+
44
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
45
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
46
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
47
+ sentence["anchored labels"] = [len(sentence["input"]), []]
48
+ for i, node in enumerate(sentence["nodes"]):
49
+ anchored_labels = []
50
+
51
+ for anchor in node["anchors"]:
52
+ sentence["anchor edges"][-1].append((i, anchor))
53
+ anchored_labels.append((anchor, node["label"]))
54
+
55
+ sentence["anchored labels"][1].append(anchored_labels)
56
+
57
+ anchor_count += len(node["anchors"])
58
+ n_node_token_pairs += len(sentence["input"])
59
+
60
+ sentence["id"] = [sentence["id"]]
61
+
62
+ self.anchor_freq = anchor_count / n_node_token_pairs
63
+ self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
64
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
65
+
66
+ super(LabeledEdgeParser, self).__init__(fields, self.data, filter_pred)
67
+
68
+ @staticmethod
69
+ def node_similarity_key(node):
70
+ return tuple([node["label"]] + node["anchors"])
data/parser/from_mrp/node_centric_parser.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class NodeCentricParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, n_node_token_pairs = 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for node in sentence["nodes"]:
21
+ if "label" not in node:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, _ in utils.node_generator(self.data):
26
+ self.node_counter += 1
27
+
28
+ # print(f"Number of unlabeled nodes: {unlabeled_count}", flush=True)
29
+
30
+ utils.create_bert_tokens(self.data, args.encoder)
31
+
32
+ # create edge vectors
33
+ for sentence in self.data.values():
34
+ N = len(sentence["nodes"])
35
+
36
+ edge_count = utils.create_edges(sentence)
37
+ self.edge_counter += edge_count
38
+ # self.no_edge_counter += len([n for n in sentence["nodes"] if n["label"] in ["Source", "Target"]]) * len([n for n in sentence["nodes"] if n["label"] not in ["Source", "Target"]]) - edge_count
39
+ self.no_edge_counter += N * (N - 1) - edge_count
40
+
41
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
42
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []] # dummy
43
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []] # dummy
44
+ sentence["anchored labels"] = [len(sentence["input"]), []]
45
+ for i, node in enumerate(sentence["nodes"]):
46
+ anchored_labels = []
47
+ #if len(node["anchors"]) == 0:
48
+ # print(f"Empty node in {sentence['id']}", flush=True)
49
+
50
+ for anchor in node["anchors"]:
51
+ sentence["anchor edges"][-1].append((i, anchor))
52
+ anchored_labels.append((anchor, node["label"]))
53
+
54
+ sentence["anchored labels"][1].append(anchored_labels)
55
+
56
+ anchor_count += len(node["anchors"])
57
+ n_node_token_pairs += len(sentence["input"])
58
+
59
+ sentence["id"] = [sentence["id"]]
60
+
61
+ self.anchor_freq = anchor_count / n_node_token_pairs
62
+ self.source_anchor_freq = self.target_anchor_freq = 0.5 # dummy
63
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
64
+
65
+ super(NodeCentricParser, self).__init__(fields, self.data, filter_pred)
66
+
67
+ @staticmethod
68
+ def node_similarity_key(node):
69
+ return tuple([node["label"]] + node["anchors"])
data/parser/from_mrp/request_parser.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import utility.parser_utils as utils
5
+ from data.parser.from_mrp.abstract_parser import AbstractParser
6
+
7
+
8
+ class RequestParser(AbstractParser):
9
+ def __init__(self, sentences, args, fields):
10
+ self.data = {i: {"id": str(i), "sentence": sentence} for i, sentence in enumerate(sentences)}
11
+
12
+ sentences = [example["sentence"] for example in self.data.values()]
13
+
14
+ for example in self.data.values():
15
+ example["input"] = example["sentence"].strip().split(' ')
16
+ example["token anchors"], offset = [], 0
17
+ for token in example["input"]:
18
+ example["token anchors"].append([offset, offset + len(token)])
19
+ offset += len(token) + 1
20
+
21
+ utils.create_bert_tokens(self.data, args.encoder)
22
+
23
+ super(RequestParser, self).__init__(fields, self.data)
data/parser/from_mrp/sequential_parser.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.from_mrp.abstract_parser import AbstractParser
5
+ import utility.parser_utils as utils
6
+
7
+
8
+ class SequentialParser(AbstractParser):
9
+ def __init__(self, args, part: str, fields, filter_pred=None, **kwargs):
10
+ assert part == "training" or part == "validation"
11
+ path = args.training_data if part == "training" else args.validation_data
12
+
13
+ self.data = utils.load_dataset(path)
14
+ utils.anchor_ids_from_intervals(self.data)
15
+
16
+ self.node_counter, self.edge_counter, self.no_edge_counter = 0, 0, 0
17
+ anchor_count, source_anchor_count, target_anchor_count, n_node_token_pairs = 0, 0, 0, 0
18
+
19
+ for sentence_id, sentence in list(self.data.items()):
20
+ for node in sentence["nodes"]:
21
+ if "label" not in node:
22
+ del self.data[sentence_id]
23
+ break
24
+
25
+ for node, _ in utils.node_generator(self.data):
26
+ node["target anchors"] = []
27
+ node["source anchors"] = []
28
+
29
+ for sentence in self.data.values():
30
+ for e in sentence["edges"]:
31
+ source, target = e["source"], e["target"]
32
+
33
+ if sentence["nodes"][target]["label"] == "Target":
34
+ sentence["nodes"][source]["target anchors"] += sentence["nodes"][target]["anchors"]
35
+ elif sentence["nodes"][target]["label"] == "Source":
36
+ sentence["nodes"][source]["source anchors"] += sentence["nodes"][target]["anchors"]
37
+
38
+ for i, node in list(enumerate(sentence["nodes"]))[::-1]:
39
+ if "label" not in node or node["label"] in ["Source", "Target"]:
40
+ del sentence["nodes"][i]
41
+ sentence["edges"] = []
42
+
43
+ for node, sentence in utils.node_generator(self.data):
44
+ self.node_counter += 1
45
+
46
+ utils.create_bert_tokens(self.data, args.encoder)
47
+
48
+ # create edge vectors
49
+ for sentence in self.data.values():
50
+ N = len(sentence["nodes"])
51
+
52
+ utils.create_edges(sentence)
53
+ self.no_edge_counter += N * (N - 1)
54
+
55
+ sentence["anchor edges"] = [N, len(sentence["input"]), []]
56
+ sentence["source anchor edges"] = [N, len(sentence["input"]), []]
57
+ sentence["target anchor edges"] = [N, len(sentence["input"]), []]
58
+
59
+ sentence["anchored labels"] = [len(sentence["input"]), []]
60
+ for i, node in enumerate(sentence["nodes"]):
61
+ anchored_labels = []
62
+
63
+ for anchor in node["anchors"]:
64
+ sentence["anchor edges"][-1].append((i, anchor))
65
+ anchored_labels.append((anchor, node["label"]))
66
+
67
+ for anchor in node["source anchors"]:
68
+ sentence["source anchor edges"][-1].append((i, anchor))
69
+ for anchor in node["target anchors"]:
70
+ sentence["target anchor edges"][-1].append((i, anchor))
71
+
72
+ sentence["anchored labels"][1].append(anchored_labels)
73
+
74
+ anchor_count += len(node["anchors"])
75
+ source_anchor_count += len(node["source anchors"])
76
+ target_anchor_count += len(node["target anchors"])
77
+ n_node_token_pairs += len(sentence["input"])
78
+
79
+ sentence["id"] = [sentence["id"]]
80
+
81
+ self.anchor_freq = anchor_count / n_node_token_pairs
82
+ self.source_anchor_freq = anchor_count / n_node_token_pairs
83
+ self.target_anchor_freq = anchor_count / n_node_token_pairs
84
+ self.input_count = sum(len(sentence["input"]) for sentence in self.data.values())
85
+
86
+ super(SequentialParser, self).__init__(fields, self.data, filter_pred)
87
+
88
+ @staticmethod
89
+ def node_similarity_key(node):
90
+ return tuple([node["label"]] + node["anchors"])
data/parser/json_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from data.field.mini_torchtext.example import Example
3
+
4
+
5
+ def example_from_json(obj, fields):
6
+ ex = Example()
7
+ for key, vals in fields.items():
8
+ if vals is not None:
9
+ if not isinstance(vals, list):
10
+ vals = [vals]
11
+ for val in vals:
12
+ # for processing the key likes 'foo.bar'
13
+ name, field = val
14
+ ks = key.split(".")
15
+
16
+ def reducer(obj, key):
17
+ if isinstance(obj, list):
18
+ results = []
19
+ for data in obj:
20
+ if key not in data:
21
+ # key error
22
+ raise ValueError("Specified key {} was not found in " "the input data".format(key))
23
+ else:
24
+ results.append(data[key])
25
+ return results
26
+ else:
27
+ # key error
28
+ if key not in obj:
29
+ raise ValueError("Specified key {} was not found in " "the input data".format(key))
30
+ else:
31
+ return obj[key]
32
+
33
+ v = reduce(reducer, ks, obj)
34
+ setattr(ex, name, field.preprocess(v))
35
+ return ex
data/parser/to_mrp/__init__.py ADDED
File without changes
data/parser/to_mrp/abstract_parser.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ class AbstractParser:
5
+ def __init__(self, dataset):
6
+ self.dataset = dataset
7
+
8
+ def create_nodes(self, prediction):
9
+ return [
10
+ {"id": i, "label": self.label_to_str(l, prediction["anchors"][i], prediction)}
11
+ for i, l in enumerate(prediction["labels"])
12
+ ]
13
+
14
+ def label_to_str(self, label, anchors, prediction):
15
+ return self.dataset.label_field.vocab.itos[label - 1]
16
+
17
+ def create_edges(self, prediction, nodes):
18
+ N = len(nodes)
19
+ node_sets = [{"id": n, "set": set([n])} for n in range(N)]
20
+ _, indices = prediction["edge presence"][:N, :N].reshape(-1).sort(descending=True)
21
+ sources, targets = indices // N, indices % N
22
+
23
+ edges = []
24
+ for i in range((N - 1) * N // 2):
25
+ source, target = sources[i].item(), targets[i].item()
26
+ p = prediction["edge presence"][source, target]
27
+
28
+ if p < 0.5 and len(edges) >= N - 1:
29
+ break
30
+
31
+ if node_sets[source]["set"] is node_sets[target]["set"] and p < 0.5:
32
+ continue
33
+
34
+ self.create_edge(source, target, prediction, edges, nodes)
35
+
36
+ if node_sets[source]["set"] is not node_sets[target]["set"]:
37
+ from_set = node_sets[source]["set"]
38
+ for n in node_sets[target]["set"]:
39
+ from_set.add(n)
40
+ node_sets[n]["set"] = from_set
41
+
42
+ return edges
43
+
44
+ def create_edge(self, source, target, prediction, edges, nodes):
45
+ label = self.get_edge_label(prediction, source, target)
46
+ edge = {"source": source, "target": target, "label": label}
47
+
48
+ edges.append(edge)
49
+
50
+ def create_anchors(self, prediction, nodes, join_contiguous=True, at_least_one=False, single_anchor=False, mode="anchors"):
51
+ for i, node in enumerate(nodes):
52
+ threshold = 0.5 if not at_least_one else min(0.5, prediction[mode][i].max().item())
53
+ node[mode] = (prediction[mode][i] >= threshold).nonzero(as_tuple=False).squeeze(-1)
54
+ node[mode] = prediction["token intervals"][node[mode], :]
55
+
56
+ if single_anchor and len(node[mode]) > 1:
57
+ start = min(a[0].item() for a in node[mode])
58
+ end = max(a[1].item() for a in node[mode])
59
+ node[mode] = [{"from": start, "to": end}]
60
+ continue
61
+
62
+ node[mode] = [{"from": f.item(), "to": t.item()} for f, t in node[mode]]
63
+ node[mode] = sorted(node[mode], key=lambda a: a["from"])
64
+
65
+ if join_contiguous and len(node[mode]) > 1:
66
+ cleaned_anchors = []
67
+ end, start = node[mode][0]["from"], node[mode][0]["from"]
68
+ for anchor in node[mode]:
69
+ if end < anchor["from"]:
70
+ cleaned_anchors.append({"from": start, "to": end})
71
+ start = anchor["from"]
72
+ end = anchor["to"]
73
+ cleaned_anchors.append({"from": start, "to": end})
74
+
75
+ node[mode] = cleaned_anchors
76
+
77
+ return nodes
78
+
79
+ def get_edge_label(self, prediction, source, target):
80
+ return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].item()]
data/parser/to_mrp/labeled_edge_parser.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class LabeledEdgeParser(AbstractParser):
8
+ def __init__(self, *args):
9
+ super().__init__(*args)
10
+ self.source_id = self.dataset.edge_label_field.vocab.stoi["Source"]
11
+ self.target_id = self.dataset.edge_label_field.vocab.stoi["Target"]
12
+
13
+ def parse(self, prediction):
14
+ output = {}
15
+
16
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
17
+ output["nodes"] = self.create_nodes(prediction)
18
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
19
+ output["nodes"] = [{"id": 0}] + output["nodes"]
20
+ output["edges"] = self.create_edges(prediction, output["nodes"])
21
+
22
+ return output
23
+
24
+ def create_nodes(self, prediction):
25
+ return [{"id": i + 1} for i, l in enumerate(prediction["labels"])]
26
+
27
+ def create_edges(self, prediction, nodes):
28
+ N = len(nodes)
29
+ edge_prediction = prediction["edge presence"][:N, :N]
30
+
31
+ edges = []
32
+ for target in range(1, N):
33
+ if edge_prediction[0, target] >= 0.5:
34
+ prediction["edge labels"][0, target, self.source_id] = float("-inf")
35
+ prediction["edge labels"][0, target, self.target_id] = float("-inf")
36
+ self.create_edge(0, target, prediction, edges, nodes)
37
+
38
+ for source in range(1, N):
39
+ for target in range(1, N):
40
+ if source == target:
41
+ continue
42
+ if edge_prediction[source, target] < 0.5:
43
+ continue
44
+ for i in range(prediction["edge labels"].size(2)):
45
+ if i not in [self.source_id, self.target_id]:
46
+ prediction["edge labels"][source, target, i] = float("-inf")
47
+ self.create_edge(source, target, prediction, edges, nodes)
48
+
49
+ return edges
50
+
51
+ def get_edge_label(self, prediction, source, target):
52
+ return self.dataset.edge_label_field.vocab.itos[prediction["edge labels"][source, target].argmax(-1).item()]
data/parser/to_mrp/node_centric_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class NodeCentricParser(AbstractParser):
8
+ def parse(self, prediction):
9
+ output = {}
10
+
11
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
12
+ output["nodes"] = self.create_nodes(prediction)
13
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True)
14
+ output["edges"] = self.create_edges(prediction, output["nodes"])
15
+
16
+ return output
17
+
18
+ def create_edge(self, source, target, prediction, edges, nodes):
19
+ edge = {"source": source, "target": target, "label": None}
20
+ edges.append(edge)
21
+
22
+ def create_edges(self, prediction, nodes):
23
+ N = len(nodes)
24
+ edge_prediction = prediction["edge presence"][:N, :N]
25
+
26
+ targets = [i for i, node in enumerate(nodes) if node["label"] in ["Source", "Target"]]
27
+ sources = [i for i, node in enumerate(nodes) if node["label"] not in ["Source", "Target"]]
28
+
29
+ edges = []
30
+ for target in targets:
31
+ for source in sources:
32
+ if edge_prediction[source, target] >= 0.5:
33
+ self.create_edge(source, target, prediction, edges, nodes)
34
+
35
+ return edges
data/parser/to_mrp/sequential_parser.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ from data.parser.to_mrp.abstract_parser import AbstractParser
5
+
6
+
7
+ class SequentialParser(AbstractParser):
8
+ def parse(self, prediction):
9
+ output = {}
10
+
11
+ output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
12
+ output["nodes"] = self.create_nodes(prediction)
13
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True, mode="anchors")
14
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="source anchors")
15
+ output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="target anchors")
16
+ output["edges"], output["nodes"] = self.create_targets_sources(output["nodes"])
17
+
18
+ return output
19
+
20
+ def create_targets_sources(self, nodes):
21
+ edges, new_nodes = [], []
22
+ for i, node in enumerate(nodes):
23
+ new_node_id = len(nodes) + len(new_nodes)
24
+ if len(node["source anchors"]) > 0:
25
+ new_nodes.append({"id": new_node_id, "label": "Source", "anchors": node["source anchors"]})
26
+ edges.append({"source": i, "target": new_node_id, "label": ""})
27
+ new_node_id += 1
28
+ del node["source anchors"]
29
+
30
+ if len(node["target anchors"]) > 0:
31
+ new_nodes.append({"id": new_node_id, "label": "Target", "anchors": node["target anchors"]})
32
+ edges.append({"source": i, "target": new_node_id, "label": ""})
33
+ del node["target anchors"]
34
+
35
+ return edges, nodes + new_nodes
model/__init__.py ADDED
File without changes
model/head/__init__.py ADDED
File without changes
model/head/abstract_head.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from model.module.edge_classifier import EdgeClassifier
10
+ from model.module.anchor_classifier import AnchorClassifier
11
+ from utility.cross_entropy import cross_entropy, binary_cross_entropy
12
+ from utility.hungarian_matching import get_matching, reorder, match_anchor, match_label
13
+ from utility.utils import create_padding_mask
14
+
15
+
16
+ class AbstractHead(nn.Module):
17
+ def __init__(self, dataset, args, config, initialize: bool):
18
+ super(AbstractHead, self).__init__()
19
+
20
+ self.edge_classifier = self.init_edge_classifier(dataset, args, config, initialize)
21
+ self.label_classifier = self.init_label_classifier(dataset, args, config, initialize)
22
+ self.anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="anchor")
23
+ self.source_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="source_anchor")
24
+ self.target_anchor_classifier = self.init_anchor_classifier(dataset, args, config, initialize, mode="target_anchor")
25
+
26
+ self.query_length = args.query_length
27
+ self.focal = args.focal
28
+ self.dataset = dataset
29
+
30
+ def forward(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch):
31
+ output = {}
32
+
33
+ decoder_lens = self.query_length * batch["every_input"][1]
34
+ output["label"] = self.forward_label(decoder_output)
35
+ output["anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
36
+ output["source_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
37
+ output["target_anchor"] = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
38
+
39
+ cost_matrices = self.create_cost_matrices(output, batch, decoder_lens)
40
+ matching = get_matching(cost_matrices)
41
+ decoder_output = reorder(decoder_output, matching, batch["labels"][0].size(1))
42
+ output["edge presence"], output["edge label"] = self.forward_edge(decoder_output)
43
+
44
+ return self.loss(output, batch, matching, decoder_mask)
45
+
46
+ def predict(self, encoder_output, decoder_output, encoder_mask, decoder_mask, batch, **kwargs):
47
+ every_input, word_lens = batch["every_input"]
48
+ decoder_lens = self.query_length * word_lens
49
+ batch_size = every_input.size(0)
50
+
51
+ label_pred = self.forward_label(decoder_output)
52
+ anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="anchor") # shape: (B, T_l, T_w)
53
+ source_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="source_anchor") # shape: (B, T_l, T_w)
54
+ target_anchor_pred = self.forward_anchor(decoder_output, encoder_output, encoder_mask, mode="target_anchor") # shape: (B, T_l, T_w)
55
+
56
+ labels = [[] for _ in range(batch_size)]
57
+ anchors, source_anchors, target_anchors = [[] for _ in range(batch_size)], [[] for _ in range(batch_size)], [[] for _ in range(batch_size)]
58
+
59
+ for b in range(batch_size):
60
+ label_indices = self.inference_label(label_pred[b, :decoder_lens[b], :]).cpu()
61
+ for t in range(label_indices.size(0)):
62
+ label_index = label_indices[t].item()
63
+ if label_index == 0:
64
+ continue
65
+
66
+ decoder_output[b, len(labels[b]), :] = decoder_output[b, t, :]
67
+
68
+ labels[b].append(label_index)
69
+ if anchor_pred is None:
70
+ anchors[b].append(list(range(t // self.query_length, word_lens[b])))
71
+ else:
72
+ anchors[b].append(self.inference_anchor(anchor_pred[b, t, :word_lens[b]]).cpu())
73
+
74
+ if source_anchor_pred is None:
75
+ source_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
76
+ else:
77
+ source_anchors[b].append(self.inference_anchor(source_anchor_pred[b, t, :word_lens[b]]).cpu())
78
+
79
+ if target_anchor_pred is None:
80
+ target_anchors[b].append(list(range(t // self.query_length, word_lens[b])))
81
+ else:
82
+ target_anchors[b].append(self.inference_anchor(target_anchor_pred[b, t, :word_lens[b]]).cpu())
83
+
84
+ decoder_output = decoder_output[:, : max(len(l) for l in labels), :]
85
+ edge_presence, edge_labels = self.forward_edge(decoder_output)
86
+
87
+ outputs = [
88
+ self.parser.parse(
89
+ {
90
+ "labels": labels[b],
91
+ "anchors": anchors[b],
92
+ "source anchors": source_anchors[b],
93
+ "target anchors": target_anchors[b],
94
+ "edge presence": self.inference_edge_presence(edge_presence, b),
95
+ "edge labels": self.inference_edge_label(edge_labels, b),
96
+ "id": batch["id"][b].cpu(),
97
+ "tokens": batch["every_input"][0][b, : word_lens[b]].cpu(),
98
+ "token intervals": batch["token_intervals"][b, :, :].cpu(),
99
+ },
100
+ **kwargs
101
+ )
102
+ for b in range(batch_size)
103
+ ]
104
+
105
+ return outputs
106
+
107
+ def loss(self, output, batch, matching, decoder_mask):
108
+ batch_size = batch["every_input"][0].size(0)
109
+ device = batch["every_input"][0].device
110
+ T_label = batch["labels"][0].size(1)
111
+ T_input = batch["every_input"][0].size(1)
112
+ T_edge = batch["edge_presence"].size(1)
113
+
114
+ input_mask = create_padding_mask(batch_size, T_input, batch["every_input"][1], device) # shape: (B, T_input)
115
+ label_mask = create_padding_mask(batch_size, T_label, batch["labels"][1], device) # shape: (B, T_label)
116
+ edge_mask = torch.eye(T_label, T_label, device=device, dtype=torch.bool).unsqueeze(0) # shape: (1, T_label, T_label)
117
+ edge_mask = edge_mask | label_mask.unsqueeze(1) | label_mask.unsqueeze(2) # shape: (B, T_label, T_label)
118
+ if T_edge != T_label:
119
+ edge_mask = F.pad(edge_mask, (T_edge - T_label, 0, T_edge - T_label, 0), value=0)
120
+ edge_label_mask = (batch["edge_presence"] == 0) | edge_mask
121
+
122
+ if output["edge label"] is not None:
123
+ batch["edge_labels"] = (
124
+ batch["edge_labels"][0][:, :, :, :output["edge label"].size(-1)],
125
+ batch["edge_labels"][1],
126
+ )
127
+
128
+ losses = {}
129
+ losses.update(self.loss_label(output, batch, decoder_mask, matching))
130
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="anchor"))
131
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="source_anchor"))
132
+ losses.update(self.loss_anchor(output, batch, input_mask, matching, mode="target_anchor"))
133
+ losses.update(self.loss_edge_presence(output, batch, edge_mask))
134
+ losses.update(self.loss_edge_label(output, batch, edge_label_mask.unsqueeze(-1)))
135
+
136
+ stats = {f"{key}": value.detach().cpu().item() for key, value in losses.items()}
137
+ total_loss = sum(losses.values()) / len(losses)
138
+
139
+ return total_loss, stats
140
+
141
+ @torch.no_grad()
142
+ def create_cost_matrices(self, output, batch, decoder_lens):
143
+ batch_size = len(batch["labels"][1])
144
+ decoder_lens = decoder_lens.cpu()
145
+
146
+ matrices = []
147
+ for b in range(batch_size):
148
+ label_cost_matrix = self.label_cost_matrix(output, batch, decoder_lens, b)
149
+ anchor_cost_matrix = self.anchor_cost_matrix(output, batch, decoder_lens, b)
150
+
151
+ cost_matrix = label_cost_matrix * anchor_cost_matrix
152
+ matrices.append(cost_matrix.cpu())
153
+
154
+ return matrices
155
+
156
+ def init_edge_classifier(self, dataset, args, config, initialize: bool):
157
+ if not config["edge presence"] and not config["edge label"]:
158
+ return None
159
+ return EdgeClassifier(dataset, args, initialize, presence=config["edge presence"], label=config["edge label"])
160
+
161
+ def init_label_classifier(self, dataset, args, config, initialize: bool):
162
+ if not config["label"]:
163
+ return None
164
+
165
+ classifier = nn.Sequential(
166
+ nn.Dropout(args.dropout_label),
167
+ nn.Linear(args.hidden_size, len(dataset.label_field.vocab) + 1, bias=True)
168
+ )
169
+ if initialize:
170
+ classifier[1].bias.data = dataset.label_freqs.log()
171
+
172
+ return classifier
173
+
174
+ def init_anchor_classifier(self, dataset, args, config, initialize: bool, mode="anchor"):
175
+ if not config[mode]:
176
+ return None
177
+
178
+ return AnchorClassifier(dataset, args, initialize, mode=mode)
179
+
180
+ def forward_edge(self, decoder_output):
181
+ if self.edge_classifier is None:
182
+ return None, None
183
+ return self.edge_classifier(decoder_output)
184
+
185
+ def forward_label(self, decoder_output):
186
+ if self.label_classifier is None:
187
+ return None
188
+ return torch.log_softmax(self.label_classifier(decoder_output), dim=-1)
189
+
190
+ def forward_anchor(self, decoder_output, encoder_output, encoder_mask, mode="anchor"):
191
+ classifier = getattr(self, f"{mode}_classifier")
192
+ if classifier is None:
193
+ return None
194
+ return classifier(decoder_output, encoder_output, encoder_mask)
195
+
196
+ def inference_label(self, prediction):
197
+ prediction = prediction.exp()
198
+ return torch.where(
199
+ prediction[:, 0] > prediction[:, 1:].sum(-1),
200
+ torch.zeros(prediction.size(0), dtype=torch.long, device=prediction.device),
201
+ prediction[:, 1:].argmax(dim=-1) + 1
202
+ )
203
+
204
+ def inference_anchor(self, prediction):
205
+ return prediction.sigmoid()
206
+
207
+ def inference_edge_presence(self, prediction, example_index: int):
208
+ if prediction is None:
209
+ return None
210
+
211
+ N = prediction.size(1)
212
+ mask = torch.eye(N, N, device=prediction.device, dtype=torch.bool)
213
+ return prediction[example_index, :, :].sigmoid().masked_fill(mask, 0.0).cpu()
214
+
215
+ def inference_edge_label(self, prediction, example_index: int):
216
+ if prediction is None:
217
+ return None
218
+ return prediction[example_index, :, :, :].cpu()
219
+
220
+ def loss_edge_presence(self, prediction, target, mask):
221
+ if self.edge_classifier is None or prediction["edge presence"] is None:
222
+ return {}
223
+ return {"edge presence": binary_cross_entropy(prediction["edge presence"], target["edge_presence"].float(), mask)}
224
+
225
+ def loss_edge_label(self, prediction, target, mask):
226
+ if self.edge_classifier is None or prediction["edge label"] is None:
227
+ return {}
228
+ return {"edge label": binary_cross_entropy(prediction["edge label"], target["edge_labels"][0].float(), mask)}
229
+
230
+ def loss_label(self, prediction, target, mask, matching):
231
+ if self.label_classifier is None or prediction["label"] is None:
232
+ return {}
233
+
234
+ prediction = prediction["label"]
235
+ target = match_label(
236
+ target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
237
+ )
238
+ return {"label": cross_entropy(prediction, target, mask, focal=self.focal)}
239
+
240
+ def loss_anchor(self, prediction, target, mask, matching, mode="anchor"):
241
+ if getattr(self, f"{mode}_classifier") is None or prediction[mode] is None:
242
+ return {}
243
+
244
+ prediction = prediction[mode]
245
+ target, anchor_mask = match_anchor(target[mode], matching, prediction.shape, prediction.device)
246
+ mask = anchor_mask.unsqueeze(-1) | mask.unsqueeze(-2)
247
+ return {mode: binary_cross_entropy(prediction, target.float(), mask)}
248
+
249
+ def label_cost_matrix(self, output, batch, decoder_lens, b: int):
250
+ if output["label"] is None:
251
+ return 1.0
252
+
253
+ target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, num_classes)
254
+ label_prob = output["label"][b, : decoder_lens[b], :].exp().unsqueeze(0) # shape: (1, num_queries, num_classes)
255
+ tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, num_classes)
256
+ cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
257
+
258
+ return cost_matrix
259
+
260
+ def anchor_cost_matrix(self, output, batch, decoder_lens, b: int):
261
+ if output["anchor"] is None:
262
+ return 1.0
263
+
264
+ num_nodes = batch["labels"][1][b]
265
+ word_lens = batch["every_input"][1]
266
+ target_anchors, _ = batch["anchor"]
267
+ pred_anchors = output["anchor"].sigmoid()
268
+
269
+ tgt_align = target_anchors[b, : num_nodes, : word_lens[b]] # shape: (num_nodes, num_inputs)
270
+ align_prob = pred_anchors[b, : decoder_lens[b], : word_lens[b]] # shape: (num_queries, num_inputs)
271
+ align_prob = align_prob.unsqueeze(1).expand(-1, num_nodes, -1) # shape: (num_queries, num_nodes, num_inputs)
272
+ align_prob = torch.where(tgt_align.unsqueeze(0).bool(), align_prob, 1.0 - align_prob) # shape: (num_queries, num_nodes, num_inputs)
273
+ cost_matrix = align_prob.log().mean(-1).exp() # shape: (num_queries, num_nodes)
274
+ return cost_matrix
model/head/labeled_edge_head.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from model.head.abstract_head import AbstractHead
8
+ from data.parser.to_mrp.labeled_edge_parser import LabeledEdgeParser
9
+ from utility.cross_entropy import binary_cross_entropy
10
+ from utility.hungarian_matching import match_label
11
+
12
+
13
+ class LabeledEdgeHead(AbstractHead):
14
+ def __init__(self, dataset, args, initialize):
15
+ config = {
16
+ "label": True,
17
+ "edge presence": True,
18
+ "edge label": True,
19
+ "anchor": True,
20
+ "source_anchor": False,
21
+ "target_anchor": False
22
+ }
23
+ super(LabeledEdgeHead, self).__init__(dataset, args, config, initialize)
24
+
25
+ self.top_node = nn.Parameter(torch.randn(1, 1, args.hidden_size), requires_grad=True)
26
+ self.parser = LabeledEdgeParser(dataset)
27
+
28
+ def init_label_classifier(self, dataset, args, config, initialize: bool):
29
+ classifier = nn.Sequential(
30
+ nn.Dropout(args.dropout_label),
31
+ nn.Linear(args.hidden_size, 1, bias=True)
32
+ )
33
+ if initialize:
34
+ bias_init = torch.tensor([dataset.label_freqs[1]])
35
+ classifier[1].bias.data = (bias_init / (1.0 - bias_init)).log()
36
+
37
+ return classifier
38
+
39
+ def forward_label(self, decoder_output):
40
+ return self.label_classifier(decoder_output)
41
+
42
+ def forward_edge(self, decoder_output):
43
+ top_node = self.top_node.expand(decoder_output.size(0), -1, -1)
44
+ decoder_output = torch.cat([top_node, decoder_output], dim=1)
45
+ return self.edge_classifier(decoder_output)
46
+
47
+ def loss_label(self, prediction, target, mask, matching):
48
+ prediction = prediction["label"]
49
+ target = match_label(
50
+ target["labels"][0], matching, prediction.shape[:-1], prediction.device, self.query_length
51
+ )
52
+ return {"label": binary_cross_entropy(prediction.squeeze(-1), target.float(), mask, focal=self.focal)}
53
+
54
+ def inference_label(self, prediction):
55
+ return (prediction.squeeze(-1) > 0.0).long()
56
+
57
+ def label_cost_matrix(self, output, batch, decoder_lens, b: int):
58
+ if output["label"] is None:
59
+ return 1.0
60
+
61
+ target_labels = batch["anchored_labels"][b] # shape: (num_nodes, num_inputs, 2)
62
+ label_prob = output["label"][b, : decoder_lens[b], :].sigmoid().unsqueeze(0) # shape: (1, num_queries, 1)
63
+ label_prob = torch.cat([1.0 - label_prob, label_prob], dim=-1) # shape: (1, num_queries, 2)
64
+ tgt_label = target_labels.repeat_interleave(self.query_length, dim=1) # shape: (num_nodes, num_queries, 2)
65
+ cost_matrix = ((tgt_label * label_prob).sum(-1) * label_prob[:, :, 1:].sum(-1)).t().sqrt() # shape: (num_queries, num_nodes)
66
+
67
+ return cost_matrix
model/head/node_centric_head.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+
6
+ from model.head.abstract_head import AbstractHead
7
+ from data.parser.to_mrp.node_centric_parser import NodeCentricParser
8
+ from utility.cross_entropy import binary_cross_entropy
9
+
10
+
11
+ class NodeCentricHead(AbstractHead):
12
+ def __init__(self, dataset, args, initialize):
13
+ config = {
14
+ "label": True,
15
+ "edge presence": True,
16
+ "edge label": False,
17
+ "anchor": True,
18
+ "source_anchor": False,
19
+ "target_anchor": False
20
+ }
21
+ super(NodeCentricHead, self).__init__(dataset, args, config, initialize)
22
+
23
+ self.source_id = dataset.label_field.vocab.stoi["Source"] + 1
24
+ self.target_id = dataset.label_field.vocab.stoi["Target"] + 1
25
+ self.parser = NodeCentricParser(dataset)
model/head/sequential_head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from model.head.abstract_head import AbstractHead
9
+ from data.parser.to_mrp.sequential_parser import SequentialParser
10
+ from utility.cross_entropy import cross_entropy
11
+
12
+
13
+ class SequentialHead(AbstractHead):
14
+ def __init__(self, dataset, args, initialize):
15
+ config = {
16
+ "label": True,
17
+ "edge presence": False,
18
+ "edge label": False,
19
+ "anchor": True,
20
+ "source_anchor": True,
21
+ "target_anchor": True
22
+ }
23
+ super(SequentialHead, self).__init__(dataset, args, config, initialize)
24
+ self.parser = SequentialParser(dataset)
model/model.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from model.module.encoder import Encoder
8
+
9
+ from model.module.transformer import Decoder
10
+ from model.head.node_centric_head import NodeCentricHead
11
+ from model.head.labeled_edge_head import LabeledEdgeHead
12
+ from model.head.sequential_head import SequentialHead
13
+ from utility.utils import create_padding_mask
14
+
15
+
16
+ class Model(nn.Module):
17
+ def __init__(self, dataset, args, initialize=True):
18
+ super(Model, self).__init__()
19
+ self.encoder = Encoder(args, dataset)
20
+ if args.n_layers > 0:
21
+ self.decoder = Decoder(args)
22
+ else:
23
+ self.decoder = lambda x, *args: x # identity function, which ignores all arguments except the first one
24
+
25
+ if args.graph_mode == "sequential":
26
+ self.head = SequentialHead(dataset, args, initialize)
27
+ elif args.graph_mode == "node-centric":
28
+ self.head = NodeCentricHead(dataset, args, initialize)
29
+ elif args.graph_mode == "labeled-edge":
30
+ self.head = LabeledEdgeHead(dataset, args, initialize)
31
+
32
+ self.query_length = args.query_length
33
+ self.dataset = dataset
34
+ self.args = args
35
+
36
+ def forward(self, batch, inference=False, **kwargs):
37
+ every_input, word_lens = batch["every_input"]
38
+ decoder_lens = self.query_length * word_lens
39
+ batch_size, input_len = every_input.size(0), every_input.size(1)
40
+ device = every_input.device
41
+
42
+ encoder_mask = create_padding_mask(batch_size, input_len, word_lens, device)
43
+ decoder_mask = create_padding_mask(batch_size, self.query_length * input_len, decoder_lens, device)
44
+
45
+ encoder_output, decoder_input = self.encoder(batch["input"], batch["char_form_input"], batch["input_scatter"], input_len)
46
+
47
+ decoder_output = self.decoder(decoder_input, encoder_output, decoder_mask, encoder_mask)
48
+
49
+ if inference:
50
+ return self.head.predict(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
51
+ else:
52
+ return self.head(encoder_output, decoder_output, encoder_mask, decoder_mask, batch)
53
+
54
+ def get_params_for_optimizer(self, args):
55
+ encoder_decay, encoder_no_decay = self.get_encoder_parameters(args.n_encoder_layers)
56
+ decoder_decay, decoder_no_decay = self.get_decoder_parameters()
57
+
58
+ parameters = [{"params": p, "weight_decay": args.encoder_weight_decay} for p in encoder_decay]
59
+ parameters += [{"params": p, "weight_decay": 0.0} for p in encoder_no_decay]
60
+ parameters += [
61
+ {"params": decoder_decay, "weight_decay": args.decoder_weight_decay},
62
+ {"params": decoder_no_decay, "weight_decay": 0.0},
63
+ ]
64
+ return parameters
65
+
66
+ def get_decoder_parameters(self):
67
+ no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
68
+ decay_params = (p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
69
+ no_decay_params = (p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and not name.startswith("encoder.bert") and p.requires_grad)
70
+
71
+ return decay_params, no_decay_params
72
+
73
+ def get_encoder_parameters(self, n_layers):
74
+ no_decay = ["bias", "LayerNorm.weight", "_norm.weight"]
75
+ decay_params = [
76
+ [p for name, p in self.named_parameters() if not any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
77
+ ]
78
+ no_decay_params = [
79
+ [p for name, p in self.named_parameters() if any(nd in name for nd in no_decay) and name.startswith(f"encoder.bert.encoder.layer.{n_layers - 1 - i}.") and p.requires_grad] for i in range(n_layers)
80
+ ]
81
+
82
+ return decay_params, no_decay_params
model/module/__init__.py ADDED
File without changes
model/module/anchor_classifier.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from model.module.biaffine import Biaffine
8
+
9
+
10
+ class AnchorClassifier(nn.Module):
11
+ def __init__(self, dataset, args, initialize: bool, bias=True, mode="anchor"):
12
+ super(AnchorClassifier, self).__init__()
13
+
14
+ self.token_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
15
+ self.label_f = nn.Linear(args.hidden_size, args.hidden_size_anchor)
16
+ self.dropout = nn.Dropout(args.dropout_anchor)
17
+
18
+ if bias and initialize:
19
+ bias_init = torch.tensor([getattr(dataset, f"{mode}_freq")])
20
+ bias_init = (bias_init / (1.0 - bias_init)).log()
21
+ else:
22
+ bias_init = None
23
+
24
+ self.output = Biaffine(args.hidden_size_anchor, 1, bias=bias, bias_init=bias_init)
25
+
26
+ def forward(self, label, tokens, encoder_mask):
27
+ tokens = self.dropout(F.elu(self.token_f(tokens))) # shape: (B, T_w, H)
28
+ label = self.dropout(F.elu(self.label_f(label))) # shape: (B, T_l, H)
29
+ anchor = self.output(label, tokens).squeeze(-1) # shape: (B, T_l, T_w)
30
+
31
+ anchor = anchor.masked_fill(encoder_mask.unsqueeze(1), float("-inf")) # shape: (B, T_l, T_w)
32
+ return anchor
model/module/biaffine.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch.nn as nn
5
+ from model.module.bilinear import Bilinear
6
+
7
+
8
+ class Biaffine(nn.Module):
9
+ def __init__(self, input_dim, output_dim, bias=True, bias_init=None):
10
+ super(Biaffine, self).__init__()
11
+
12
+ self.linear_1 = nn.Linear(input_dim, output_dim, bias=False)
13
+ self.linear_2 = nn.Linear(input_dim, output_dim, bias=False)
14
+
15
+ self.bilinear = Bilinear(input_dim, input_dim, output_dim, bias=bias)
16
+ if bias_init is not None:
17
+ self.bilinear.bias.data = bias_init
18
+
19
+ def forward(self, x, y):
20
+ return self.bilinear(x, y) + self.linear_1(x).unsqueeze(2) + self.linear_2(y).unsqueeze(1)
model/module/bilinear.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://github.com/NLPInBLCU/BiaffineDependencyParsing/blob/master/modules/biaffine.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Bilinear(nn.Module):
8
+ """
9
+ 使用版本
10
+ A bilinear module that deals with broadcasting for efficient memory usage.
11
+ Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
12
+ Output: tensor of size (N x L1 x L2 x O)"""
13
+
14
+ def __init__(self, input1_size, input2_size, output_size, bias=True):
15
+ super(Bilinear, self).__init__()
16
+
17
+ self.input1_size = input1_size
18
+ self.input2_size = input2_size
19
+ self.output_size = output_size
20
+
21
+ self.weight = nn.Parameter(torch.Tensor(input1_size, input2_size, output_size))
22
+ self.bias = nn.Parameter(torch.Tensor(output_size)) if bias else None
23
+
24
+ self.reset_parameters()
25
+
26
+ def reset_parameters(self):
27
+ nn.init.zeros_(self.weight)
28
+
29
+ def forward(self, input1, input2):
30
+ input1_size = list(input1.size())
31
+ input2_size = list(input2.size())
32
+
33
+ intermediate = torch.mm(input1.view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size),)
34
+
35
+ input2 = input2.transpose(1, 2)
36
+ output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
37
+
38
+ output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3)
39
+
40
+ if self.bias is not None:
41
+ output = output + self.bias
42
+
43
+ return output
model/module/char_embedding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
7
+
8
+
9
+ class CharEmbedding(nn.Module):
10
+ def __init__(self, vocab_size: int, embedding_size: int, output_size: int):
11
+ super(CharEmbedding, self).__init__()
12
+
13
+ self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False)
14
+ self.layer_norm = nn.LayerNorm(embedding_size)
15
+ self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True)
16
+ self.out_linear = nn.Linear(2*embedding_size, output_size)
17
+ self.layer_norm_2 = nn.LayerNorm(output_size)
18
+
19
+ def forward(self, words, sentence_lens, word_lens):
20
+ # input shape: (B, W, C)
21
+ n_words = words.size(1)
22
+ sentence_lens = sentence_lens.cpu()
23
+ sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) # shape: (B*W, C)
24
+ lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) # shape: (B*W)
25
+ word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) # shape: (B*W*C)
26
+
27
+ embedded = self.embedding(word_packed.data) # shape: (B*W*C, D)
28
+ embedded = self.layer_norm(embedded) # shape: (B*W*C, D)
29
+
30
+ embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3])
31
+ _, embedded = self.gru(embedded_packed) # shape: (layers * 2, B*W, D)
32
+
33
+ embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) # shape: (B*W, 2*D)
34
+ embedded = F.relu(embedded)
35
+ embedded = self.out_linear(embedded)
36
+ embedded = self.layer_norm_2(embedded)
37
+
38
+ embedded, _ = pad_packed_sequence(
39
+ PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words,
40
+ ) # shape: (B, W, 2*D)
41
+
42
+ return embedded # shape: (B, W, 2*D)
model/module/edge_classifier.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from model.module.biaffine import Biaffine
8
+
9
+
10
+ class EdgeClassifier(nn.Module):
11
+ def __init__(self, dataset, args, initialize: bool, presence: bool, label: bool):
12
+ super(EdgeClassifier, self).__init__()
13
+
14
+ self.presence = presence
15
+ if self.presence:
16
+ if initialize:
17
+ presence_init = torch.tensor([dataset.edge_presence_freq])
18
+ presence_init = (presence_init / (1.0 - presence_init)).log()
19
+ else:
20
+ presence_init = None
21
+
22
+ self.edge_presence = EdgeBiaffine(
23
+ args.hidden_size, args.hidden_size_edge_presence, 1, args.dropout_edge_presence, bias_init=presence_init
24
+ )
25
+
26
+ self.label = label
27
+ if self.label:
28
+ label_init = (dataset.edge_label_freqs / (1.0 - dataset.edge_label_freqs)).log() if initialize else None
29
+ n_labels = len(dataset.edge_label_field.vocab)
30
+ self.edge_label = EdgeBiaffine(
31
+ args.hidden_size, args.hidden_size_edge_label, n_labels, args.dropout_edge_label, bias_init=label_init
32
+ )
33
+
34
+ def forward(self, x):
35
+ presence, label = None, None
36
+
37
+ if self.presence:
38
+ presence = self.edge_presence(x).squeeze(-1) # shape: (B, T, T)
39
+ if self.label:
40
+ label = self.edge_label(x) # shape: (B, T, T, O_1)
41
+
42
+ return presence, label
43
+
44
+
45
+ class EdgeBiaffine(nn.Module):
46
+ def __init__(self, hidden_dim, bottleneck_dim, output_dim, dropout, bias_init=None):
47
+ super(EdgeBiaffine, self).__init__()
48
+ self.hidden = nn.Linear(hidden_dim, 2 * bottleneck_dim)
49
+ self.output = Biaffine(bottleneck_dim, output_dim, bias_init=bias_init)
50
+ self.dropout = nn.Dropout(dropout)
51
+
52
+ def forward(self, x):
53
+ x = self.dropout(F.elu(self.hidden(x))) # shape: (B, T, 2H)
54
+ predecessors, current = x.chunk(2, dim=-1) # shape: (B, T, H), (B, T, H)
55
+ edge = self.output(current, predecessors) # shape: (B, T, T, O)
56
+ return edge
model/module/encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # coding=utf-8
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import AutoModel
11
+ from model.module.char_embedding import CharEmbedding
12
+
13
+
14
+ class WordDropout(nn.Dropout):
15
+ def forward(self, input_tensor):
16
+ if self.p == 0:
17
+ return input_tensor
18
+
19
+ ones = input_tensor.new_ones(input_tensor.shape[:-1])
20
+ dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
21
+
22
+ return dropout_mask.unsqueeze(-1) * input_tensor
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, args, dataset):
27
+ super(Encoder, self).__init__()
28
+
29
+ self.dim = args.hidden_size
30
+ self.n_layers = args.n_encoder_layers
31
+ self.width_factor = args.query_length
32
+
33
+ self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False)
34
+ # self.bert._set_gradient_checkpointing(self.bert.encoder, value=True)
35
+ if args.encoder_freeze_embedding:
36
+ self.bert.embeddings.requires_grad_(False)
37
+ self.bert.embeddings.LayerNorm.requires_grad_(True)
38
+
39
+ if args.freeze_bert:
40
+ self.bert.requires_grad_(False)
41
+
42
+ self.use_char_embedding = args.char_embedding
43
+ if self.use_char_embedding:
44
+ self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim)
45
+ self.word_dropout = WordDropout(args.dropout_word)
46
+
47
+ self.post_layer_norm = nn.LayerNorm(self.dim)
48
+ self.subword_attention = nn.Linear(self.dim, 1)
49
+
50
+ if self.width_factor > 1:
51
+ self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor)
52
+ else:
53
+ self.query_generator = nn.Identity()
54
+
55
+ self.encoded_layer_norm = nn.LayerNorm(self.dim)
56
+ self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True)
57
+
58
+ def forward(self, bert_input, form_chars, to_scatter, n_words):
59
+ tokens, mask = bert_input
60
+ batch_size = tokens.size(0)
61
+
62
+ encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:]
63
+ encoded = torch.stack(encoded, dim=0) # shape: (12, B, T, H)
64
+ encoded = self.encoded_layer_norm(encoded)
65
+
66
+ if self.training:
67
+ time_len = encoded.size(2)
68
+ scores = self.scores.expand(-1, batch_size, time_len, -1)
69
+ dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device)
70
+ dropout.bernoulli_(0.1)
71
+ scores = scores.masked_fill(dropout, float("-inf"))
72
+ else:
73
+ scores = self.scores
74
+
75
+ scores = F.softmax(scores, dim=0)
76
+ encoded = (scores * encoded).sum(0) # shape: (B, T, H)
77
+ encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0) # shape: (B, T, H)
78
+
79
+ subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim) # shape: (B, T, 1)
80
+ subword_attention = subword_attention.expand_as(to_scatter) # shape: (B, T_subword, T_word)
81
+ subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf")) # shape: (B, T_subword, T_word)
82
+ subword_attention = torch.softmax(subword_attention, dim=1) # shape: (B, T_subword, T_word)
83
+ subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0) # shape: (B, T_subword, T_word)
84
+
85
+ encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention)
86
+ encoder_output = self.post_layer_norm(encoder_output)
87
+
88
+ if self.use_char_embedding:
89
+ form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2])
90
+ encoder_output = self.word_dropout(encoder_output) + form_char_embedding
91
+
92
+ decoder_input = self.query_generator(encoder_output)
93
+ decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2) # shape: (B, T*Q, D)
94
+
95
+ return encoder_output, decoder_input