added demo base code
Browse files- .gitattributes +3 -0
- __init__.py +0 -0
- app.py +87 -0
- custom_bart/__init__.py +12 -0
- custom_bart/attention_utils.py +132 -0
- custom_bart/bart_attention.py +313 -0
- custom_bart/bart_for_conditional_generation.py +205 -0
- custom_bart/bart_generation_mixin.py +0 -0
- custom_bart/bart_mask_attention.py +238 -0
- custom_bart/bart_model.py +169 -0
- custom_bart/bart_onnx.py +240 -0
- custom_bart/config.py +197 -0
- custom_bart/custom_constants.py +168 -0
- custom_bart/custom_outputs.py +142 -0
- custom_bart/decoder.py +312 -0
- custom_bart/decoder_layer.py +134 -0
- custom_bart/encoder.py +216 -0
- custom_bart/encoder_layer.py +102 -0
- custom_tokenizer/__init__.py +1 -0
- custom_tokenizer/bart_custom_tokenizer_fast.py +484 -0
- data/__init__.py +0 -0
- data/relation_utils.py +53 -0
- inference.py +349 -0
- kgs_binding/__init__.py +3 -0
- kgs_binding/conceptnet/__init__.py +1 -0
- kgs_binding/conceptnet/conceptnet_english_noun_2_noun_relations.json +3 -0
- kgs_binding/conceptnet/conceptnet_english_nouns.json +3 -0
- kgs_binding/conceptnet/conceptnet_english_nouns_simple.json +3 -0
- kgs_binding/conceptnet_handler.py +61 -0
- kgs_binding/english_stopwords.txt +1126 -0
- kgs_binding/kg_base_wrapper.py +80 -0
- kgs_binding/kg_qa_binding_utils.py +73 -0
- kgs_binding/parsing_utils.py +86 -0
- kgs_binding/relation_mapper_builder.py +164 -0
- kgs_binding/swow/__init__.py +1 -0
- kgs_binding/swow/swow_knowledge.json +0 -0
- kgs_binding/swow_handler.py +75 -0
- model_utils.py +54 -0
- requirements.txt +4 -0
- utils.py +230 -0
.gitattributes
CHANGED
@@ -29,3 +29,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
29 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
kgs_binding/conceptnet/conceptnet_english_noun_2_noun_relations.json filter=lfs diff=lfs merge=lfs -text
|
33 |
+
kgs_binding/conceptnet/conceptnet_english_nouns.json filter=lfs diff=lfs merge=lfs -text
|
34 |
+
kgs_binding/conceptnet/conceptnet_english_nouns_simple.json filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
|
4 |
+
from inference import RelationsInference
|
5 |
+
from utils import KGType,Model_Type
|
6 |
+
|
7 |
+
#############################
|
8 |
+
# Constants
|
9 |
+
#############################
|
10 |
+
|
11 |
+
examples = [["What's the meaning of life?", "eli5", "constraint"],
|
12 |
+
["boat, water, bird", "commongen", "constraint"],
|
13 |
+
["What flows under a bridge?", "commonsense_qa", "constraint"]]
|
14 |
+
|
15 |
+
bart = RelationsInference(
|
16 |
+
model_path='MrVicente/commonsense_bart_commongen',
|
17 |
+
kg_type=KGType.CONCEPTNET,
|
18 |
+
model_type=Model_Type.RELATIONS,
|
19 |
+
max_length=32
|
20 |
+
)
|
21 |
+
|
22 |
+
#############################
|
23 |
+
# Helper
|
24 |
+
#############################
|
25 |
+
|
26 |
+
def infer_bart(context, task_type, decoding_type_str):
|
27 |
+
response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False)
|
28 |
+
return response[0]
|
29 |
+
|
30 |
+
|
31 |
+
def plot_attention(layer, head):
|
32 |
+
fig = plt.figure()
|
33 |
+
plt.plot([1, 2, 3], [2, 4, 6])
|
34 |
+
plt.title("Things")
|
35 |
+
plt.ylabel("Cases")
|
36 |
+
plt.xlabel("Days since Day 0")
|
37 |
+
return fig
|
38 |
+
|
39 |
+
|
40 |
+
#############################
|
41 |
+
# Interface
|
42 |
+
#############################
|
43 |
+
|
44 |
+
app = gr.Blocks()
|
45 |
+
with app:
|
46 |
+
gr.Markdown(
|
47 |
+
"""
|
48 |
+
# Demo
|
49 |
+
### Test Commonsense Relation-Aware BART (BART-RA) model
|
50 |
+
|
51 |
+
Tutorial: <br>
|
52 |
+
1) Select the possible model variations and tasks;<br>
|
53 |
+
2) Change the inputs and Click the buttons to produce results;<br>
|
54 |
+
3) See attention visualisations, by choosing a specific layer and head;<br>
|
55 |
+
""")
|
56 |
+
with gr.Row():
|
57 |
+
context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
|
58 |
+
model_result_output = gr.Textbox(lines=2, label='Model result:')
|
59 |
+
with gr.Column():
|
60 |
+
task_type_choice = gr.Radio(
|
61 |
+
["eli5", "commongen"], value="eli5", label="What task do you want to try?"
|
62 |
+
)
|
63 |
+
decoding_type_choice = gr.Radio(
|
64 |
+
["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
|
65 |
+
)
|
66 |
+
with gr.Row():
|
67 |
+
model_btn = gr.Button(value="See Model Results")
|
68 |
+
gr.Markdown(
|
69 |
+
"""
|
70 |
+
---
|
71 |
+
Observe Attention
|
72 |
+
"""
|
73 |
+
)
|
74 |
+
with gr.Row():
|
75 |
+
with gr.Column():
|
76 |
+
layer = gr.Slider(0, 11, 0, step=1, label="Layer")
|
77 |
+
head = gr.Slider(0, 15, 0, step=1, label="Head")
|
78 |
+
with gr.Column():
|
79 |
+
plot_output = gr.Plot()
|
80 |
+
with gr.Row():
|
81 |
+
vis_btn = gr.Button(value="See Attention Scores")
|
82 |
+
model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
|
83 |
+
outputs=[model_result_output])
|
84 |
+
vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output])
|
85 |
+
|
86 |
+
if __name__ == '__main__':
|
87 |
+
app.launch()
|
custom_bart/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bart_attention import BartCustomAttention
|
2 |
+
from .bart_mask_attention import BartCustomMaskAttention
|
3 |
+
from .bart_for_conditional_generation import BartCustomForConditionalGeneration
|
4 |
+
from .bart_model import BartCustomModel
|
5 |
+
from .config import BartCustomConfig
|
6 |
+
from .custom_constants import BartConstants
|
7 |
+
from .decoder import *
|
8 |
+
from .decoder_layer import *
|
9 |
+
from .encoder import *
|
10 |
+
from .encoder_layer import *
|
11 |
+
from .bart_generation_mixin import *
|
12 |
+
from . import *
|
custom_bart/attention_utils.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
|
7 |
+
# Remote modules
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Local modules
|
11 |
+
|
12 |
+
#############################
|
13 |
+
# Constants
|
14 |
+
#############################
|
15 |
+
|
16 |
+
#############################
|
17 |
+
# Stuff
|
18 |
+
#############################
|
19 |
+
|
20 |
+
def find_head_to_mask(heads_mask) -> int:
|
21 |
+
head_idx = torch.argmax(heads_mask)
|
22 |
+
head_idx_simple = head_idx.item()
|
23 |
+
return head_idx_simple
|
24 |
+
|
25 |
+
def commonsense_attention_mask_update(bsz, n_tokens, commonsense_matrix, attn_weights,
|
26 |
+
num_heads=16, specific_head=0):
|
27 |
+
commonsense_mask = torch.zeros(
|
28 |
+
((bsz, num_heads, n_tokens, n_tokens))
|
29 |
+
)
|
30 |
+
attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
|
31 |
+
zeros = torch.zeros(
|
32 |
+
((bsz, n_tokens, n_tokens))
|
33 |
+
)
|
34 |
+
head_previous_attention_weights = attn_weights_helper[specific_head]
|
35 |
+
attn_weights_helper[specific_head] = zeros
|
36 |
+
attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
|
37 |
+
if commonsense_matrix is None:
|
38 |
+
# ignore is not passed (ones -> neutral since multiplication is used)
|
39 |
+
commonsense_matrix = torch.ones(
|
40 |
+
((bsz, n_tokens, n_tokens))
|
41 |
+
)
|
42 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
43 |
+
commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
|
44 |
+
# TODO Stupid conversion
|
45 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
|
46 |
+
return attn_weights_helper + commonsense_mask
|
47 |
+
|
48 |
+
def convert_relations_to_binary_mask(input_relations, should_clone=True):
|
49 |
+
relations_binary_mask=input_relations
|
50 |
+
if should_clone:
|
51 |
+
relations_binary_mask = input_relations.clone()
|
52 |
+
relations_binary_mask[relations_binary_mask > 1] = 1
|
53 |
+
return relations_binary_mask
|
54 |
+
|
55 |
+
def relation_binary_2d_to_1d(relations_binary_mask):
|
56 |
+
relations_binary_mask = relations_binary_mask.sum(dim=1)
|
57 |
+
relations_binary_mask[relations_binary_mask > 1] = 1
|
58 |
+
return relations_binary_mask
|
59 |
+
|
60 |
+
def create_layer_with_commonsense_on_specific_head(relation_binary_mask, bsz, num_heads, specific_head=0):
|
61 |
+
n_tokens = relation_binary_mask.size()[-1]
|
62 |
+
relations_mask = torch.zeros(
|
63 |
+
(bsz, num_heads, n_tokens, n_tokens)
|
64 |
+
)
|
65 |
+
layer = relations_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
66 |
+
layer[specific_head] = relation_binary_mask
|
67 |
+
layer = layer.reshape((bsz, num_heads, n_tokens, n_tokens))
|
68 |
+
return layer
|
69 |
+
|
70 |
+
def update_weights_regarding_relations_on_specific_head(layer_head_mask, attn_weights, relation_inputs, bsz, num_heads, tgt_len, src_len, verbose=True):
|
71 |
+
#layer_head_mask = layer_head_mask.to(attn_weights.device)
|
72 |
+
inverse_layer_head_mask = (layer_head_mask.view(num_heads, 1, 1) - 1) * -1
|
73 |
+
#inverse_layer_head_mask = inverse_layer_head_mask.to(attn_weights.device)
|
74 |
+
#print('layer_head_mask:', layer_head_mask)
|
75 |
+
if verbose:
|
76 |
+
print("==============================")
|
77 |
+
print('layer_head_mask.shape:', layer_head_mask.shape)
|
78 |
+
print('inverse_layer_head_mask.shape:', inverse_layer_head_mask.shape)
|
79 |
+
print('attn_weights.shape:', attn_weights.shape)
|
80 |
+
print('relation_inputs.shape', relation_inputs.shape)
|
81 |
+
print("==============================")
|
82 |
+
#print('layer_head_mask.device:', layer_head_mask.device)
|
83 |
+
#print('inverse_layer_head_mask.device:', inverse_layer_head_mask.device)
|
84 |
+
#print('relation_inputs.device:', relation_inputs.device)
|
85 |
+
intermediate_weights = inverse_layer_head_mask * attn_weights.view(bsz, num_heads, tgt_len, src_len)
|
86 |
+
relation_inputs = convert_relations_to_binary_mask(relation_inputs, should_clone=False)
|
87 |
+
relation_weights = layer_head_mask.view(num_heads, 1, 1) * relation_inputs.view(bsz,1,tgt_len, src_len) * attn_weights.view(bsz, num_heads,
|
88 |
+
tgt_len, src_len)
|
89 |
+
attn_weights = intermediate_weights + relation_weights
|
90 |
+
# [batch, n_heads, seq_length, seq_length]
|
91 |
+
if verbose:
|
92 |
+
print('attn_weights_int.shape', attn_weights.shape)
|
93 |
+
return attn_weights
|
94 |
+
|
95 |
+
"""
|
96 |
+
def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0):
|
97 |
+
commonsense_mask = torch.zeros(
|
98 |
+
((bsz, num_heads, n_tokens, n_tokens))
|
99 |
+
)
|
100 |
+
if commonsense_matrix is None:
|
101 |
+
commonsense_matrix = torch.zeros(
|
102 |
+
((bsz, n_tokens, n_tokens))
|
103 |
+
)
|
104 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
105 |
+
commonsense_mask[specific_head] = commonsense_matrix
|
106 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
|
107 |
+
return commonsense_mask
|
108 |
+
|
109 |
+
def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights,
|
110 |
+
specific_head=0):
|
111 |
+
num_heads = self.num_heads
|
112 |
+
commonsense_mask = torch.zeros(
|
113 |
+
((bsz, num_heads, n_tokens, n_tokens))
|
114 |
+
)
|
115 |
+
attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
|
116 |
+
zeros = torch.zeros(
|
117 |
+
((bsz, n_tokens, n_tokens))
|
118 |
+
)
|
119 |
+
head_previous_attention_weights = attn_weights_helper[specific_head]
|
120 |
+
attn_weights_helper[specific_head] = zeros
|
121 |
+
attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
|
122 |
+
if commonsense_matrix is None:
|
123 |
+
# ignore is not passed (ones -> neutral since multiplication is used)
|
124 |
+
commonsense_matrix = torch.ones(
|
125 |
+
((bsz, n_tokens, n_tokens))
|
126 |
+
)
|
127 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
128 |
+
commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
|
129 |
+
# TODO Stupid conversion
|
130 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
|
131 |
+
return attn_weights_helper + commonsense_mask
|
132 |
+
"""
|
custom_bart/bart_attention.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
# Remote modules
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
# Local modules
|
12 |
+
from .attention_utils import (
|
13 |
+
create_layer_with_commonsense_on_specific_head,
|
14 |
+
find_head_to_mask,
|
15 |
+
convert_relations_to_binary_mask,
|
16 |
+
update_weights_regarding_relations_on_specific_head
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
class BartCustomAttention(nn.Module):
|
21 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
embed_dim: int,
|
26 |
+
num_heads: int,
|
27 |
+
dropout: float = 0.0,
|
28 |
+
is_decoder: bool = False,
|
29 |
+
bias: bool = True,
|
30 |
+
num_relation_kinds: int = 0,
|
31 |
+
use_same_relation_kv_emb: bool = True,
|
32 |
+
heads_mask: Optional[torch.Tensor] = None,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
self.embed_dim = embed_dim
|
36 |
+
self.num_heads = num_heads
|
37 |
+
self.dropout = dropout
|
38 |
+
self.head_dim = embed_dim // num_heads
|
39 |
+
|
40 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
41 |
+
raise ValueError(
|
42 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
43 |
+
f" and `num_heads`: {num_heads})."
|
44 |
+
)
|
45 |
+
if heads_mask.size() != (self.num_heads,):
|
46 |
+
raise ValueError(
|
47 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {heads_mask.size()}"
|
48 |
+
)
|
49 |
+
self.heads_mask = heads_mask
|
50 |
+
|
51 |
+
self.scaling = self.head_dim**-0.5
|
52 |
+
self.is_decoder = is_decoder
|
53 |
+
|
54 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
55 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
56 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
57 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
58 |
+
|
59 |
+
self.num_relation_kinds = num_relation_kinds
|
60 |
+
self.relation_k_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
|
61 |
+
if use_same_relation_kv_emb:
|
62 |
+
self.relation_v_emb = self.relation_k_emb
|
63 |
+
else:
|
64 |
+
self.relation_v_emb = nn.Embedding(num_relation_kinds + 1, self.head_dim, padding_idx=0)
|
65 |
+
|
66 |
+
self.k_rel_scale = 0.0
|
67 |
+
self.v_rel_scale = 1.0
|
68 |
+
|
69 |
+
|
70 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
71 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
hidden_states: torch.Tensor,
|
76 |
+
key_value_states: Optional[torch.Tensor] = None,
|
77 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
78 |
+
attention_mask: Optional[torch.Tensor] = None,
|
79 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
80 |
+
output_attentions: bool = False,
|
81 |
+
relation_inputs: Optional[torch.Tensor] = None,
|
82 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
83 |
+
"""Input shape: Batch x Time x Channel"""
|
84 |
+
|
85 |
+
#print('device:', hidden_states.device)
|
86 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
87 |
+
# for the decoder
|
88 |
+
is_cross_attention = key_value_states is not None
|
89 |
+
|
90 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
91 |
+
|
92 |
+
#print(relation_inputs.shape, 'VS ', (bsz, tgt_len, tgt_len))
|
93 |
+
if relation_inputs is None:
|
94 |
+
# TODO
|
95 |
+
print('oh no')
|
96 |
+
relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
|
97 |
+
print(relation_inputs.shape, ' | ', (bsz, tgt_len, tgt_len))
|
98 |
+
assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
|
99 |
+
|
100 |
+
# (batch_size, seq_length, seq_length, self.num_relation_kinds, self.inner_dim // num_relation_kinds)
|
101 |
+
relation_k_embeds = self.relation_k_emb(relation_inputs)
|
102 |
+
relation_v_embeds = self.relation_v_emb(relation_inputs)
|
103 |
+
|
104 |
+
# get query proj
|
105 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
106 |
+
# get key, value proj
|
107 |
+
if is_cross_attention and past_key_value is not None:
|
108 |
+
# reuse k,v, cross_attentions
|
109 |
+
key_states = past_key_value[0]
|
110 |
+
value_states = past_key_value[1]
|
111 |
+
elif is_cross_attention:
|
112 |
+
# cross_attentions
|
113 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
114 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
115 |
+
elif past_key_value is not None:
|
116 |
+
# reuse k, v, self_attention
|
117 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
118 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
119 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
120 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
121 |
+
else:
|
122 |
+
# self_attention
|
123 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
124 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
125 |
+
|
126 |
+
if self.is_decoder:
|
127 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
128 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
129 |
+
# key/value_states (first "if" case)
|
130 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
131 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
132 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
133 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
134 |
+
past_key_value = (key_states, value_states)
|
135 |
+
|
136 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
137 |
+
query_states = self._shape(query_states, tgt_len, bsz)
|
138 |
+
src_len = key_states.size(2)
|
139 |
+
|
140 |
+
# compute scores
|
141 |
+
attn_weights = torch.matmul(
|
142 |
+
query_states, key_states.transpose(3, 2)
|
143 |
+
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
144 |
+
|
145 |
+
# q_t is [batch, seq_length, n_heads, dim_per_head]
|
146 |
+
q_t = query_states.permute(0, 2, 1, 3)
|
147 |
+
#print('qt.shape: ', q_t.shape)
|
148 |
+
# r_t is [batch, seq_length, dim_per_head, seq_length]
|
149 |
+
r_t = relation_k_embeds.transpose(-2, -1)
|
150 |
+
#print('rt.shape: ', r_t.shape)
|
151 |
+
|
152 |
+
q_tr_t_matmul = torch.matmul(q_t, r_t) # [batch, seq_length, n_heads, seq_length]
|
153 |
+
q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) # [batch, n_heads, seq_length, seq_length]
|
154 |
+
|
155 |
+
# Make sure impact of relation-aware only apllicable on specific heads (k-part)
|
156 |
+
|
157 |
+
#print("==========")
|
158 |
+
#print('first K: ', q_tr_tmatmul_t.sum())
|
159 |
+
"""
|
160 |
+
q_tr_tmatmul_t = self.layer_heads_relation_attention_update(
|
161 |
+
self.heads_mask,
|
162 |
+
q_tr_tmatmul_t,
|
163 |
+
)
|
164 |
+
"""
|
165 |
+
#print('second K: ', q_tr_tmatmul_t.sum())
|
166 |
+
#print("==========")
|
167 |
+
|
168 |
+
# give weight to influence
|
169 |
+
#q_tr_tmatmul_t = 100.0 * q_tr_tmatmul_t
|
170 |
+
|
171 |
+
# Add to scores
|
172 |
+
#print('attn_weights k [before]', attn_weights)
|
173 |
+
#print('attn_weights sum k [before]', attn_weights.sum())
|
174 |
+
attn_weights += self.k_rel_scale * q_tr_tmatmul_t
|
175 |
+
#attn_weights += 100.0 * q_tr_tmatmul_t
|
176 |
+
#print('attn_weights k [after]: ', attn_weights)
|
177 |
+
#print('attn_weights sum k [after]', attn_weights.sum())
|
178 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
179 |
+
|
180 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
181 |
+
raise ValueError(
|
182 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
183 |
+
)
|
184 |
+
|
185 |
+
if attention_mask is not None:
|
186 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
187 |
+
raise ValueError(
|
188 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
189 |
+
)
|
190 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
191 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
192 |
+
|
193 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
194 |
+
|
195 |
+
# Wrong place... gonna comment
|
196 |
+
"""
|
197 |
+
attn_weights = self.layer_heads_relation_attention_update(layer_head_mask,
|
198 |
+
relation_inputs,
|
199 |
+
attn_weights,
|
200 |
+
bsz,
|
201 |
+
tgt_len,
|
202 |
+
src_len)
|
203 |
+
"""
|
204 |
+
if layer_head_mask is not None:
|
205 |
+
if layer_head_mask.size() != (self.num_heads,):
|
206 |
+
raise ValueError(
|
207 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
208 |
+
)
|
209 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
210 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
211 |
+
|
212 |
+
|
213 |
+
if output_attentions:
|
214 |
+
# this operation is a bit awkward, but it's required to
|
215 |
+
# make sure that attn_weights keeps its gradient.
|
216 |
+
# In order to do so, attn_weights have to be reshaped
|
217 |
+
# twice and have to be reused in the following
|
218 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
219 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
220 |
+
else:
|
221 |
+
attn_weights_reshaped = None
|
222 |
+
|
223 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
224 |
+
|
225 |
+
attn_output = torch.bmm(attn_probs, value_states.view(*proj_shape))
|
226 |
+
|
227 |
+
#print('attn_probs.shape', attn_probs.shape)
|
228 |
+
# w_t is [batch, seq_length, n_heads, seq_length]
|
229 |
+
w_t = attn_probs.view(bsz, self.num_heads, tgt_len, src_len).permute(0, 2, 1, 3)
|
230 |
+
#print('w_t.shape 1:', w_t.shape)
|
231 |
+
#print('relation_v_embeds.shape', relation_v_embeds.shape)
|
232 |
+
# [batch, seq_length, n_heads, seq_length]
|
233 |
+
w_tr_matmul = torch.matmul(w_t, relation_v_embeds)
|
234 |
+
#print('w_tr_matmul.shape 1:', w_tr_matmul.shape)
|
235 |
+
#print('w_tr_matmul.shape 2:', w_tr_matmul.shape)
|
236 |
+
# Make sure impact of relation-aware only apllicable on specific heads (v-part)
|
237 |
+
|
238 |
+
#print("==========")
|
239 |
+
#print('first V sum: ', w_tr_matmul.sum())
|
240 |
+
#print('first V: ', w_tr_matmul[0])
|
241 |
+
"""
|
242 |
+
w_tr_matmul = self.layer_heads_relation_attention_v_update(
|
243 |
+
self.heads_mask,
|
244 |
+
w_tr_matmul,
|
245 |
+
bsz,
|
246 |
+
tgt_len,
|
247 |
+
)
|
248 |
+
"""
|
249 |
+
w_tr_matmul = self.v_rel_scale * w_tr_matmul
|
250 |
+
#print('second V sum: ', w_tr_matmul.sum())
|
251 |
+
#print('second V: ', w_tr_matmul[0])
|
252 |
+
#print("==========")
|
253 |
+
|
254 |
+
w_tr_matmul = w_tr_matmul.permute(0, 2, 1, 3)
|
255 |
+
w_tr_matmul = w_tr_matmul.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
|
256 |
+
|
257 |
+
#print('attn_output v [before]', attn_output)
|
258 |
+
#print('attn_output sum v [before]', attn_output.sum())
|
259 |
+
attn_output += w_tr_matmul
|
260 |
+
#attn_output += 100.0 * w_tr_matmul
|
261 |
+
#print('attn_output v [after]', attn_output)
|
262 |
+
#print('attn_output sum v [after]', attn_output.sum())
|
263 |
+
#raise Exception()
|
264 |
+
|
265 |
+
|
266 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
267 |
+
raise ValueError(
|
268 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
269 |
+
)
|
270 |
+
|
271 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
272 |
+
attn_output = attn_output.transpose(1, 2)
|
273 |
+
|
274 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
275 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
276 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
277 |
+
|
278 |
+
attn_output = self.out_proj(attn_output)
|
279 |
+
|
280 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
281 |
+
|
282 |
+
def layer_heads_relation_attention_update(self,
|
283 |
+
layer_head_mask,
|
284 |
+
data,
|
285 |
+
):
|
286 |
+
if layer_head_mask is not None:
|
287 |
+
if layer_head_mask.size() != (self.num_heads,):
|
288 |
+
raise ValueError(
|
289 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
290 |
+
)
|
291 |
+
#print('layer_head_mask:', layer_head_mask)
|
292 |
+
masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data
|
293 |
+
return masked_weights
|
294 |
+
return data
|
295 |
+
|
296 |
+
def layer_heads_relation_attention_v_update(self,
|
297 |
+
layer_head_mask,
|
298 |
+
data,
|
299 |
+
bsz,
|
300 |
+
tgt_len,
|
301 |
+
):
|
302 |
+
if layer_head_mask is not None:
|
303 |
+
if layer_head_mask.size() != (self.num_heads,):
|
304 |
+
raise ValueError(
|
305 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
306 |
+
)
|
307 |
+
#relation_binary_mask = convert_relations_to_binary_mask(relation_inputs)
|
308 |
+
#one_dimension_mask = relation_binary_mask.sum(-1)
|
309 |
+
#relation_binary_mask = convert_relations_to_binary_mask(one_dimension_mask)
|
310 |
+
# [16, 128, 16, 64]
|
311 |
+
masked_weights = layer_head_mask.view(self.num_heads, 1, 1) * data.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
312 |
+
return masked_weights.view(bsz, tgt_len, self.num_heads, self.head_dim)
|
313 |
+
return data
|
custom_bart/bart_for_conditional_generation.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import (
|
7 |
+
Optional,
|
8 |
+
Tuple,
|
9 |
+
Union,
|
10 |
+
List,
|
11 |
+
)
|
12 |
+
|
13 |
+
# Remote modules
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import CrossEntropyLoss
|
17 |
+
from transformers import (
|
18 |
+
BartConfig,
|
19 |
+
BartPretrainedModel,
|
20 |
+
)
|
21 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
22 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
23 |
+
|
24 |
+
from transformers.utils import (
|
25 |
+
add_end_docstrings,
|
26 |
+
add_start_docstrings,
|
27 |
+
add_start_docstrings_to_model_forward,
|
28 |
+
logging,
|
29 |
+
replace_return_docstrings,
|
30 |
+
)
|
31 |
+
|
32 |
+
from .bart_model import BartCustomModel
|
33 |
+
from .config import BartCustomConfig
|
34 |
+
from .custom_constants import BartConstants
|
35 |
+
from .bart_generation_mixin import GenerationMixin
|
36 |
+
from .custom_outputs import CustomSeq2SeqLMOutput
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__)
|
39 |
+
|
40 |
+
@add_start_docstrings(
|
41 |
+
"The BART Model with a language modeling head. Can be used for summarization.", BartConstants.BART_START_DOCSTRING
|
42 |
+
)
|
43 |
+
class BartCustomForConditionalGeneration(BartPretrainedModel, GenerationMixin):
|
44 |
+
base_model_prefix = "model"
|
45 |
+
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
46 |
+
|
47 |
+
def __init__(self, config: BartCustomConfig):
|
48 |
+
super().__init__(config)
|
49 |
+
self.model = BartCustomModel(config)
|
50 |
+
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
51 |
+
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
|
52 |
+
|
53 |
+
# Initialize weights and apply final processing
|
54 |
+
self.post_init()
|
55 |
+
|
56 |
+
def get_encoder(self):
|
57 |
+
return self.model.get_encoder()
|
58 |
+
|
59 |
+
def get_decoder(self):
|
60 |
+
return self.model.get_decoder()
|
61 |
+
|
62 |
+
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
63 |
+
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
64 |
+
self._resize_final_logits_bias(new_num_tokens)
|
65 |
+
return new_embeddings
|
66 |
+
|
67 |
+
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
68 |
+
old_num_tokens = self.final_logits_bias.shape[-1]
|
69 |
+
if new_num_tokens <= old_num_tokens:
|
70 |
+
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
71 |
+
else:
|
72 |
+
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
73 |
+
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
74 |
+
self.register_buffer("final_logits_bias", new_bias)
|
75 |
+
|
76 |
+
def get_output_embeddings(self):
|
77 |
+
return self.lm_head
|
78 |
+
|
79 |
+
def set_output_embeddings(self, new_embeddings):
|
80 |
+
self.lm_head = new_embeddings
|
81 |
+
|
82 |
+
@add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING)
|
83 |
+
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=BartConstants.CONFIG_FOR_DOC)
|
84 |
+
@add_end_docstrings(BartConstants.BART_GENERATION_EXAMPLE)
|
85 |
+
def forward(
|
86 |
+
self,
|
87 |
+
input_ids: torch.LongTensor = None,
|
88 |
+
attention_mask: Optional[torch.Tensor] = None,
|
89 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
90 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
91 |
+
head_mask: Optional[torch.Tensor] = None,
|
92 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
93 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
94 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
95 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
96 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
97 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
98 |
+
labels: Optional[torch.LongTensor] = None,
|
99 |
+
use_cache: Optional[bool] = None,
|
100 |
+
output_attentions: Optional[bool] = None,
|
101 |
+
output_hidden_states: Optional[bool] = None,
|
102 |
+
return_dict: Optional[bool] = None,
|
103 |
+
input_commonsense_relations: Optional[torch.Tensor] = None,
|
104 |
+
reduce_ce=True,
|
105 |
+
) -> Union[Tuple, CustomSeq2SeqLMOutput]:
|
106 |
+
r"""
|
107 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
108 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
109 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
110 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
"""
|
114 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
115 |
+
|
116 |
+
if labels is not None:
|
117 |
+
if use_cache:
|
118 |
+
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
119 |
+
use_cache = False
|
120 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
121 |
+
decoder_input_ids = shift_tokens_right(
|
122 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
123 |
+
)
|
124 |
+
outputs = self.model(
|
125 |
+
input_ids,
|
126 |
+
attention_mask=attention_mask,
|
127 |
+
decoder_input_ids=decoder_input_ids,
|
128 |
+
encoder_outputs=encoder_outputs,
|
129 |
+
decoder_attention_mask=decoder_attention_mask,
|
130 |
+
head_mask=head_mask,
|
131 |
+
decoder_head_mask=decoder_head_mask,
|
132 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
133 |
+
past_key_values=past_key_values,
|
134 |
+
inputs_embeds=inputs_embeds,
|
135 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
136 |
+
use_cache=use_cache,
|
137 |
+
output_attentions=output_attentions,
|
138 |
+
output_hidden_states=output_hidden_states,
|
139 |
+
return_dict=return_dict,
|
140 |
+
relation_inputs=input_commonsense_relations
|
141 |
+
)
|
142 |
+
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
143 |
+
|
144 |
+
masked_lm_loss = None
|
145 |
+
if labels is not None:
|
146 |
+
loss_fct = CrossEntropyLoss(reduce=reduce_ce, ignore_index=self.config.pad_token_id) # added ignore_index=self.config.pad_token_id
|
147 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
148 |
+
|
149 |
+
if not return_dict:
|
150 |
+
output = (lm_logits,) + outputs[1:]
|
151 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
152 |
+
|
153 |
+
return CustomSeq2SeqLMOutput(
|
154 |
+
loss=masked_lm_loss,
|
155 |
+
logits=lm_logits,
|
156 |
+
past_key_values=outputs.past_key_values,
|
157 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
158 |
+
decoder_attentions=outputs.decoder_attentions,
|
159 |
+
cross_attentions=outputs.cross_attentions,
|
160 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
161 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
162 |
+
encoder_attentions=outputs.encoder_attentions,
|
163 |
+
head_mask=outputs.encoder_head_mask
|
164 |
+
)
|
165 |
+
|
166 |
+
def prepare_inputs_for_generation(
|
167 |
+
self,
|
168 |
+
decoder_input_ids,
|
169 |
+
past=None,
|
170 |
+
attention_mask=None,
|
171 |
+
head_mask=None,
|
172 |
+
decoder_head_mask=None,
|
173 |
+
cross_attn_head_mask=None,
|
174 |
+
use_cache=None,
|
175 |
+
encoder_outputs=None,
|
176 |
+
**kwargs
|
177 |
+
):
|
178 |
+
# cut decoder_input_ids if past is used
|
179 |
+
if past is not None:
|
180 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
181 |
+
|
182 |
+
return {
|
183 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
184 |
+
"encoder_outputs": encoder_outputs,
|
185 |
+
"past_key_values": past,
|
186 |
+
"decoder_input_ids": decoder_input_ids,
|
187 |
+
"attention_mask": attention_mask,
|
188 |
+
"head_mask": head_mask,
|
189 |
+
"decoder_head_mask": decoder_head_mask,
|
190 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
191 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
192 |
+
}
|
193 |
+
|
194 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
195 |
+
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
196 |
+
|
197 |
+
@staticmethod
|
198 |
+
def _reorder_cache(past, beam_idx):
|
199 |
+
reordered_past = ()
|
200 |
+
for layer_past in past:
|
201 |
+
# cached cross_attention states don't have to be reordered -> they are always the same
|
202 |
+
reordered_past += (
|
203 |
+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
204 |
+
)
|
205 |
+
return reordered_past
|
custom_bart/bart_generation_mixin.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
custom_bart/bart_mask_attention.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
# Remote modules
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
# Local modules
|
13 |
+
from .attention_utils import update_weights_regarding_relations_on_specific_head
|
14 |
+
|
15 |
+
|
16 |
+
class BartCustomMaskAttention(nn.Module):
|
17 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
embed_dim: int,
|
22 |
+
num_heads: int,
|
23 |
+
dropout: float = 0.0,
|
24 |
+
is_decoder: bool = False,
|
25 |
+
bias: bool = True,
|
26 |
+
num_relation_kinds: int = 0,
|
27 |
+
heads_mask: Optional[torch.Tensor] = None,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.embed_dim = embed_dim
|
31 |
+
self.num_heads = num_heads
|
32 |
+
self.dropout = dropout
|
33 |
+
self.head_dim = embed_dim // num_heads
|
34 |
+
|
35 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
36 |
+
raise ValueError(
|
37 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
38 |
+
f" and `num_heads`: {num_heads})."
|
39 |
+
)
|
40 |
+
if heads_mask.size() != (self.num_heads,):
|
41 |
+
raise ValueError(
|
42 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {heads_mask.size()}"
|
43 |
+
)
|
44 |
+
self.heads_mask = heads_mask
|
45 |
+
|
46 |
+
self.scaling = self.head_dim**-0.5
|
47 |
+
self.is_decoder = is_decoder
|
48 |
+
|
49 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
50 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
51 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
52 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
53 |
+
|
54 |
+
self.num_relation_kinds = num_relation_kinds
|
55 |
+
|
56 |
+
|
57 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
58 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
hidden_states: torch.Tensor,
|
63 |
+
key_value_states: Optional[torch.Tensor] = None,
|
64 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
65 |
+
attention_mask: Optional[torch.Tensor] = None,
|
66 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
67 |
+
output_attentions: bool = False,
|
68 |
+
relation_inputs: Optional[torch.Tensor] = None,
|
69 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
70 |
+
"""Input shape: Batch x Time x Channel"""
|
71 |
+
|
72 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
73 |
+
# for the decoder
|
74 |
+
is_cross_attention = key_value_states is not None
|
75 |
+
|
76 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
77 |
+
|
78 |
+
#print(relation_inputs.shape, 'VS ', (bsz, tgt_len, tgt_len))
|
79 |
+
if relation_inputs is None:
|
80 |
+
# TODO
|
81 |
+
relation_inputs = torch.zeros((bsz, tgt_len, tgt_len)).to('cuda').long()
|
82 |
+
assert relation_inputs.shape == (bsz, tgt_len, tgt_len)
|
83 |
+
|
84 |
+
# get query proj
|
85 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
86 |
+
# get key, value proj
|
87 |
+
if is_cross_attention and past_key_value is not None:
|
88 |
+
# reuse k,v, cross_attentions
|
89 |
+
key_states = past_key_value[0]
|
90 |
+
value_states = past_key_value[1]
|
91 |
+
elif is_cross_attention:
|
92 |
+
# cross_attentions
|
93 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
94 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
95 |
+
elif past_key_value is not None:
|
96 |
+
# reuse k, v, self_attention
|
97 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
98 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
99 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
100 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
101 |
+
else:
|
102 |
+
# self_attention
|
103 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
104 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
105 |
+
|
106 |
+
if self.is_decoder:
|
107 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
108 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
109 |
+
# key/value_states (first "if" case)
|
110 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
111 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
112 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
113 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
114 |
+
past_key_value = (key_states, value_states)
|
115 |
+
|
116 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
117 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
118 |
+
key_states = key_states.view(*proj_shape)
|
119 |
+
value_states = value_states.view(*proj_shape)
|
120 |
+
|
121 |
+
src_len = key_states.size(1)
|
122 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
123 |
+
|
124 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
125 |
+
raise ValueError(
|
126 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
127 |
+
)
|
128 |
+
|
129 |
+
if attention_mask is not None:
|
130 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
131 |
+
raise ValueError(
|
132 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
133 |
+
)
|
134 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
135 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
136 |
+
|
137 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
138 |
+
|
139 |
+
if self.heads_mask is not None:# and layer_head_mask is not None:
|
140 |
+
if self.heads_mask.size() != (self.num_heads,):
|
141 |
+
raise ValueError(
|
142 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
143 |
+
)
|
144 |
+
h_mask = layer_head_mask
|
145 |
+
#print('h_mask: ', h_mask)
|
146 |
+
if layer_head_mask is None:
|
147 |
+
h_mask = self.heads_mask
|
148 |
+
#h_mask.to(attn_weights.device)
|
149 |
+
attn_weights = update_weights_regarding_relations_on_specific_head(h_mask, attn_weights,
|
150 |
+
relation_inputs, bsz, self.num_heads, tgt_len,
|
151 |
+
src_len, verbose=False)
|
152 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
153 |
+
|
154 |
+
elif layer_head_mask is not None:
|
155 |
+
if layer_head_mask.size() != (self.num_heads,):
|
156 |
+
raise ValueError(
|
157 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
158 |
+
)
|
159 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
160 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
161 |
+
|
162 |
+
|
163 |
+
if output_attentions:
|
164 |
+
# this operation is a bit awkward, but it's required to
|
165 |
+
# make sure that attn_weights keeps its gradient.
|
166 |
+
# In order to do so, attn_weights have to be reshaped
|
167 |
+
# twice and have to be reused in the following
|
168 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
169 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
170 |
+
else:
|
171 |
+
attn_weights_reshaped = None
|
172 |
+
|
173 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
174 |
+
|
175 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
176 |
+
|
177 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
178 |
+
raise ValueError(
|
179 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
180 |
+
)
|
181 |
+
|
182 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
183 |
+
attn_output = attn_output.transpose(1, 2)
|
184 |
+
|
185 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
186 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
187 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
188 |
+
|
189 |
+
attn_output = self.out_proj(attn_output)
|
190 |
+
|
191 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
192 |
+
|
193 |
+
def find_head_to_mask(self, heads_mask) -> int:
|
194 |
+
head_idx = torch.argmax(heads_mask)
|
195 |
+
head_idx_simple = head_idx.item()
|
196 |
+
return head_idx_simple
|
197 |
+
|
198 |
+
def create_commonsense_mask(self, bsz, n_tokens, commonsense_matrix, num_heads=16, specific_head=0):
|
199 |
+
commonsense_mask = torch.zeros(
|
200 |
+
((bsz, num_heads, n_tokens, n_tokens))
|
201 |
+
)
|
202 |
+
if commonsense_matrix is None:
|
203 |
+
commonsense_matrix = torch.zeros(
|
204 |
+
((bsz, n_tokens, n_tokens))
|
205 |
+
)
|
206 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
207 |
+
commonsense_mask[specific_head] = commonsense_matrix
|
208 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
|
209 |
+
return commonsense_mask
|
210 |
+
|
211 |
+
def commonsense_attention_mask_update(self, bsz, n_tokens, commonsense_matrix, attn_weights,
|
212 |
+
specific_head=0):
|
213 |
+
num_heads = self.num_heads
|
214 |
+
commonsense_mask = torch.zeros(
|
215 |
+
((bsz, num_heads, n_tokens, n_tokens))
|
216 |
+
)
|
217 |
+
attn_weights_helper = attn_weights.reshape((num_heads, bsz, n_tokens, n_tokens))
|
218 |
+
zeros = torch.zeros(
|
219 |
+
((bsz, n_tokens, n_tokens))
|
220 |
+
)
|
221 |
+
head_previous_attention_weights = attn_weights_helper[specific_head]
|
222 |
+
attn_weights_helper[specific_head] = zeros
|
223 |
+
attn_weights_helper = attn_weights_helper.reshape((bsz, num_heads, n_tokens, n_tokens))
|
224 |
+
if commonsense_matrix is None:
|
225 |
+
# ignore is not passed (ones -> neutral since multiplication is used)
|
226 |
+
commonsense_matrix = torch.ones(
|
227 |
+
((bsz, n_tokens, n_tokens))
|
228 |
+
)
|
229 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
230 |
+
commonsense_mask[specific_head] = head_previous_attention_weights * commonsense_matrix
|
231 |
+
# TODO Stupid conversion
|
232 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens)).to('cuda')
|
233 |
+
return attn_weights_helper + commonsense_mask
|
234 |
+
|
235 |
+
def convert_relations_to_binary_mask(self, input_relations):
|
236 |
+
relations_binary_mask = input_relations.clone()
|
237 |
+
relations_binary_mask[relations_binary_mask > 1] = 1
|
238 |
+
return relations_binary_mask
|
custom_bart/bart_model.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import (
|
7 |
+
Optional,
|
8 |
+
Tuple,
|
9 |
+
Union,
|
10 |
+
List,
|
11 |
+
)
|
12 |
+
|
13 |
+
# Remote modules
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from transformers import (
|
17 |
+
BartConfig,
|
18 |
+
BartPretrainedModel,
|
19 |
+
)
|
20 |
+
from transformers.modeling_outputs import (
|
21 |
+
BaseModelOutput, Seq2SeqModelOutput,
|
22 |
+
)
|
23 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
24 |
+
|
25 |
+
from transformers.utils import (
|
26 |
+
add_code_sample_docstrings,
|
27 |
+
add_end_docstrings,
|
28 |
+
add_start_docstrings,
|
29 |
+
add_start_docstrings_to_model_forward,
|
30 |
+
logging,
|
31 |
+
replace_return_docstrings,
|
32 |
+
)
|
33 |
+
|
34 |
+
# Local modules
|
35 |
+
from .config import BartCustomConfig
|
36 |
+
from .encoder import BartCustomEncoder
|
37 |
+
from .decoder import BartCustomDecoder
|
38 |
+
from .custom_constants import BartConstants
|
39 |
+
from .custom_outputs import CustomSeq2SeqModelOutput
|
40 |
+
|
41 |
+
@add_start_docstrings(
|
42 |
+
"The bare BART Model outputting raw hidden-states without any specific head on top.",
|
43 |
+
BartConstants.BART_START_DOCSTRING,
|
44 |
+
)
|
45 |
+
class BartCustomModel(BartPretrainedModel):
|
46 |
+
def __init__(self, config: BartCustomConfig):
|
47 |
+
super().__init__(config)
|
48 |
+
|
49 |
+
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
50 |
+
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
51 |
+
|
52 |
+
self.encoder = BartCustomEncoder(config, self.shared)
|
53 |
+
self.decoder = BartCustomDecoder(config, self.shared)
|
54 |
+
|
55 |
+
# Initialize weights and apply final processing
|
56 |
+
self.post_init()
|
57 |
+
|
58 |
+
def get_input_embeddings(self):
|
59 |
+
return self.shared
|
60 |
+
|
61 |
+
def set_input_embeddings(self, value):
|
62 |
+
self.shared = value
|
63 |
+
self.encoder.embed_tokens = self.shared
|
64 |
+
self.decoder.embed_tokens = self.shared
|
65 |
+
|
66 |
+
def get_encoder(self):
|
67 |
+
return self.encoder
|
68 |
+
|
69 |
+
def get_decoder(self):
|
70 |
+
return self.decoder
|
71 |
+
|
72 |
+
@add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING)
|
73 |
+
@add_code_sample_docstrings(
|
74 |
+
processor_class= BartConstants.TOKENIZER_FOR_DOC,
|
75 |
+
checkpoint= BartConstants.CHECKPOINT_FOR_DOC,
|
76 |
+
output_type= Seq2SeqModelOutput,
|
77 |
+
config_class= BartConstants.CONFIG_FOR_DOC,
|
78 |
+
expected_output= BartConstants.EXPECTED_OUTPUT_SHAPE,
|
79 |
+
)
|
80 |
+
def forward(
|
81 |
+
self,
|
82 |
+
input_ids: torch.LongTensor = None,
|
83 |
+
attention_mask: Optional[torch.Tensor] = None,
|
84 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
85 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
86 |
+
head_mask: Optional[torch.Tensor] = None,
|
87 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
88 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
89 |
+
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
90 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
91 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
92 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
93 |
+
use_cache: Optional[bool] = None,
|
94 |
+
output_attentions: Optional[bool] = None,
|
95 |
+
output_hidden_states: Optional[bool] = None,
|
96 |
+
return_dict: Optional[bool] = None,
|
97 |
+
relation_inputs: Optional[torch.Tensor] = None,
|
98 |
+
) -> Union[Tuple, CustomSeq2SeqModelOutput]:
|
99 |
+
|
100 |
+
# different to other models, Bart automatically creates decoder_input_ids from
|
101 |
+
# input_ids if no decoder_input_ids are provided
|
102 |
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
103 |
+
if input_ids is None:
|
104 |
+
raise ValueError(
|
105 |
+
"If no `decoder_input_ids` or `decoder_inputs_embeds` are "
|
106 |
+
"passed, `input_ids` cannot be `None`. Please pass either "
|
107 |
+
"`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
|
108 |
+
)
|
109 |
+
|
110 |
+
decoder_input_ids = shift_tokens_right(
|
111 |
+
input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
|
112 |
+
)
|
113 |
+
|
114 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
115 |
+
output_hidden_states = (
|
116 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
117 |
+
)
|
118 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
119 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
120 |
+
|
121 |
+
if encoder_outputs is None:
|
122 |
+
encoder_outputs = self.encoder(
|
123 |
+
input_ids=input_ids,
|
124 |
+
attention_mask=attention_mask,
|
125 |
+
head_mask=head_mask,
|
126 |
+
inputs_embeds=inputs_embeds,
|
127 |
+
output_attentions=output_attentions,
|
128 |
+
output_hidden_states=output_hidden_states,
|
129 |
+
return_dict=return_dict,
|
130 |
+
relation_inputs=relation_inputs
|
131 |
+
)
|
132 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
133 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
134 |
+
encoder_outputs = BaseModelOutput(
|
135 |
+
last_hidden_state=encoder_outputs[0],
|
136 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
137 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
138 |
+
)
|
139 |
+
|
140 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
141 |
+
decoder_outputs = self.decoder(
|
142 |
+
input_ids=decoder_input_ids,
|
143 |
+
attention_mask=decoder_attention_mask,
|
144 |
+
encoder_hidden_states=encoder_outputs[0],
|
145 |
+
encoder_attention_mask=attention_mask,
|
146 |
+
head_mask=decoder_head_mask,
|
147 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
148 |
+
past_key_values=past_key_values,
|
149 |
+
inputs_embeds=decoder_inputs_embeds,
|
150 |
+
use_cache=use_cache,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
output_hidden_states=output_hidden_states,
|
153 |
+
return_dict=return_dict,
|
154 |
+
)
|
155 |
+
|
156 |
+
if not return_dict:
|
157 |
+
return decoder_outputs + encoder_outputs
|
158 |
+
|
159 |
+
return CustomSeq2SeqModelOutput(
|
160 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
161 |
+
past_key_values=decoder_outputs.past_key_values,
|
162 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
163 |
+
decoder_attentions=decoder_outputs.attentions,
|
164 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
165 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
166 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
167 |
+
encoder_attentions=encoder_outputs.attentions,
|
168 |
+
encoder_head_mask=head_mask
|
169 |
+
)
|
custom_bart/bart_onnx.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Any, Mapping, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import PreTrainedTokenizer
|
7 |
+
from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
8 |
+
from transformers.onnx.utils import compute_effective_axis_dimension
|
9 |
+
from transformers.utils.generic import TensorType
|
10 |
+
from transformers.utils.import_utils import is_torch_available
|
11 |
+
|
12 |
+
class BartCustumOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
13 |
+
@property
|
14 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
15 |
+
if self.task in ["default", "seq2seq-lm"]:
|
16 |
+
common_inputs = OrderedDict(
|
17 |
+
[
|
18 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
19 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
20 |
+
("input_commonsense_relations", {0: "batch", 1: "encoder_sequence", 2: "encoder_sequence"}),
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
if self.use_past:
|
25 |
+
common_inputs["decoder_input_ids"] = {0: "batch"}
|
26 |
+
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
27 |
+
else:
|
28 |
+
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
29 |
+
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
30 |
+
|
31 |
+
if self.use_past:
|
32 |
+
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
33 |
+
elif self.task == "causal-lm":
|
34 |
+
# TODO: figure this case out.
|
35 |
+
common_inputs = OrderedDict(
|
36 |
+
[
|
37 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
38 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
39 |
+
]
|
40 |
+
)
|
41 |
+
if self.use_past:
|
42 |
+
num_encoder_layers, _ = self.num_layers
|
43 |
+
for i in range(num_encoder_layers):
|
44 |
+
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
45 |
+
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
46 |
+
else:
|
47 |
+
common_inputs = OrderedDict(
|
48 |
+
[
|
49 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
50 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
51 |
+
("input_commonsense_relations", {0: "batch", 2: "encoder_sequence", 3: "encoder_sequence"}),
|
52 |
+
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
53 |
+
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
54 |
+
]
|
55 |
+
)
|
56 |
+
|
57 |
+
return common_inputs
|
58 |
+
|
59 |
+
@property
|
60 |
+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
61 |
+
if self.task in ["default", "seq2seq-lm"]:
|
62 |
+
common_outputs = super().outputs
|
63 |
+
else:
|
64 |
+
common_outputs = super(OnnxConfigWithPast, self).outputs
|
65 |
+
if self.use_past:
|
66 |
+
num_encoder_layers, _ = self.num_layers
|
67 |
+
for i in range(num_encoder_layers):
|
68 |
+
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
69 |
+
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
70 |
+
return common_outputs
|
71 |
+
|
72 |
+
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
73 |
+
self,
|
74 |
+
tokenizer: PreTrainedTokenizer,
|
75 |
+
batch_size: int = -1,
|
76 |
+
seq_length: int = -1,
|
77 |
+
is_pair: bool = False,
|
78 |
+
framework: Optional[TensorType] = None,
|
79 |
+
) -> Mapping[str, Any]:
|
80 |
+
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
81 |
+
tokenizer, batch_size, seq_length, is_pair, framework
|
82 |
+
)
|
83 |
+
|
84 |
+
# Generate decoder inputs
|
85 |
+
decoder_seq_length = seq_length if not self.use_past else 1
|
86 |
+
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
87 |
+
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
88 |
+
)
|
89 |
+
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
90 |
+
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
91 |
+
|
92 |
+
if self.use_past:
|
93 |
+
if not is_torch_available():
|
94 |
+
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
95 |
+
else:
|
96 |
+
import torch
|
97 |
+
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
98 |
+
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
99 |
+
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
100 |
+
encoder_shape = (
|
101 |
+
batch,
|
102 |
+
num_encoder_attention_heads,
|
103 |
+
encoder_seq_length,
|
104 |
+
self._config.hidden_size // num_encoder_attention_heads,
|
105 |
+
)
|
106 |
+
decoder_past_length = decoder_seq_length + 3
|
107 |
+
decoder_shape = (
|
108 |
+
batch,
|
109 |
+
num_decoder_attention_heads,
|
110 |
+
decoder_past_length,
|
111 |
+
self._config.hidden_size // num_decoder_attention_heads,
|
112 |
+
)
|
113 |
+
|
114 |
+
common_inputs["decoder_attention_mask"] = torch.cat(
|
115 |
+
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
116 |
+
)
|
117 |
+
|
118 |
+
common_inputs["past_key_values"] = []
|
119 |
+
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
120 |
+
num_encoder_layers, num_decoder_layers = self.num_layers
|
121 |
+
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
122 |
+
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
123 |
+
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
124 |
+
|
125 |
+
for _ in range(min_num_layers):
|
126 |
+
common_inputs["past_key_values"].append(
|
127 |
+
(
|
128 |
+
torch.zeros(decoder_shape),
|
129 |
+
torch.zeros(decoder_shape),
|
130 |
+
torch.zeros(encoder_shape),
|
131 |
+
torch.zeros(encoder_shape),
|
132 |
+
)
|
133 |
+
)
|
134 |
+
# TODO: test this.
|
135 |
+
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
136 |
+
for _ in range(min_num_layers, max_num_layers):
|
137 |
+
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
138 |
+
return common_inputs
|
139 |
+
|
140 |
+
def _generate_dummy_inputs_for_causal_lm(
|
141 |
+
self,
|
142 |
+
tokenizer: PreTrainedTokenizer,
|
143 |
+
batch_size: int = -1,
|
144 |
+
seq_length: int = -1,
|
145 |
+
is_pair: bool = False,
|
146 |
+
framework: Optional[TensorType] = None,
|
147 |
+
) -> Mapping[str, Any]:
|
148 |
+
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
149 |
+
tokenizer, batch_size, seq_length, is_pair, framework
|
150 |
+
)
|
151 |
+
|
152 |
+
if self.use_past:
|
153 |
+
if not is_torch_available():
|
154 |
+
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
155 |
+
else:
|
156 |
+
import torch
|
157 |
+
batch, seqlen = common_inputs["input_ids"].shape
|
158 |
+
# Not using the same length for past_key_values
|
159 |
+
past_key_values_length = seqlen + 2
|
160 |
+
num_encoder_layers, _ = self.num_layers
|
161 |
+
num_encoder_attention_heads, _ = self.num_attention_heads
|
162 |
+
past_shape = (
|
163 |
+
batch,
|
164 |
+
num_encoder_attention_heads,
|
165 |
+
past_key_values_length,
|
166 |
+
self._config.hidden_size // num_encoder_attention_heads,
|
167 |
+
)
|
168 |
+
|
169 |
+
mask_dtype = common_inputs["attention_mask"].dtype
|
170 |
+
common_inputs["attention_mask"] = torch.cat(
|
171 |
+
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
172 |
+
)
|
173 |
+
common_inputs["past_key_values"] = [
|
174 |
+
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
|
175 |
+
]
|
176 |
+
return common_inputs
|
177 |
+
|
178 |
+
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
179 |
+
self,
|
180 |
+
tokenizer: PreTrainedTokenizer,
|
181 |
+
batch_size: int = -1,
|
182 |
+
seq_length: int = -1,
|
183 |
+
is_pair: bool = False,
|
184 |
+
framework: Optional[TensorType] = None,
|
185 |
+
) -> Mapping[str, Any]:
|
186 |
+
# Copied from OnnxConfig.generate_dummy_inputs
|
187 |
+
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
188 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
189 |
+
batch_size = compute_effective_axis_dimension(
|
190 |
+
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
191 |
+
)
|
192 |
+
|
193 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
194 |
+
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
195 |
+
seq_length = compute_effective_axis_dimension(
|
196 |
+
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
197 |
+
)
|
198 |
+
|
199 |
+
# Generate dummy inputs according to compute batch and sequence
|
200 |
+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
201 |
+
tmp_seq_length = seq_length + 2
|
202 |
+
commonsense_relation= torch.IntTensor([[[0] * tmp_seq_length] * tmp_seq_length]* batch_size)
|
203 |
+
common_inputs = dict(tokenizer(dummy_input,
|
204 |
+
return_tensors=framework))
|
205 |
+
common_inputs['input_commonsense_relations'] = commonsense_relation
|
206 |
+
print('here:', common_inputs)
|
207 |
+
return common_inputs
|
208 |
+
|
209 |
+
def generate_dummy_inputs(
|
210 |
+
self,
|
211 |
+
tokenizer: PreTrainedTokenizer,
|
212 |
+
batch_size: int = -1,
|
213 |
+
seq_length: int = -1,
|
214 |
+
is_pair: bool = False,
|
215 |
+
framework: Optional[TensorType] = None,
|
216 |
+
) -> Mapping[str, Any]:
|
217 |
+
if self.task in ["default", "seq2seq-lm"]:
|
218 |
+
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
|
219 |
+
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
220 |
+
)
|
221 |
+
|
222 |
+
elif self.task == "causal-lm":
|
223 |
+
common_inputs = self._generate_dummy_inputs_for_causal_lm(
|
224 |
+
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
228 |
+
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
229 |
+
)
|
230 |
+
if 'decoder_input_commonsense_relations' in common_inputs:
|
231 |
+
del common_inputs['decoder_input_commonsense_relations']
|
232 |
+
return common_inputs
|
233 |
+
|
234 |
+
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
235 |
+
if self.task in ["default", "seq2seq-lm"]:
|
236 |
+
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
237 |
+
else:
|
238 |
+
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
239 |
+
flattened_output, name, idx, t
|
240 |
+
)
|
custom_bart/config.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartConfig
|
2 |
+
|
3 |
+
class BartCustomConfig(BartConfig):
|
4 |
+
def __init__(
|
5 |
+
self,
|
6 |
+
model_type='bart',
|
7 |
+
vocab_size=50265,
|
8 |
+
max_position_embeddings=1024,
|
9 |
+
encoder_layers=12,
|
10 |
+
encoder_ffn_dim=4096,
|
11 |
+
encoder_attention_heads=16,
|
12 |
+
decoder_layers=12,
|
13 |
+
decoder_ffn_dim=4096,
|
14 |
+
decoder_attention_heads=16,
|
15 |
+
encoder_layerdrop=0.0,
|
16 |
+
decoder_layerdrop=0.0,
|
17 |
+
activation_function="gelu",
|
18 |
+
d_model=1024,
|
19 |
+
dropout=0.1,
|
20 |
+
attention_dropout=0.1,
|
21 |
+
activation_dropout=0.1,
|
22 |
+
init_std=0.02,
|
23 |
+
classifier_dropout=0.0,
|
24 |
+
classif_dropout=0.1,
|
25 |
+
scale_embedding=False,
|
26 |
+
use_cache=True,
|
27 |
+
num_labels=3,
|
28 |
+
pad_token_id=1,
|
29 |
+
bos_token_id=0,
|
30 |
+
eos_token_id=2,
|
31 |
+
is_encoder_decoder=True,
|
32 |
+
decoder_start_token_id=2,
|
33 |
+
forced_eos_token_id=2,
|
34 |
+
forced_bos_token_id=0,
|
35 |
+
no_repeat_ngram_size=3, # adding
|
36 |
+
num_hidden_layers=12,
|
37 |
+
normalize_before=False,
|
38 |
+
num_beams=4,
|
39 |
+
add_bias_logits=False,
|
40 |
+
add_final_layer_norm=False,
|
41 |
+
early_stopping=True,
|
42 |
+
gradient_checkpointing=False,
|
43 |
+
num_relation_kinds = 0,
|
44 |
+
use_same_relation_kv_emb = True,
|
45 |
+
is_simple_mask_commonsense = False,
|
46 |
+
should_embed_positions = False,
|
47 |
+
heads_mask = None,
|
48 |
+
**kwargs
|
49 |
+
):
|
50 |
+
super(BartCustomConfig, self).__init__(
|
51 |
+
model_type=model_type,
|
52 |
+
vocab_size=vocab_size,
|
53 |
+
max_position_embeddings=max_position_embeddings,
|
54 |
+
encoder_layers=encoder_layers,
|
55 |
+
encoder_ffn_dim=encoder_ffn_dim,
|
56 |
+
encoder_attention_heads=encoder_attention_heads,
|
57 |
+
decoder_layers=decoder_layers,
|
58 |
+
decoder_ffn_dim=decoder_ffn_dim,
|
59 |
+
decoder_attention_heads=decoder_attention_heads,
|
60 |
+
encoder_layerdrop=encoder_layerdrop,
|
61 |
+
decoder_layerdrop=decoder_layerdrop,
|
62 |
+
activation_function=activation_function,
|
63 |
+
d_model=d_model,
|
64 |
+
dropout=dropout,
|
65 |
+
attention_dropout=attention_dropout,
|
66 |
+
activation_dropout=activation_dropout,
|
67 |
+
init_std=init_std,
|
68 |
+
classifier_dropout=classifier_dropout,
|
69 |
+
classif_dropout=classif_dropout,
|
70 |
+
scale_embedding=scale_embedding,
|
71 |
+
use_cache=use_cache,
|
72 |
+
num_labels=num_labels,
|
73 |
+
pad_token_id = pad_token_id,
|
74 |
+
bos_token_id = bos_token_id,
|
75 |
+
eos_token_id = eos_token_id,
|
76 |
+
is_encoder_decoder = is_encoder_decoder,
|
77 |
+
decoder_start_token_id = decoder_start_token_id,
|
78 |
+
forced_eos_token_id = forced_eos_token_id,
|
79 |
+
forced_bos_token_id=forced_bos_token_id,
|
80 |
+
no_repeat_ngram_size=no_repeat_ngram_size, # Adding
|
81 |
+
normalize_before=normalize_before,
|
82 |
+
num_hidden_layers=num_hidden_layers,
|
83 |
+
num_beams=num_beams,
|
84 |
+
add_bias_logits=add_bias_logits,
|
85 |
+
add_final_layer_norm=add_final_layer_norm,
|
86 |
+
early_stopping=early_stopping,
|
87 |
+
gradient_checkpointing=gradient_checkpointing,
|
88 |
+
num_relation_kinds = num_relation_kinds,
|
89 |
+
use_same_relation_kv_emb = use_same_relation_kv_emb,
|
90 |
+
is_simple_mask_commonsense = is_simple_mask_commonsense,
|
91 |
+
heads_mask = None,
|
92 |
+
should_embed_positions=False,
|
93 |
+
**kwargs
|
94 |
+
)
|
95 |
+
self.num_relation_kinds = num_relation_kinds
|
96 |
+
self.use_same_relation_kv_emb = use_same_relation_kv_emb
|
97 |
+
self.is_simple_mask_commonsense = is_simple_mask_commonsense
|
98 |
+
self.heads_mask = heads_mask
|
99 |
+
self.should_embed_positions = should_embed_positions
|
100 |
+
|
101 |
+
class BartSmallCustomConfig(BartConfig):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
vocab_size=50265,
|
105 |
+
max_position_embeddings=1024,
|
106 |
+
encoder_layers=6,
|
107 |
+
encoder_ffn_dim=3072,
|
108 |
+
encoder_attention_heads=12,
|
109 |
+
decoder_layers=12,
|
110 |
+
decoder_ffn_dim=3072,
|
111 |
+
decoder_attention_heads=12,
|
112 |
+
encoder_layerdrop=0.0,
|
113 |
+
decoder_layerdrop=0.0,
|
114 |
+
activation_function="gelu",
|
115 |
+
d_model=768,
|
116 |
+
dropout=0.1,
|
117 |
+
attention_dropout=0.1,
|
118 |
+
activation_dropout=0.1,
|
119 |
+
init_std=0.02,
|
120 |
+
classifier_dropout=0.0,
|
121 |
+
classif_dropout= 0.1,
|
122 |
+
scale_embedding=False,
|
123 |
+
use_cache=True,
|
124 |
+
num_labels=3,
|
125 |
+
pad_token_id=1,
|
126 |
+
bos_token_id=0,
|
127 |
+
eos_token_id=2,
|
128 |
+
is_encoder_decoder=True,
|
129 |
+
decoder_start_token_id=2,
|
130 |
+
forced_eos_token_id=2,
|
131 |
+
forced_bos_token_id=0,
|
132 |
+
no_repeat_ngram_size=3, #adding
|
133 |
+
num_hidden_layers=6,
|
134 |
+
normalize_before=False,
|
135 |
+
num_beams=4,
|
136 |
+
add_bias_logits=False,
|
137 |
+
add_final_layer_norm=False,
|
138 |
+
_name_or_path="bart-base",
|
139 |
+
early_stopping=True,
|
140 |
+
gradient_checkpointing=False,
|
141 |
+
num_relation_kinds = 0,
|
142 |
+
use_same_relation_kv_emb = True,
|
143 |
+
is_simple_mask_commonsense = False,
|
144 |
+
should_embed_positions = True,
|
145 |
+
heads_mask = None,
|
146 |
+
**kwargs
|
147 |
+
):
|
148 |
+
super(BartSmallCustomConfig, self).__init__(
|
149 |
+
vocab_size=vocab_size,
|
150 |
+
max_position_embeddings=max_position_embeddings,
|
151 |
+
encoder_layers=encoder_layers,
|
152 |
+
encoder_ffn_dim=encoder_ffn_dim,
|
153 |
+
encoder_attention_heads=encoder_attention_heads,
|
154 |
+
decoder_layers=decoder_layers,
|
155 |
+
decoder_ffn_dim=decoder_ffn_dim,
|
156 |
+
decoder_attention_heads=decoder_attention_heads,
|
157 |
+
encoder_layerdrop=encoder_layerdrop,
|
158 |
+
decoder_layerdrop=decoder_layerdrop,
|
159 |
+
activation_function=activation_function,
|
160 |
+
d_model=d_model,
|
161 |
+
dropout=dropout,
|
162 |
+
attention_dropout=attention_dropout,
|
163 |
+
activation_dropout=activation_dropout,
|
164 |
+
init_std=init_std,
|
165 |
+
classifier_dropout=classifier_dropout,
|
166 |
+
classif_dropout=classif_dropout,
|
167 |
+
scale_embedding=scale_embedding,
|
168 |
+
use_cache=use_cache,
|
169 |
+
num_labels=num_labels,
|
170 |
+
pad_token_id = pad_token_id,
|
171 |
+
bos_token_id = bos_token_id,
|
172 |
+
eos_token_id = eos_token_id,
|
173 |
+
is_encoder_decoder = is_encoder_decoder,
|
174 |
+
decoder_start_token_id = decoder_start_token_id,
|
175 |
+
forced_eos_token_id = forced_eos_token_id,
|
176 |
+
forced_bos_token_id=forced_bos_token_id,
|
177 |
+
no_repeat_ngram_size = no_repeat_ngram_size, #Adding
|
178 |
+
normalize_before = normalize_before,
|
179 |
+
num_hidden_layers=num_hidden_layers,
|
180 |
+
num_beams=num_beams,
|
181 |
+
add_bias_logits=add_bias_logits,
|
182 |
+
add_final_layer_norm=add_final_layer_norm,
|
183 |
+
_name_or_path=_name_or_path,
|
184 |
+
early_stopping=early_stopping,
|
185 |
+
gradient_checkpointing=gradient_checkpointing,
|
186 |
+
num_relation_kinds = num_relation_kinds,
|
187 |
+
use_same_relation_kv_emb = use_same_relation_kv_emb,
|
188 |
+
is_simple_mask_commonsense = is_simple_mask_commonsense,
|
189 |
+
heads_mask = heads_mask,
|
190 |
+
should_embed_positions=should_embed_positions,
|
191 |
+
**kwargs
|
192 |
+
)
|
193 |
+
self.num_relation_kinds = num_relation_kinds
|
194 |
+
self.use_same_relation_kv_emb = use_same_relation_kv_emb
|
195 |
+
self.is_simple_mask_commonsense = is_simple_mask_commonsense
|
196 |
+
self.heads_mask = heads_mask
|
197 |
+
self.should_embed_positions = should_embed_positions
|
custom_bart/custom_constants.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class BartConstants:
|
3 |
+
CHECKPOINT_FOR_DOC = "facebook/bart-base"
|
4 |
+
CONFIG_FOR_DOC = "BartConfig"
|
5 |
+
TOKENIZER_FOR_DOC = "BartTokenizer"
|
6 |
+
|
7 |
+
# Base model docstring
|
8 |
+
EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
|
9 |
+
|
10 |
+
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
11 |
+
"facebook/bart-large",
|
12 |
+
]
|
13 |
+
|
14 |
+
BART_START_DOCSTRING = r"""
|
15 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
16 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
17 |
+
etc.)
|
18 |
+
|
19 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
20 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
21 |
+
and behavior.
|
22 |
+
|
23 |
+
Parameters:
|
24 |
+
config ([`BartConfig`]):
|
25 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
26 |
+
load the weights associated with the model, only the configuration. Check out the
|
27 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
28 |
+
"""
|
29 |
+
BART_INPUTS_DOCSTRING = r"""
|
30 |
+
Args:
|
31 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
32 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
33 |
+
it.
|
34 |
+
|
35 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
36 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
37 |
+
|
38 |
+
[What are input IDs?](../glossary#input-ids)
|
39 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
40 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
41 |
+
|
42 |
+
- 1 for tokens that are **not masked**,
|
43 |
+
- 0 for tokens that are **masked**.
|
44 |
+
|
45 |
+
[What are attention masks?](../glossary#attention-mask)
|
46 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
47 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
48 |
+
|
49 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
50 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
51 |
+
|
52 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
53 |
+
|
54 |
+
Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
55 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
56 |
+
|
57 |
+
For translation and summarization training, `decoder_input_ids` should be provided. If no
|
58 |
+
`decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
|
59 |
+
for denoising pre-training following the paper.
|
60 |
+
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
61 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
62 |
+
be used by default.
|
63 |
+
|
64 |
+
If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and
|
65 |
+
modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information
|
66 |
+
on the default strategy.
|
67 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
68 |
+
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
69 |
+
|
70 |
+
- 1 indicates the head is **not masked**,
|
71 |
+
- 0 indicates the head is **masked**.
|
72 |
+
|
73 |
+
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
74 |
+
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
75 |
+
|
76 |
+
- 1 indicates the head is **not masked**,
|
77 |
+
- 0 indicates the head is **masked**.
|
78 |
+
|
79 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
80 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
|
81 |
+
1]`:
|
82 |
+
|
83 |
+
- 1 indicates the head is **not masked**,
|
84 |
+
- 0 indicates the head is **masked**.
|
85 |
+
|
86 |
+
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
87 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
88 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
89 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
90 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
91 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
92 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
93 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
94 |
+
|
95 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
96 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
97 |
+
|
98 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
99 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
100 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
|
101 |
+
`(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
|
102 |
+
can choose to directly pass an embedded representation. This is useful if you want more control over how to
|
103 |
+
convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
104 |
+
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
105 |
+
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
106 |
+
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
107 |
+
input (see `past_key_values`). This is useful if you want more control over how to convert
|
108 |
+
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
109 |
+
|
110 |
+
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
111 |
+
of `inputs_embeds`.
|
112 |
+
use_cache (`bool`, *optional*):
|
113 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
114 |
+
`past_key_values`).
|
115 |
+
output_attentions (`bool`, *optional*):
|
116 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
117 |
+
tensors for more detail.
|
118 |
+
output_hidden_states (`bool`, *optional*):
|
119 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
120 |
+
more detail.
|
121 |
+
return_dict (`bool`, *optional*):
|
122 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
123 |
+
"""
|
124 |
+
BART_GENERATION_EXAMPLE = r"""
|
125 |
+
Summarization example:
|
126 |
+
|
127 |
+
```python
|
128 |
+
>>> from transformers import BartTokenizer, BartForConditionalGeneration
|
129 |
+
|
130 |
+
>>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
|
131 |
+
>>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
132 |
+
|
133 |
+
>>> ARTICLE_TO_SUMMARIZE = (
|
134 |
+
... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
|
135 |
+
... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
|
136 |
+
... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
|
137 |
+
... )
|
138 |
+
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
|
139 |
+
|
140 |
+
>>> # Generate Summary
|
141 |
+
>>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
|
142 |
+
>>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
143 |
+
'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
|
144 |
+
```
|
145 |
+
|
146 |
+
Mask filling example:
|
147 |
+
|
148 |
+
```python
|
149 |
+
>>> from transformers import BartTokenizer, BartForConditionalGeneration
|
150 |
+
|
151 |
+
>>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
152 |
+
>>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
153 |
+
|
154 |
+
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
155 |
+
>>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
|
156 |
+
>>> logits = model(input_ids).logits
|
157 |
+
|
158 |
+
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
159 |
+
>>> probs = logits[0, masked_index].softmax(dim=0)
|
160 |
+
>>> values, predictions = probs.topk(5)
|
161 |
+
|
162 |
+
>>> tokenizer.decode(predictions).split()
|
163 |
+
['not', 'good', 'healthy', 'great', 'very']
|
164 |
+
```
|
165 |
+
"""
|
166 |
+
|
167 |
+
|
168 |
+
|
custom_bart/custom_outputs.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
# Remote modules
|
10 |
+
import torch
|
11 |
+
from transformers.modeling_outputs import ModelOutput
|
12 |
+
|
13 |
+
# Local modules
|
14 |
+
|
15 |
+
#############################
|
16 |
+
# Constants
|
17 |
+
#############################
|
18 |
+
|
19 |
+
#############################
|
20 |
+
# Stuff
|
21 |
+
#############################
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class CustomSeq2SeqLMOutput(ModelOutput):
|
25 |
+
"""
|
26 |
+
Base class for sequence-to-sequence language models outputs.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
30 |
+
Language modeling loss.
|
31 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
32 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
33 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
34 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
35 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
36 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
37 |
+
|
38 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
39 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
40 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
41 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
42 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
43 |
+
|
44 |
+
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
45 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
46 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
47 |
+
sequence_length)`.
|
48 |
+
|
49 |
+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
50 |
+
self-attention heads.
|
51 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
52 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
53 |
+
sequence_length)`.
|
54 |
+
|
55 |
+
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
56 |
+
weighted average in the cross-attention heads.
|
57 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
58 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
59 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
60 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
61 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
62 |
+
|
63 |
+
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
64 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
65 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
66 |
+
sequence_length)`.
|
67 |
+
|
68 |
+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
69 |
+
self-attention heads.
|
70 |
+
"""
|
71 |
+
|
72 |
+
loss: Optional[torch.FloatTensor] = None
|
73 |
+
logits: torch.FloatTensor = None
|
74 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
75 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
76 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
77 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
78 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
79 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
80 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
81 |
+
head_mask: Optional[Tuple[torch.FloatTensor]] = None
|
82 |
+
|
83 |
+
@dataclass
|
84 |
+
class CustomSeq2SeqModelOutput(ModelOutput):
|
85 |
+
"""
|
86 |
+
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
|
87 |
+
decoding.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
91 |
+
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
92 |
+
|
93 |
+
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
|
94 |
+
hidden_size)` is output.
|
95 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
96 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
97 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
98 |
+
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
99 |
+
|
100 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
101 |
+
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
102 |
+
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
103 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
104 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
105 |
+
|
106 |
+
Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
|
107 |
+
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
108 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
109 |
+
sequence_length)`.
|
110 |
+
|
111 |
+
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
112 |
+
self-attention heads.
|
113 |
+
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
114 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
115 |
+
sequence_length)`.
|
116 |
+
|
117 |
+
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
118 |
+
weighted average in the cross-attention heads.
|
119 |
+
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
120 |
+
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
121 |
+
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
122 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
123 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
124 |
+
|
125 |
+
Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
|
126 |
+
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
127 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
128 |
+
sequence_length)`.
|
129 |
+
|
130 |
+
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
131 |
+
self-attention heads.
|
132 |
+
"""
|
133 |
+
|
134 |
+
last_hidden_state: torch.FloatTensor = None
|
135 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
136 |
+
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
137 |
+
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
138 |
+
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
139 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
140 |
+
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
141 |
+
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
142 |
+
encoder_head_mask: Optional[Tuple[torch.FloatTensor]] = None
|
custom_bart/decoder.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import (
|
7 |
+
Optional,
|
8 |
+
Tuple,
|
9 |
+
Union,
|
10 |
+
List,
|
11 |
+
)
|
12 |
+
import math
|
13 |
+
import random
|
14 |
+
|
15 |
+
# Remote modules
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from transformers import (
|
19 |
+
BartConfig,
|
20 |
+
BartPretrainedModel,
|
21 |
+
)
|
22 |
+
from transformers.modeling_outputs import (
|
23 |
+
BaseModelOutput,
|
24 |
+
BaseModelOutputWithPastAndCrossAttentions
|
25 |
+
)
|
26 |
+
from transformers.models.bart.modeling_bart import (
|
27 |
+
BartLearnedPositionalEmbedding,
|
28 |
+
_expand_mask,
|
29 |
+
_make_causal_mask
|
30 |
+
)
|
31 |
+
from transformers.utils import (
|
32 |
+
logging,
|
33 |
+
)
|
34 |
+
|
35 |
+
# Local modules
|
36 |
+
from .config import BartCustomConfig
|
37 |
+
from .decoder_layer import BartCustomDecoderLayer
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
class BartCustomDecoder(BartPretrainedModel):
|
42 |
+
"""
|
43 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
|
44 |
+
|
45 |
+
Args:
|
46 |
+
config: BartConfig
|
47 |
+
embed_tokens (nn.Embedding): output embedding
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, config: BartCustomConfig, embed_tokens: Optional[nn.Embedding] = None):
|
51 |
+
super().__init__(config)
|
52 |
+
self.dropout = config.dropout
|
53 |
+
self.layerdrop = config.decoder_layerdrop
|
54 |
+
self.padding_idx = config.pad_token_id
|
55 |
+
self.max_target_positions = config.max_position_embeddings
|
56 |
+
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
57 |
+
|
58 |
+
if embed_tokens is not None:
|
59 |
+
self.embed_tokens = embed_tokens
|
60 |
+
else:
|
61 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
62 |
+
|
63 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
64 |
+
config.max_position_embeddings,
|
65 |
+
config.d_model,
|
66 |
+
)
|
67 |
+
self.layers = nn.ModuleList([BartCustomDecoderLayer(config) for _ in range(config.decoder_layers)])
|
68 |
+
self.layernorm_embedding = nn.LayerNorm(config.d_model)
|
69 |
+
|
70 |
+
self.gradient_checkpointing = False
|
71 |
+
# Initialize weights and apply final processing
|
72 |
+
self.post_init()
|
73 |
+
|
74 |
+
def get_input_embeddings(self):
|
75 |
+
return self.embed_tokens
|
76 |
+
|
77 |
+
def set_input_embeddings(self, value):
|
78 |
+
self.embed_tokens = value
|
79 |
+
|
80 |
+
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
81 |
+
# create causal mask
|
82 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
83 |
+
combined_attention_mask = None
|
84 |
+
if input_shape[-1] > 1:
|
85 |
+
combined_attention_mask = _make_causal_mask(
|
86 |
+
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
87 |
+
).to(self.device)
|
88 |
+
|
89 |
+
if attention_mask is not None:
|
90 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
91 |
+
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
92 |
+
combined_attention_mask = (
|
93 |
+
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
94 |
+
)
|
95 |
+
|
96 |
+
return combined_attention_mask
|
97 |
+
|
98 |
+
def forward(
|
99 |
+
self,
|
100 |
+
input_ids: torch.LongTensor = None,
|
101 |
+
attention_mask: Optional[torch.Tensor] = None,
|
102 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
103 |
+
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
104 |
+
head_mask: Optional[torch.Tensor] = None,
|
105 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
106 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
107 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
108 |
+
use_cache: Optional[bool] = None,
|
109 |
+
output_attentions: Optional[bool] = None,
|
110 |
+
output_hidden_states: Optional[bool] = None,
|
111 |
+
return_dict: Optional[bool] = None,
|
112 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
113 |
+
r"""
|
114 |
+
Args:
|
115 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
116 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
117 |
+
provide it.
|
118 |
+
|
119 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
120 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
121 |
+
|
122 |
+
[What are input IDs?](../glossary#input-ids)
|
123 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
124 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
125 |
+
|
126 |
+
- 1 for tokens that are **not masked**,
|
127 |
+
- 0 for tokens that are **masked**.
|
128 |
+
|
129 |
+
[What are attention masks?](../glossary#attention-mask)
|
130 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
131 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
132 |
+
of the decoder.
|
133 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
134 |
+
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
135 |
+
selected in `[0, 1]`:
|
136 |
+
|
137 |
+
- 1 for tokens that are **not masked**,
|
138 |
+
- 0 for tokens that are **masked**.
|
139 |
+
|
140 |
+
[What are attention masks?](../glossary#attention-mask)
|
141 |
+
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
142 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
143 |
+
|
144 |
+
- 1 indicates the head is **not masked**,
|
145 |
+
- 0 indicates the head is **masked**.
|
146 |
+
|
147 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
148 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
149 |
+
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
150 |
+
|
151 |
+
- 1 indicates the head is **not masked**,
|
152 |
+
- 0 indicates the head is **masked**.
|
153 |
+
|
154 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
155 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
156 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
157 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
158 |
+
|
159 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
160 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
161 |
+
|
162 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
163 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
164 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
165 |
+
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
166 |
+
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
167 |
+
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
168 |
+
embedding lookup matrix.
|
169 |
+
output_attentions (`bool`, *optional*):
|
170 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
171 |
+
returned tensors for more detail.
|
172 |
+
output_hidden_states (`bool`, *optional*):
|
173 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
174 |
+
for more detail.
|
175 |
+
return_dict (`bool`, *optional*):
|
176 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
177 |
+
"""
|
178 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
179 |
+
output_hidden_states = (
|
180 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
181 |
+
)
|
182 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
183 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
184 |
+
|
185 |
+
# retrieve input_ids and inputs_embeds
|
186 |
+
if input_ids is not None and inputs_embeds is not None:
|
187 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
188 |
+
elif input_ids is not None:
|
189 |
+
input_shape = input_ids.size()
|
190 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
191 |
+
elif inputs_embeds is not None:
|
192 |
+
input_shape = inputs_embeds.size()[:-1]
|
193 |
+
else:
|
194 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
195 |
+
|
196 |
+
# past_key_values_length
|
197 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
198 |
+
|
199 |
+
if inputs_embeds is None:
|
200 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
201 |
+
|
202 |
+
attention_mask = self._prepare_decoder_attention_mask(
|
203 |
+
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
204 |
+
)
|
205 |
+
|
206 |
+
# expand encoder attention mask
|
207 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
208 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
209 |
+
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
210 |
+
|
211 |
+
# embed positions
|
212 |
+
positions = self.embed_positions(input_shape, past_key_values_length)
|
213 |
+
|
214 |
+
hidden_states = inputs_embeds + positions
|
215 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
216 |
+
|
217 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
218 |
+
|
219 |
+
# decoder layers
|
220 |
+
all_hidden_states = () if output_hidden_states else None
|
221 |
+
all_self_attns = () if output_attentions else None
|
222 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
223 |
+
next_decoder_cache = () if use_cache else None
|
224 |
+
|
225 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
226 |
+
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
227 |
+
if attn_mask is not None:
|
228 |
+
if attn_mask.size()[0] != (len(self.layers)):
|
229 |
+
raise ValueError(
|
230 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
231 |
+
)
|
232 |
+
|
233 |
+
for idx, decoder_layer in enumerate(self.layers):
|
234 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
235 |
+
if output_hidden_states:
|
236 |
+
all_hidden_states += (hidden_states,)
|
237 |
+
dropout_probability = random.uniform(0, 1)
|
238 |
+
if self.training and (dropout_probability < self.layerdrop):
|
239 |
+
continue
|
240 |
+
|
241 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
242 |
+
|
243 |
+
if self.gradient_checkpointing and self.training:
|
244 |
+
|
245 |
+
if use_cache:
|
246 |
+
logger.warning(
|
247 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
248 |
+
)
|
249 |
+
use_cache = False
|
250 |
+
|
251 |
+
def create_custom_forward(module):
|
252 |
+
def custom_forward(*inputs):
|
253 |
+
# None for past_key_value
|
254 |
+
return module(*inputs, output_attentions, use_cache)
|
255 |
+
|
256 |
+
return custom_forward
|
257 |
+
|
258 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
259 |
+
create_custom_forward(decoder_layer),
|
260 |
+
hidden_states,
|
261 |
+
attention_mask,
|
262 |
+
encoder_hidden_states,
|
263 |
+
encoder_attention_mask,
|
264 |
+
head_mask[idx] if head_mask is not None else None,
|
265 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
266 |
+
None,
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
|
270 |
+
layer_outputs = decoder_layer(
|
271 |
+
hidden_states,
|
272 |
+
attention_mask=attention_mask,
|
273 |
+
encoder_hidden_states=encoder_hidden_states,
|
274 |
+
encoder_attention_mask=encoder_attention_mask,
|
275 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
276 |
+
cross_attn_layer_head_mask=(
|
277 |
+
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
278 |
+
),
|
279 |
+
past_key_value=past_key_value,
|
280 |
+
output_attentions=output_attentions,
|
281 |
+
use_cache=use_cache,
|
282 |
+
)
|
283 |
+
hidden_states = layer_outputs[0]
|
284 |
+
|
285 |
+
if use_cache:
|
286 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
287 |
+
|
288 |
+
if output_attentions:
|
289 |
+
all_self_attns += (layer_outputs[1],)
|
290 |
+
|
291 |
+
if encoder_hidden_states is not None:
|
292 |
+
all_cross_attentions += (layer_outputs[2],)
|
293 |
+
|
294 |
+
# add hidden states from the last decoder layer
|
295 |
+
if output_hidden_states:
|
296 |
+
all_hidden_states += (hidden_states,)
|
297 |
+
|
298 |
+
next_cache = next_decoder_cache if use_cache else None
|
299 |
+
if not return_dict:
|
300 |
+
return tuple(
|
301 |
+
v
|
302 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
303 |
+
if v is not None
|
304 |
+
)
|
305 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
306 |
+
last_hidden_state=hidden_states,
|
307 |
+
past_key_values=next_cache,
|
308 |
+
hidden_states=all_hidden_states,
|
309 |
+
attentions=all_self_attns,
|
310 |
+
cross_attentions=all_cross_attentions,
|
311 |
+
)
|
312 |
+
|
custom_bart/decoder_layer.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
# Remote modules
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from transformers import BartConfig
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
|
14 |
+
# Local modules
|
15 |
+
from transformers.models.bart.modeling_bart import BartAttention
|
16 |
+
|
17 |
+
from .config import BartCustomConfig
|
18 |
+
|
19 |
+
|
20 |
+
class BartCustomDecoderLayer(nn.Module):
|
21 |
+
def __init__(self, config: BartCustomConfig):
|
22 |
+
super().__init__()
|
23 |
+
self.embed_dim = config.d_model
|
24 |
+
|
25 |
+
self.self_attn = BartAttention(
|
26 |
+
embed_dim=self.embed_dim,
|
27 |
+
num_heads=config.decoder_attention_heads,
|
28 |
+
dropout=config.attention_dropout,
|
29 |
+
is_decoder=True,
|
30 |
+
)
|
31 |
+
self.dropout = config.dropout
|
32 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
33 |
+
self.activation_dropout = config.activation_dropout
|
34 |
+
|
35 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
36 |
+
self.encoder_attn = BartAttention(
|
37 |
+
self.embed_dim,
|
38 |
+
config.decoder_attention_heads,
|
39 |
+
dropout=config.attention_dropout,
|
40 |
+
is_decoder=True,
|
41 |
+
)
|
42 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
43 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
44 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
45 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
46 |
+
|
47 |
+
def forward(
|
48 |
+
self,
|
49 |
+
hidden_states: torch.Tensor,
|
50 |
+
attention_mask: Optional[torch.Tensor] = None,
|
51 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
52 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
53 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
54 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
55 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
56 |
+
output_attentions: Optional[bool] = False,
|
57 |
+
use_cache: Optional[bool] = True,
|
58 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
62 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
63 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
64 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
65 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
66 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
67 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
68 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
69 |
+
`(encoder_attention_heads,)`.
|
70 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
71 |
+
size `(decoder_attention_heads,)`.
|
72 |
+
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
73 |
+
output_attentions (`bool`, *optional*):
|
74 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
75 |
+
returned tensors for more detail.
|
76 |
+
"""
|
77 |
+
residual = hidden_states
|
78 |
+
|
79 |
+
# Self Attention
|
80 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
81 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
82 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
83 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
84 |
+
hidden_states=hidden_states,
|
85 |
+
past_key_value=self_attn_past_key_value,
|
86 |
+
attention_mask=attention_mask,
|
87 |
+
layer_head_mask=layer_head_mask,
|
88 |
+
output_attentions=output_attentions,
|
89 |
+
)
|
90 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
91 |
+
hidden_states = residual + hidden_states
|
92 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
93 |
+
|
94 |
+
# Cross-Attention Block
|
95 |
+
cross_attn_present_key_value = None
|
96 |
+
cross_attn_weights = None
|
97 |
+
if encoder_hidden_states is not None:
|
98 |
+
residual = hidden_states
|
99 |
+
|
100 |
+
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
101 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
102 |
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
103 |
+
hidden_states=hidden_states,
|
104 |
+
key_value_states=encoder_hidden_states,
|
105 |
+
attention_mask=encoder_attention_mask,
|
106 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
107 |
+
past_key_value=cross_attn_past_key_value,
|
108 |
+
output_attentions=output_attentions,
|
109 |
+
)
|
110 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
111 |
+
hidden_states = residual + hidden_states
|
112 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
113 |
+
|
114 |
+
# add cross-attn to positions 3,4 of present_key_value tuple
|
115 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
116 |
+
|
117 |
+
# Fully Connected
|
118 |
+
residual = hidden_states
|
119 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
120 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
121 |
+
hidden_states = self.fc2(hidden_states)
|
122 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
123 |
+
hidden_states = residual + hidden_states
|
124 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
125 |
+
|
126 |
+
outputs = (hidden_states,)
|
127 |
+
|
128 |
+
if output_attentions:
|
129 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
130 |
+
|
131 |
+
if use_cache:
|
132 |
+
outputs += (present_key_value,)
|
133 |
+
|
134 |
+
return outputs
|
custom_bart/encoder.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import (
|
7 |
+
Optional,
|
8 |
+
Tuple,
|
9 |
+
Union,
|
10 |
+
)
|
11 |
+
import math
|
12 |
+
import random
|
13 |
+
|
14 |
+
# Remote modules
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from transformers import (
|
18 |
+
BartConfig,
|
19 |
+
BartPretrainedModel,
|
20 |
+
)
|
21 |
+
from transformers.modeling_outputs import BaseModelOutput
|
22 |
+
from transformers.models.bart.modeling_bart import (
|
23 |
+
BartLearnedPositionalEmbedding,
|
24 |
+
_expand_mask
|
25 |
+
)
|
26 |
+
|
27 |
+
# Local modules
|
28 |
+
from .config import BartCustomConfig
|
29 |
+
from .encoder_layer import BartCustomEncoderLayer
|
30 |
+
|
31 |
+
|
32 |
+
class BartCustomEncoder(BartPretrainedModel):
|
33 |
+
"""
|
34 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
35 |
+
[`BartEncoderLayer`].
|
36 |
+
|
37 |
+
Args:
|
38 |
+
config: BartConfig
|
39 |
+
embed_tokens (nn.Embedding): output embedding
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, config: BartCustomConfig, embed_tokens: Optional[nn.Embedding] = None):
|
43 |
+
super().__init__(config)
|
44 |
+
|
45 |
+
self.dropout = config.dropout
|
46 |
+
self.layerdrop = config.encoder_layerdrop
|
47 |
+
|
48 |
+
embed_dim = config.d_model
|
49 |
+
self.padding_idx = config.pad_token_id
|
50 |
+
self.max_source_positions = config.max_position_embeddings
|
51 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
52 |
+
|
53 |
+
if embed_tokens is not None:
|
54 |
+
self.embed_tokens = embed_tokens
|
55 |
+
else:
|
56 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
57 |
+
|
58 |
+
if not config.should_embed_positions:
|
59 |
+
self.embed_positions = None
|
60 |
+
else:
|
61 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
62 |
+
config.max_position_embeddings,
|
63 |
+
embed_dim,
|
64 |
+
)
|
65 |
+
device = self.device
|
66 |
+
self.layers = nn.ModuleList([BartCustomEncoderLayer(config, heads_mask=torch.Tensor(config.heads_mask[i]).to(device))
|
67 |
+
for i in range(config.encoder_layers)])
|
68 |
+
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
69 |
+
|
70 |
+
self.gradient_checkpointing = False
|
71 |
+
# Initialize weights and apply final processing
|
72 |
+
self.post_init()
|
73 |
+
self.run_config = config
|
74 |
+
|
75 |
+
def get_input_embeddings(self):
|
76 |
+
return self.embed_tokens
|
77 |
+
|
78 |
+
def set_input_embeddings(self, value):
|
79 |
+
self.embed_tokens = value
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self,
|
83 |
+
input_ids: torch.LongTensor = None,
|
84 |
+
attention_mask: Optional[torch.Tensor] = None,
|
85 |
+
head_mask: Optional[torch.Tensor] = None,
|
86 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
87 |
+
output_attentions: Optional[bool] = None,
|
88 |
+
output_hidden_states: Optional[bool] = None,
|
89 |
+
return_dict: Optional[bool] = None,
|
90 |
+
relation_inputs: Optional[torch.Tensor] = None,
|
91 |
+
) -> Union[Tuple, BaseModelOutput]:
|
92 |
+
r"""
|
93 |
+
Args:
|
94 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
95 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
96 |
+
provide it.
|
97 |
+
|
98 |
+
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
99 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
100 |
+
|
101 |
+
[What are input IDs?](../glossary#input-ids)
|
102 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
103 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
104 |
+
|
105 |
+
- 1 for tokens that are **not masked**,
|
106 |
+
- 0 for tokens that are **masked**.
|
107 |
+
|
108 |
+
[What are attention masks?](../glossary#attention-mask)
|
109 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
110 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
111 |
+
|
112 |
+
- 1 indicates the head is **not masked**,
|
113 |
+
- 0 indicates the head is **masked**.
|
114 |
+
|
115 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
116 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
117 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
118 |
+
than the model's internal embedding lookup matrix.
|
119 |
+
output_attentions (`bool`, *optional*):
|
120 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
121 |
+
returned tensors for more detail.
|
122 |
+
output_hidden_states (`bool`, *optional*):
|
123 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
124 |
+
for more detail.
|
125 |
+
return_dict (`bool`, *optional*):
|
126 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
127 |
+
"""
|
128 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
129 |
+
output_hidden_states = (
|
130 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
131 |
+
)
|
132 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
133 |
+
|
134 |
+
# retrieve input_ids and inputs_embeds
|
135 |
+
if input_ids is not None and inputs_embeds is not None:
|
136 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
137 |
+
elif input_ids is not None:
|
138 |
+
input_shape = input_ids.size()
|
139 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
140 |
+
elif inputs_embeds is not None:
|
141 |
+
input_shape = inputs_embeds.size()[:-1]
|
142 |
+
else:
|
143 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
144 |
+
|
145 |
+
if inputs_embeds is None:
|
146 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
147 |
+
|
148 |
+
# Important for datasets which the order of words deoes not matter(eg: commongen)
|
149 |
+
if self.run_config.should_embed_positions:
|
150 |
+
embed_pos = self.embed_positions(input_shape)
|
151 |
+
hidden_states = inputs_embeds + embed_pos
|
152 |
+
else:
|
153 |
+
hidden_states = inputs_embeds
|
154 |
+
|
155 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
156 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
157 |
+
|
158 |
+
# expand attention_mask
|
159 |
+
if attention_mask is not None:
|
160 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
161 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
162 |
+
|
163 |
+
encoder_states = () if output_hidden_states else None
|
164 |
+
all_attentions = () if output_attentions else None
|
165 |
+
|
166 |
+
# check if head_mask has a correct number of layers specified if desired
|
167 |
+
if head_mask is not None:
|
168 |
+
if head_mask.size()[0] != (len(self.layers)):
|
169 |
+
raise ValueError(
|
170 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
171 |
+
)
|
172 |
+
|
173 |
+
for idx, encoder_layer in enumerate(self.layers):
|
174 |
+
if output_hidden_states:
|
175 |
+
encoder_states = encoder_states + (hidden_states,)
|
176 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
177 |
+
dropout_probability = random.uniform(0, 1)
|
178 |
+
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
179 |
+
layer_outputs = (None, None)
|
180 |
+
else:
|
181 |
+
if self.gradient_checkpointing and self.training:
|
182 |
+
|
183 |
+
def create_custom_forward(module):
|
184 |
+
def custom_forward(*inputs):
|
185 |
+
return module(*inputs, output_attentions, relation_inputs=relation_inputs)
|
186 |
+
|
187 |
+
return custom_forward
|
188 |
+
|
189 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
190 |
+
create_custom_forward(encoder_layer),
|
191 |
+
hidden_states,
|
192 |
+
attention_mask,
|
193 |
+
(head_mask[idx] if head_mask is not None else None),
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
layer_outputs = encoder_layer(
|
197 |
+
hidden_states,
|
198 |
+
attention_mask,
|
199 |
+
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
200 |
+
output_attentions=output_attentions,
|
201 |
+
relation_inputs=relation_inputs,
|
202 |
+
)
|
203 |
+
|
204 |
+
hidden_states = layer_outputs[0]
|
205 |
+
|
206 |
+
if output_attentions:
|
207 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
208 |
+
|
209 |
+
if output_hidden_states:
|
210 |
+
encoder_states = encoder_states + (hidden_states,)
|
211 |
+
|
212 |
+
if not return_dict:
|
213 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
214 |
+
return BaseModelOutput(
|
215 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
216 |
+
)
|
custom_bart/encoder_layer.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
|
8 |
+
# Remote modules
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
from transformers import BartConfig
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
|
14 |
+
# Local modules
|
15 |
+
from .bart_attention import BartCustomAttention
|
16 |
+
from .bart_mask_attention import BartCustomMaskAttention
|
17 |
+
from .config import BartCustomConfig
|
18 |
+
|
19 |
+
|
20 |
+
class BartCustomEncoderLayer(nn.Module):
|
21 |
+
def __init__(self, config: BartCustomConfig, heads_mask: Optional[torch.Tensor]):
|
22 |
+
super().__init__()
|
23 |
+
self.embed_dim = config.d_model
|
24 |
+
is_simple_mask_commonsense = config.is_simple_mask_commonsense
|
25 |
+
if not is_simple_mask_commonsense:
|
26 |
+
print("Selecting complex relation attention")
|
27 |
+
self.self_attn = BartCustomAttention(
|
28 |
+
embed_dim=self.embed_dim,
|
29 |
+
num_heads=config.encoder_attention_heads,
|
30 |
+
dropout=config.attention_dropout,
|
31 |
+
num_relation_kinds=config.num_relation_kinds,
|
32 |
+
use_same_relation_kv_emb=config.use_same_relation_kv_emb,
|
33 |
+
heads_mask=heads_mask,
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
print("Selecting simple (MASK) relation attention")
|
37 |
+
self.self_attn = BartCustomMaskAttention(
|
38 |
+
embed_dim=self.embed_dim,
|
39 |
+
num_heads=config.encoder_attention_heads,
|
40 |
+
dropout=config.attention_dropout,
|
41 |
+
num_relation_kinds=config.num_relation_kinds,
|
42 |
+
heads_mask=heads_mask,
|
43 |
+
)
|
44 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
45 |
+
self.dropout = config.dropout
|
46 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
47 |
+
self.activation_dropout = config.activation_dropout
|
48 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
49 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
50 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
hidden_states: torch.FloatTensor,
|
55 |
+
attention_mask: torch.FloatTensor,
|
56 |
+
layer_head_mask: torch.FloatTensor,
|
57 |
+
output_attentions: Optional[bool] = False,
|
58 |
+
relation_inputs: Optional[torch.Tensor] = None,
|
59 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
63 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
64 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
65 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
66 |
+
`(encoder_attention_heads,)`.
|
67 |
+
output_attentions (`bool`, *optional*):
|
68 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
69 |
+
returned tensors for more detail.
|
70 |
+
"""
|
71 |
+
residual = hidden_states
|
72 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
73 |
+
hidden_states=hidden_states,
|
74 |
+
attention_mask=attention_mask,
|
75 |
+
layer_head_mask=layer_head_mask,
|
76 |
+
output_attentions=output_attentions,
|
77 |
+
relation_inputs=relation_inputs,
|
78 |
+
)
|
79 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
80 |
+
hidden_states = residual + hidden_states
|
81 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
82 |
+
|
83 |
+
residual = hidden_states
|
84 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
85 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
86 |
+
hidden_states = self.fc2(hidden_states)
|
87 |
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
88 |
+
hidden_states = residual + hidden_states
|
89 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
90 |
+
|
91 |
+
if hidden_states.dtype == torch.float16 and (
|
92 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
93 |
+
):
|
94 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
95 |
+
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
96 |
+
|
97 |
+
outputs = (hidden_states,)
|
98 |
+
|
99 |
+
if output_attentions:
|
100 |
+
outputs += (attn_weights,)
|
101 |
+
|
102 |
+
return outputs
|
custom_tokenizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .bart_custom_tokenizer_fast import *
|
custom_tokenizer/bart_custom_tokenizer_fast.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import json
|
17 |
+
from typing import List, Optional, Tuple, Dict
|
18 |
+
from collections import deque
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from tokenizers import pre_tokenizers, processors
|
24 |
+
|
25 |
+
from transformers.tokenization_utils_base import AddedToken, BatchEncoding
|
26 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
27 |
+
from transformers.utils import logging
|
28 |
+
from transformers.models.bart.tokenization_bart import BartTokenizer
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
35 |
+
|
36 |
+
# See all BART models at https://huggingface.co/models?filter=bart
|
37 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
38 |
+
"vocab_file": {
|
39 |
+
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/vocab.json",
|
40 |
+
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/vocab.json",
|
41 |
+
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/vocab.json",
|
42 |
+
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/vocab.json",
|
43 |
+
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/vocab.json",
|
44 |
+
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/vocab.json",
|
45 |
+
},
|
46 |
+
"merges_file": {
|
47 |
+
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/merges.txt",
|
48 |
+
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/merges.txt",
|
49 |
+
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/merges.txt",
|
50 |
+
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/merges.txt",
|
51 |
+
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/merges.txt",
|
52 |
+
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/merges.txt",
|
53 |
+
},
|
54 |
+
"tokenizer_file": {
|
55 |
+
"facebook/bart-base": "https://huggingface.co/facebook/bart-base/resolve/main/tokenizer.json",
|
56 |
+
"facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/tokenizer.json",
|
57 |
+
"facebook/bart-large-mnli": "https://huggingface.co/facebook/bart-large-mnli/resolve/main/tokenizer.json",
|
58 |
+
"facebook/bart-large-cnn": "https://huggingface.co/facebook/bart-large-cnn/resolve/main/tokenizer.json",
|
59 |
+
"facebook/bart-large-xsum": "https://huggingface.co/facebook/bart-large-xsum/resolve/main/tokenizer.json",
|
60 |
+
"yjernite/bart_eli5": "https://huggingface.co/yjernite/bart_eli5/resolve/main/tokenizer.json",
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
65 |
+
"facebook/bart-base": 1024,
|
66 |
+
"facebook/bart-large": 1024,
|
67 |
+
"facebook/bart-large-mnli": 1024,
|
68 |
+
"facebook/bart-large-cnn": 1024,
|
69 |
+
"facebook/bart-large-xsum": 1024,
|
70 |
+
"yjernite/bart_eli5": 1024,
|
71 |
+
}
|
72 |
+
|
73 |
+
|
74 |
+
class BartCustomTokenizerFast(PreTrainedTokenizerFast):
|
75 |
+
r"""
|
76 |
+
Construct a "fast" BART tokenizer (backed by HuggingFace's *tokenizers* library), derived from the GPT-2 tokenizer,
|
77 |
+
using byte-level Byte-Pair-Encoding.
|
78 |
+
|
79 |
+
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
80 |
+
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
81 |
+
|
82 |
+
```
|
83 |
+
>>> from transformers import BartTokenizerFast
|
84 |
+
>>> tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
|
85 |
+
>>> tokenizer("Hello world")['input_ids']
|
86 |
+
[0, 31414, 232, 2]
|
87 |
+
>>> tokenizer(" Hello world")['input_ids']
|
88 |
+
[0, 20920, 232, 2]
|
89 |
+
```
|
90 |
+
|
91 |
+
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
|
92 |
+
call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
|
93 |
+
|
94 |
+
<Tip>
|
95 |
+
|
96 |
+
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
97 |
+
|
98 |
+
</Tip>
|
99 |
+
|
100 |
+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
101 |
+
refer to this superclass for more information regarding those methods.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
vocab_file (`str`):
|
105 |
+
Path to the vocabulary file.
|
106 |
+
merges_file (`str`):
|
107 |
+
Path to the merges file.
|
108 |
+
errors (`str`, *optional*, defaults to `"replace"`):
|
109 |
+
Paradigm to follow when decoding bytes to UTF-8. See
|
110 |
+
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
111 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
112 |
+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
113 |
+
|
114 |
+
<Tip>
|
115 |
+
|
116 |
+
When building a sequence using special tokens, this is not the token that is used for the beginning of
|
117 |
+
sequence. The token used is the `cls_token`.
|
118 |
+
|
119 |
+
</Tip>
|
120 |
+
|
121 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
122 |
+
The end of sequence token.
|
123 |
+
|
124 |
+
<Tip>
|
125 |
+
|
126 |
+
When building a sequence using special tokens, this is not the token that is used for the end of sequence.
|
127 |
+
The token used is the `sep_token`.
|
128 |
+
|
129 |
+
</Tip>
|
130 |
+
|
131 |
+
sep_token (`str`, *optional*, defaults to `"</s>"`):
|
132 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
133 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
134 |
+
token of a sequence built with special tokens.
|
135 |
+
cls_token (`str`, *optional*, defaults to `"<s>"`):
|
136 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
137 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
138 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
139 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
140 |
+
token instead.
|
141 |
+
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
142 |
+
The token used for padding, for example when batching sequences of different lengths.
|
143 |
+
mask_token (`str`, *optional*, defaults to `"<mask>"`):
|
144 |
+
The token used for masking values. This is the token used when training this model with masked language
|
145 |
+
modeling. This is the token which the model will try to predict.
|
146 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
147 |
+
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
148 |
+
other word. (BART tokenizer detect beginning of words by the preceding space).
|
149 |
+
trim_offsets (`bool`, *optional*, defaults to `True`):
|
150 |
+
Whether the post processing step should trim offsets to avoid including whitespaces.
|
151 |
+
"""
|
152 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
153 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
154 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
155 |
+
model_input_names = ["input_ids", "attention_mask", "input_commonsense_relations", "commonsense_mask"]
|
156 |
+
slow_tokenizer_class = BartTokenizer
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
vocab_file=None,
|
161 |
+
merges_file=None,
|
162 |
+
tokenizer_file=None,
|
163 |
+
errors="replace",
|
164 |
+
bos_token="<s>",
|
165 |
+
eos_token="</s>",
|
166 |
+
sep_token="</s>",
|
167 |
+
cls_token="<s>",
|
168 |
+
unk_token="<unk>",
|
169 |
+
pad_token="<pad>",
|
170 |
+
mask_token="<mask>",
|
171 |
+
add_prefix_space=False,
|
172 |
+
trim_offsets=True,
|
173 |
+
**kwargs
|
174 |
+
):
|
175 |
+
super().__init__(
|
176 |
+
vocab_file,
|
177 |
+
merges_file,
|
178 |
+
tokenizer_file=tokenizer_file,
|
179 |
+
errors=errors,
|
180 |
+
bos_token=bos_token,
|
181 |
+
eos_token=eos_token,
|
182 |
+
sep_token=sep_token,
|
183 |
+
cls_token=cls_token,
|
184 |
+
unk_token=unk_token,
|
185 |
+
pad_token=pad_token,
|
186 |
+
mask_token=mask_token,
|
187 |
+
add_prefix_space=add_prefix_space,
|
188 |
+
trim_offsets=trim_offsets,
|
189 |
+
**kwargs,
|
190 |
+
)
|
191 |
+
|
192 |
+
self.relational_kind_to_index = None
|
193 |
+
self.there_is_difference_between_relations = True
|
194 |
+
|
195 |
+
pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
|
196 |
+
if pre_tok_state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
197 |
+
pre_tok_class = getattr(pre_tokenizers, pre_tok_state.pop("type"))
|
198 |
+
pre_tok_state["add_prefix_space"] = add_prefix_space
|
199 |
+
self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
|
200 |
+
|
201 |
+
self.add_prefix_space = add_prefix_space
|
202 |
+
|
203 |
+
# the pre_tokenizer is already updated in the GPT2TokenizerFast `__init__`
|
204 |
+
tokenizer_component = "post_processor"
|
205 |
+
tokenizer_component_instance = getattr(self.backend_tokenizer, tokenizer_component, None)
|
206 |
+
if tokenizer_component_instance:
|
207 |
+
state = json.loads(tokenizer_component_instance.__getstate__())
|
208 |
+
|
209 |
+
# The lists 'sep' and 'cls' must be cased in tuples for the object `post_processor_class`
|
210 |
+
if "sep" in state:
|
211 |
+
state["sep"] = tuple(state["sep"])
|
212 |
+
if "cls" in state:
|
213 |
+
state["cls"] = tuple(state["cls"])
|
214 |
+
|
215 |
+
changes_to_apply = False
|
216 |
+
|
217 |
+
if state.get("add_prefix_space", add_prefix_space) != add_prefix_space:
|
218 |
+
state["add_prefix_space"] = add_prefix_space
|
219 |
+
changes_to_apply = True
|
220 |
+
|
221 |
+
if state.get("trim_offsets", trim_offsets) != trim_offsets:
|
222 |
+
state["trim_offsets"] = trim_offsets
|
223 |
+
changes_to_apply = True
|
224 |
+
|
225 |
+
if changes_to_apply:
|
226 |
+
component_class = getattr(processors, state.pop("type"))
|
227 |
+
new_value = component_class(**state)
|
228 |
+
setattr(self.backend_tokenizer, tokenizer_component, new_value)
|
229 |
+
|
230 |
+
def __call__(self, *args, **kwargs):
|
231 |
+
input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
|
232 |
+
if 'input_commonsense_relations' in kwargs:
|
233 |
+
kwargs.pop('input_commonsense_relations')
|
234 |
+
out = super(BartCustomTokenizerFast, self).__call__(*args, **kwargs)
|
235 |
+
if out.get('input_commonsense_relations') is None:
|
236 |
+
out = self._post_process_tokenization(input_commonsense_relations, out)
|
237 |
+
return out
|
238 |
+
|
239 |
+
def set_known_relation_names(self, known_relations_names: List[str]):
|
240 |
+
self.relational_kind_to_index = {t: i + 1 for i, t in enumerate(known_relations_names)}
|
241 |
+
|
242 |
+
def set_operation_mode(self, there_is_difference_between_relations=True):
|
243 |
+
self.there_is_difference_between_relations = there_is_difference_between_relations
|
244 |
+
|
245 |
+
@property
|
246 |
+
def mask_token(self) -> str:
|
247 |
+
"""
|
248 |
+
`str`: Mask token, to use when training a model with masked-language modeling. Log an error if used while not
|
249 |
+
having been set.
|
250 |
+
|
251 |
+
BART tokenizer has a special mask token to be usable in the fill-mask pipeline. The mask token will greedily
|
252 |
+
comprise the space before the *<mask>*.
|
253 |
+
"""
|
254 |
+
if self._mask_token is None and self.verbose:
|
255 |
+
logger.error("Using mask_token, but it is not set yet.")
|
256 |
+
return None
|
257 |
+
return str(self._mask_token)
|
258 |
+
|
259 |
+
@mask_token.setter
|
260 |
+
def mask_token(self, value):
|
261 |
+
"""
|
262 |
+
Overriding the default behavior of the mask token to have it eat the space before it.
|
263 |
+
|
264 |
+
This is needed to preserve backward compatibility with all the previously used models based on Bart.
|
265 |
+
"""
|
266 |
+
# Mask token behave like a normal word, i.e. include the space before it
|
267 |
+
# So we set lstrip to True
|
268 |
+
value = AddedToken(value, lstrip=True, rstrip=False) if isinstance(value, str) else value
|
269 |
+
self._mask_token = value
|
270 |
+
|
271 |
+
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
272 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
273 |
+
|
274 |
+
if is_split_into_words and not self.add_prefix_space:
|
275 |
+
raise ValueError(
|
276 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
277 |
+
"to use it with pretokenized inputs."
|
278 |
+
)
|
279 |
+
input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
|
280 |
+
if 'input_commonsense_relations' in kwargs:
|
281 |
+
kwargs.pop('input_commonsense_relations')
|
282 |
+
out = super()._batch_encode_plus(*args, **kwargs)
|
283 |
+
if out.get('input_commonsense_relations') is None:
|
284 |
+
out = self._post_process_tokenization(input_commonsense_relations, out)
|
285 |
+
return out
|
286 |
+
|
287 |
+
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
288 |
+
is_split_into_words = kwargs.get("is_split_into_words", False)
|
289 |
+
|
290 |
+
if is_split_into_words and not self.add_prefix_space:
|
291 |
+
raise ValueError(
|
292 |
+
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True "
|
293 |
+
"to use it with pretokenized inputs."
|
294 |
+
)
|
295 |
+
|
296 |
+
input_commonsense_relations = kwargs.get('input_commonsense_relations', None)
|
297 |
+
if 'input_commonsense_relations' in kwargs:
|
298 |
+
kwargs.pop('input_commonsense_relations')
|
299 |
+
out = super()._encode_plus(*args, **kwargs)
|
300 |
+
if out.get('input_commonsense_relations') is None:
|
301 |
+
out = self._post_process_tokenization(input_commonsense_relations, out)
|
302 |
+
return out
|
303 |
+
|
304 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
305 |
+
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
306 |
+
return tuple(files)
|
307 |
+
|
308 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
309 |
+
output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
|
310 |
+
if token_ids_1 is None:
|
311 |
+
return output
|
312 |
+
|
313 |
+
return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
|
314 |
+
|
315 |
+
def create_token_type_ids_from_sequences(
|
316 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
317 |
+
) -> List[int]:
|
318 |
+
"""
|
319 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. BART does not
|
320 |
+
make use of token type ids, therefore a list of zeros is returned.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
token_ids_0 (`List[int]`):
|
324 |
+
List of IDs.
|
325 |
+
token_ids_1 (`List[int]`, *optional*):
|
326 |
+
Optional second list of IDs for sequence pairs.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
`List[int]`: List of zeros.
|
330 |
+
"""
|
331 |
+
sep = [self.sep_token_id]
|
332 |
+
cls = [self.cls_token_id]
|
333 |
+
|
334 |
+
if token_ids_1 is None:
|
335 |
+
return len(cls + token_ids_0 + sep) * [0]
|
336 |
+
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
337 |
+
|
338 |
+
def _post_process_tokenization(self, input_commonsense_relations, out: BatchEncoding) -> BatchEncoding:
|
339 |
+
new_input_relations = self.get_new_input_relation_kinds(
|
340 |
+
tokenizer_outputs=out, input_relations=input_commonsense_relations
|
341 |
+
)
|
342 |
+
#if new_input_relations is not None:
|
343 |
+
# print('sum:', new_input_relations.sum())
|
344 |
+
out['input_commonsense_relations'] = new_input_relations
|
345 |
+
return out
|
346 |
+
|
347 |
+
def find_new_tokens_span_for_multiword(self, pair, aux_dict):
|
348 |
+
old_start, old_end = pair
|
349 |
+
#print('pair:', pair)
|
350 |
+
keys = list(aux_dict.keys())
|
351 |
+
#print('aux_dict:', aux_dict)
|
352 |
+
new_start, new_end = old_start, old_end
|
353 |
+
for (start, end) in keys:
|
354 |
+
#print('-----> (start, end)', (start, end))
|
355 |
+
#print('old_start, old_end:', old_start, old_end)
|
356 |
+
#print('start, end:', start, end)
|
357 |
+
if old_start >= start and old_end <= end:
|
358 |
+
new_start, new_end = start, end
|
359 |
+
break
|
360 |
+
return new_start, new_end
|
361 |
+
|
362 |
+
def find_new_tokens_incoming_span_for_multiword(self, pair, aux_dict):
|
363 |
+
old_start, old_end = pair
|
364 |
+
incoming_rels = list([coord for v in aux_dict.values() for coord, relation in v.items()])
|
365 |
+
new_start, new_end = old_start, old_end
|
366 |
+
for (start, end) in incoming_rels:
|
367 |
+
#print('-----> (start, end)', (start, end))
|
368 |
+
#print('old_start, old_end:', old_start, old_end)
|
369 |
+
#print('start, end:', start, end)
|
370 |
+
if old_start >= start and old_end <= end:
|
371 |
+
new_start, new_end = start, end
|
372 |
+
break
|
373 |
+
return new_start, new_end
|
374 |
+
|
375 |
+
def get_new_input_relation_kinds(
|
376 |
+
self,
|
377 |
+
tokenizer_outputs: BatchEncoding,
|
378 |
+
input_relations: Optional[List[Dict[Tuple[int, int], Dict[Tuple[int, int], str]]]] = None
|
379 |
+
) -> torch.Tensor:
|
380 |
+
|
381 |
+
n_examples = len(tokenizer_outputs['input_ids'])
|
382 |
+
n_tokens = len(tokenizer_outputs['input_ids'][0])
|
383 |
+
aux_input_relation_kinds = np.zeros(
|
384 |
+
(n_examples, n_tokens, n_tokens),
|
385 |
+
dtype=np.int64
|
386 |
+
)
|
387 |
+
if not input_relations and input_relations is not None:
|
388 |
+
return torch.from_numpy(aux_input_relation_kinds)
|
389 |
+
elif not input_relations:
|
390 |
+
return None#torch.tensor([])
|
391 |
+
assert 'offset_mapping' in tokenizer_outputs, "Run tokenizer with return_offsets_mapping=True"
|
392 |
+
# print('aux_input_relation_kinds.shape', tokenizer_outputs['input_ids'].shape)
|
393 |
+
#print('input_relations:', input_relations)
|
394 |
+
if input_relations is not None:
|
395 |
+
# if input_relations is dirty, clean it
|
396 |
+
if isinstance(input_relations, dict):
|
397 |
+
input_relations = [input_relations]
|
398 |
+
mappings = tokenizer_outputs['offset_mapping']
|
399 |
+
assert len(mappings) == len(input_relations)
|
400 |
+
# print("to normal:", self.tokenizer.convert_ids_to_tokens(tokenizer_outputs['input_ids'][0]))
|
401 |
+
# print('words: ', words)
|
402 |
+
# print('x: ', mappings)
|
403 |
+
mappings = [[tuple(x) for x in mappings[idx].cpu().detach().tolist()] for idx in range(n_examples)]
|
404 |
+
# print(mappings)
|
405 |
+
examples_mappings = []
|
406 |
+
max_idx = 0
|
407 |
+
for idx, mapping in enumerate(mappings):
|
408 |
+
#print(idx, mapping)
|
409 |
+
words = tokenizer_outputs.word_ids(batch_index=idx)
|
410 |
+
tokens_to_words = deque(words)
|
411 |
+
token_idx_2_word_span = {}
|
412 |
+
for token_idx, (_char_i, _char_j) in enumerate(mapping):
|
413 |
+
word_idx_of_token = tokens_to_words.popleft()
|
414 |
+
if word_idx_of_token is None:
|
415 |
+
continue
|
416 |
+
token_span = tokenizer_outputs.word_to_chars(word_idx_of_token)
|
417 |
+
token_idx_2_word_span[token_idx] = (token_span.start, token_span.end) # sera que tenho de tirar o menos 1 (estava -1)
|
418 |
+
max_idx = max(token_idx, max_idx)
|
419 |
+
#print('token_idx_2_word_span:', token_idx_2_word_span)
|
420 |
+
##### Multiword ######
|
421 |
+
token_idx_2_word_span_multiword = {}
|
422 |
+
d = input_relations[idx]
|
423 |
+
for k, v in token_idx_2_word_span.items():
|
424 |
+
#print('k,v', k, v)
|
425 |
+
new_start, new_end = self.find_new_tokens_span_for_multiword(v, d)
|
426 |
+
token_idx_2_word_span_multiword[k] = (new_start, new_end)
|
427 |
+
#print('tmp:', token_idx_2_word_span_multiword)
|
428 |
+
#print('[before]token_idx_2_word_span_multiword[k]:', token_idx_2_word_span_multiword[k])
|
429 |
+
if v[0]==new_start and v[1]==new_end:
|
430 |
+
new_start, new_end = self.find_new_tokens_incoming_span_for_multiword(v, d)
|
431 |
+
token_idx_2_word_span_multiword[k] = (new_start, new_end)
|
432 |
+
#print('tmp2:', token_idx_2_word_span_multiword)
|
433 |
+
#print('[after]token_idx_2_word_span_multiword[k]:', token_idx_2_word_span_multiword[k])
|
434 |
+
##### ######
|
435 |
+
#print('token_idx_2_word_span_multiword:', token_idx_2_word_span_multiword)
|
436 |
+
examples_mappings.append(token_idx_2_word_span_multiword)
|
437 |
+
# print('len:', len(examples_mappings))
|
438 |
+
# print('max_idx: ', max_idx)
|
439 |
+
for i_example in range(n_examples):
|
440 |
+
token_idx_2_word_span = examples_mappings[i_example]
|
441 |
+
# print('token_idx_2_word_span: ', token_idx_2_word_span)
|
442 |
+
possible_relations = input_relations[i_example]
|
443 |
+
# print('possible_relations: ', possible_relations)
|
444 |
+
for token_i_idx in range(max_idx + 1):
|
445 |
+
for token_j_idx in range(max_idx + 1):
|
446 |
+
fixed_word_range = token_idx_2_word_span.get(token_i_idx, None)
|
447 |
+
other_word_range = token_idx_2_word_span.get(token_j_idx, None)
|
448 |
+
if not fixed_word_range or not other_word_range:
|
449 |
+
continue
|
450 |
+
#print(fixed_word_range, ' | ', other_word_range)
|
451 |
+
relations = possible_relations.get(fixed_word_range, None)
|
452 |
+
if not relations:
|
453 |
+
continue
|
454 |
+
#print('possible_relations:' , possible_relations)
|
455 |
+
relation_kind = relations.get(other_word_range, None)
|
456 |
+
if not relation_kind:
|
457 |
+
continue
|
458 |
+
#print('relation_kind:',relation_kind)
|
459 |
+
if self.there_is_difference_between_relations:
|
460 |
+
aux_input_relation_kinds[i_example, token_i_idx, token_j_idx] = self.relational_kind_to_index[relation_kind]
|
461 |
+
else:
|
462 |
+
# basic relation | only matters that relation exists between tokens
|
463 |
+
aux_input_relation_kinds[i_example, token_i_idx, token_j_idx] = 1
|
464 |
+
aux_input_relation_kinds = torch.from_numpy(aux_input_relation_kinds)
|
465 |
+
return aux_input_relation_kinds
|
466 |
+
|
467 |
+
def create_commonsense_mask(self, tokenizer_outputs, commonsense_matrix, num_heads=16, specific_head=0):
|
468 |
+
bsz = len(tokenizer_outputs['input_ids'])
|
469 |
+
n_tokens = len(tokenizer_outputs['input_ids'][0])
|
470 |
+
commonsense_mask = np.zeros(
|
471 |
+
((bsz, num_heads, n_tokens, n_tokens)),
|
472 |
+
dtype=np.int64
|
473 |
+
)
|
474 |
+
if commonsense_matrix is None:
|
475 |
+
commonsense_matrix = np.zeros(
|
476 |
+
((bsz, n_tokens, n_tokens)),
|
477 |
+
dtype=np.int64
|
478 |
+
)
|
479 |
+
commonsense_mask = commonsense_mask.reshape((num_heads, bsz, n_tokens, n_tokens))
|
480 |
+
# commonsense_matrix.shape: (bsz, src_len, tgt_len)
|
481 |
+
#print('commonsense_matrix:', commonsense_matrix)
|
482 |
+
commonsense_mask[specific_head] = commonsense_matrix
|
483 |
+
commonsense_mask = commonsense_mask.reshape((bsz, num_heads, n_tokens, n_tokens))
|
484 |
+
return commonsense_mask
|
data/__init__.py
ADDED
File without changes
|
data/relation_utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
from collections import deque
|
8 |
+
from ast import literal_eval
|
9 |
+
|
10 |
+
# Remote modules
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Local modules
|
14 |
+
|
15 |
+
#############################
|
16 |
+
# Constants
|
17 |
+
#############################
|
18 |
+
|
19 |
+
##########################################################
|
20 |
+
# Helper functions for Relations in dict format
|
21 |
+
##########################################################
|
22 |
+
|
23 |
+
def clean_relations(word_relations):
|
24 |
+
new_relations = deque()
|
25 |
+
for r in word_relations:
|
26 |
+
rel = {}
|
27 |
+
for r_key, r_value in r.items():
|
28 |
+
normal_k = literal_eval(r_key)
|
29 |
+
rel_d = {}
|
30 |
+
for r_d_key, r_d_value in r_value.items():
|
31 |
+
normal_d_k = literal_eval(r_d_key)
|
32 |
+
rel_d[normal_d_k] = r_d_value
|
33 |
+
rel[normal_k] = rel_d
|
34 |
+
new_relations.append(rel)
|
35 |
+
list_new_relations = list(new_relations)
|
36 |
+
return list_new_relations
|
37 |
+
|
38 |
+
##########################################################
|
39 |
+
# Helper functions for Relations in Matrix format
|
40 |
+
##########################################################
|
41 |
+
|
42 |
+
def relation_binary_2d_to_1d(relations_binary_mask, dim=1):
|
43 |
+
relations_binary_mask = relations_binary_mask.sum(dim=dim)
|
44 |
+
relations_binary_mask[relations_binary_mask > 1] = 1
|
45 |
+
return relations_binary_mask
|
46 |
+
|
47 |
+
def tokens_with_relations(relations_binary_mask):
|
48 |
+
relations_binary_mask_dim1 = relations_binary_mask.sum(dim=0)
|
49 |
+
relations_binary_mask_dim2 = relations_binary_mask.sum(dim=1)
|
50 |
+
tokens_with_rels = relations_binary_mask_dim1 + relations_binary_mask_dim2
|
51 |
+
tokens_with_rels[tokens_with_rels > 1] = 1
|
52 |
+
mask_rels = torch.tensor(tokens_with_rels, dtype=torch.bool)
|
53 |
+
return mask_rels
|
inference.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
# Remote modules
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Local modules
|
13 |
+
from kgs_binding.relation_mapper_builder import RelationsMapperBuilder
|
14 |
+
from kgs_binding.kg_qa_binding_utils import load_kg_handler
|
15 |
+
from data.relation_utils import clean_relations
|
16 |
+
from model_utils import create_layers_head_mask
|
17 |
+
|
18 |
+
from transformers import (
|
19 |
+
BartForConditionalGeneration,
|
20 |
+
BartTokenizer,
|
21 |
+
BartConfig,
|
22 |
+
DisjunctiveConstraint,
|
23 |
+
)
|
24 |
+
|
25 |
+
from utils import get_jump_chunks
|
26 |
+
|
27 |
+
#############################
|
28 |
+
# Constants
|
29 |
+
#############################
|
30 |
+
|
31 |
+
#############################
|
32 |
+
# Stuff
|
33 |
+
#############################
|
34 |
+
from custom_tokenizer import BartCustomTokenizerFast
|
35 |
+
from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration
|
36 |
+
from utils import get_device, KGType, Model_Type
|
37 |
+
|
38 |
+
from kgs_binding.kg_base_wrapper import KGBaseHandler
|
39 |
+
from kgs_binding.swow_handler import SwowHandler
|
40 |
+
from kgs_binding.conceptnet_handler import ConceptNetHandler
|
41 |
+
|
42 |
+
class Inference:
|
43 |
+
def __init__(self, model_path:str, max_length=32):
|
44 |
+
self.device = get_device()
|
45 |
+
self.tokenizer = self.prepare_tokenizer()
|
46 |
+
self.model = self.prepare_model(model_path)
|
47 |
+
self.max_length = max_length
|
48 |
+
|
49 |
+
def prepare_tokenizer(self):
|
50 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
51 |
+
return tokenizer
|
52 |
+
|
53 |
+
def prepare_model(self, model_path):
|
54 |
+
config = BartConfig.from_pretrained(model_path)
|
55 |
+
model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device)
|
56 |
+
model.eval()
|
57 |
+
return model
|
58 |
+
|
59 |
+
def pre_process_context(self, context):
|
60 |
+
context = context.lower()
|
61 |
+
context_tokenized = self.tokenizer(context, padding='max_length',
|
62 |
+
truncation='longest_first', max_length=self.max_length,
|
63 |
+
return_tensors="pt",
|
64 |
+
)
|
65 |
+
return context_tokenized
|
66 |
+
|
67 |
+
def generate_based_on_context(self, context):
|
68 |
+
model_input = self.pre_process_context(context)
|
69 |
+
generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
|
70 |
+
attention_mask=model_input["attention_mask"].to(self.device),
|
71 |
+
min_length=1,
|
72 |
+
max_length=self.max_length,
|
73 |
+
do_sample=True,
|
74 |
+
early_stopping=True,
|
75 |
+
num_beams=4,
|
76 |
+
temperature=1.0,
|
77 |
+
top_k=None,
|
78 |
+
top_p=None,
|
79 |
+
# eos_token_id=tokenizer.eos_token_id,
|
80 |
+
no_repeat_ngram_size=2,
|
81 |
+
num_return_sequences=1,
|
82 |
+
return_dict_in_generate=True,
|
83 |
+
output_attentions=True,
|
84 |
+
output_scores=True)
|
85 |
+
# print(f'Scores: {generated_answers_encoded}')
|
86 |
+
response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
|
87 |
+
clean_up_tokenization_spaces=True)
|
88 |
+
encoder_attentions = generated_answers_encoded['encoder_attentions']
|
89 |
+
return response, encoder_attentions, model_input
|
90 |
+
|
91 |
+
def prepare_context_for_visualization(self, context):
|
92 |
+
examples = []
|
93 |
+
response, encoder_outputs, model_input = self.generate_based_on_context(context)
|
94 |
+
encoder_outputs = torch.stack(encoder_outputs)
|
95 |
+
n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
|
96 |
+
print(encoder_outputs.size())
|
97 |
+
encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
|
98 |
+
for i, ex in enumerate(encoder_attentions):
|
99 |
+
d = {}
|
100 |
+
indices = model_input['input_ids'][i].detach().cpu()
|
101 |
+
all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
|
102 |
+
useful_indeces = indices != self.tokenizer.pad_token_id
|
103 |
+
all_tokens = np.array(all_tokens)[useful_indeces]
|
104 |
+
all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
|
105 |
+
d['words'] = all_tokens
|
106 |
+
d['attentions'] = ex.detach().cpu().numpy()
|
107 |
+
examples.append(d)
|
108 |
+
print(d['words'])
|
109 |
+
return response, examples
|
110 |
+
|
111 |
+
class RelationsInference:
|
112 |
+
def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32):
|
113 |
+
self.device = get_device()
|
114 |
+
kg_handler: KGBaseHandler = load_kg_handler(kg_type)
|
115 |
+
self.kg_handler = kg_handler
|
116 |
+
relation_names = kg_handler.get_relation_types()
|
117 |
+
self.tokenizer = self.prepare_tokenizer(relation_names, model_type)
|
118 |
+
self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
119 |
+
self.model, self.config = self.prepare_model(relation_names, model_path, model_type)
|
120 |
+
self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler)
|
121 |
+
self.max_length = max_length
|
122 |
+
|
123 |
+
def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type):
|
124 |
+
tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large')
|
125 |
+
tokenizer.set_known_relation_names(relation_names)
|
126 |
+
tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations())
|
127 |
+
return tokenizer
|
128 |
+
|
129 |
+
def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type):
|
130 |
+
config = BartCustomConfig.from_pretrained(model_path, revision='master')
|
131 |
+
print('config.heads_mask:', config.heads_mask)
|
132 |
+
if config.num_relation_kinds is None:
|
133 |
+
config.num_relation_kinds = len(relation_names)
|
134 |
+
if config.is_simple_mask_commonsense is None:
|
135 |
+
config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense()
|
136 |
+
if config.heads_mask is None:
|
137 |
+
config.heads_mask = create_layers_head_mask(config)#, heads_mask_type, specific_heads)
|
138 |
+
model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device)
|
139 |
+
model.eval()
|
140 |
+
return model, config
|
141 |
+
|
142 |
+
def pre_process_context(self, context):
|
143 |
+
context = context.lower()
|
144 |
+
# process context in search for relations
|
145 |
+
commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True)
|
146 |
+
# clean relation
|
147 |
+
commonsense_relation = clean_relations(commonsense_relations)[0]
|
148 |
+
# convert this relations to matrices
|
149 |
+
print(commonsense_relation)
|
150 |
+
context_tokenized = self.tokenizer(context, padding='max_length',
|
151 |
+
truncation='longest_first', max_length=self.max_length,
|
152 |
+
return_tensors="pt", return_offsets_mapping=True,
|
153 |
+
input_commonsense_relations=commonsense_relation,
|
154 |
+
)
|
155 |
+
return context_tokenized
|
156 |
+
|
157 |
+
def get_relations_information(self, phrase_generated):
|
158 |
+
all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0]
|
159 |
+
words = phrase_generated.strip().split(' ') # all words
|
160 |
+
concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True)
|
161 |
+
concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations))
|
162 |
+
#print('without_relations:', concepts_with_no_relations)
|
163 |
+
print("====== RELATIONS SUMMARY ======")
|
164 |
+
print('phrase_generated:', phrase_generated)
|
165 |
+
print('words:', words)
|
166 |
+
print('all_concepts:', all_concepts)
|
167 |
+
print('concepts_with_relations:', concepts_with_relations)
|
168 |
+
print('without_relations:', concepts_with_no_relations)
|
169 |
+
print("\n== STATS:")
|
170 |
+
print('n_words:', len(words))
|
171 |
+
print('n_concepts:', len(all_concepts))
|
172 |
+
print('n_concepts_with_relations:', len(concepts_with_relations))
|
173 |
+
print('n_c_without_relations:', len(concepts_with_no_relations))
|
174 |
+
print("====== ================= ======")
|
175 |
+
return words, all_concepts, concepts_with_relations, concepts_with_no_relations
|
176 |
+
|
177 |
+
def remove_subsets(self, l):
|
178 |
+
l2 = l[:]
|
179 |
+
for m in l:
|
180 |
+
for n in l:
|
181 |
+
if set(m).issubset(set(n)) and m != n:
|
182 |
+
l2.remove(m)
|
183 |
+
break
|
184 |
+
return l2
|
185 |
+
|
186 |
+
def generate_based_on_context(self, context, use_kg=False):
|
187 |
+
model_input = self.pre_process_context(context)
|
188 |
+
#print(model_input)
|
189 |
+
gen_kwargs = {}
|
190 |
+
if "input_commonsense_relations" in model_input:
|
191 |
+
#print(model_input['input_commonsense_relations'].sum())
|
192 |
+
gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device)
|
193 |
+
|
194 |
+
constraints = None
|
195 |
+
if use_kg:
|
196 |
+
constraints = []
|
197 |
+
concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True)
|
198 |
+
useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context]
|
199 |
+
if not useful_concepts:
|
200 |
+
useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context]
|
201 |
+
useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] # add spaces
|
202 |
+
#useful_concepts = [[phrase for phrase in concepts if len(phrase.split(' ')) == 1] for concepts in useful_concepts]
|
203 |
+
#useful_concepts = list(itertools.chain.from_iterable(useful_concepts))
|
204 |
+
#print('useful_concepts:', useful_concepts[:5])
|
205 |
+
if concepts_from_context:
|
206 |
+
for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts):
|
207 |
+
print('neighbour:', neighbour_concepts[:20])
|
208 |
+
#flexible_words = self.most_similar_words(context_concept, neighbour_concepts) # limit the upperbound
|
209 |
+
#flexible_words = [word for word in flexible_words if word not in context_concept] # remove input concepts
|
210 |
+
flexible_words = [word for word in neighbour_concepts if word not in context_concept] # remove input concepts
|
211 |
+
flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids
|
212 |
+
flexible_words_ids = self.remove_subsets(flexible_words_ids)
|
213 |
+
#add_prefix_space=True
|
214 |
+
#flexible_words_ids = [x for x in flexible_words_ids if len(x) == 1] # problem with subsets
|
215 |
+
flexible_words_ids = flexible_words_ids[:10]
|
216 |
+
print('flexible_words_ids:', flexible_words_ids[:3])
|
217 |
+
constraint = DisjunctiveConstraint(flexible_words_ids)
|
218 |
+
constraints.append(constraint)
|
219 |
+
else:
|
220 |
+
constraints = None
|
221 |
+
|
222 |
+
generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device),
|
223 |
+
attention_mask=model_input["attention_mask"].to(self.device),
|
224 |
+
constraints=constraints,
|
225 |
+
min_length=1,
|
226 |
+
max_length=self.max_length,
|
227 |
+
do_sample=False,
|
228 |
+
early_stopping=True,
|
229 |
+
num_beams=8,
|
230 |
+
temperature=1.0,
|
231 |
+
top_k=None,
|
232 |
+
top_p=None,
|
233 |
+
# eos_token_id=tokenizer.eos_token_id,
|
234 |
+
no_repeat_ngram_size=2,
|
235 |
+
num_return_sequences=1,
|
236 |
+
return_dict_in_generate=True,
|
237 |
+
output_attentions=True,
|
238 |
+
output_scores=True,
|
239 |
+
**gen_kwargs,
|
240 |
+
)
|
241 |
+
# print(f'Scores: {generated_answers_encoded}')
|
242 |
+
response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True,
|
243 |
+
clean_up_tokenization_spaces=True)
|
244 |
+
encoder_attentions = generated_answers_encoded['encoder_attentions']
|
245 |
+
return response, encoder_attentions, model_input
|
246 |
+
|
247 |
+
def get_related_concepts_list(self, knowledge, list_concepts):
|
248 |
+
other_concepts = []
|
249 |
+
for concept in list_concepts:
|
250 |
+
other_near_concepts = knowledge.get_related_concepts(concept)
|
251 |
+
other_concepts.extend(other_near_concepts)
|
252 |
+
return other_concepts
|
253 |
+
|
254 |
+
|
255 |
+
def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1):
|
256 |
+
model_inputs = [self.pre_process_context(context) for context in contexts]
|
257 |
+
constraints = None
|
258 |
+
if use_kg:
|
259 |
+
constraints = []
|
260 |
+
concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts]
|
261 |
+
neighbours_contexts = []#[self.get_related_concepts_list(self.relation_mapper_builder.knowledge, context) for context in concepts_from_contexts]
|
262 |
+
if not neighbours_contexts:
|
263 |
+
neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts]
|
264 |
+
all_constraints = []
|
265 |
+
for context_neighbours in neighbours_contexts:
|
266 |
+
# context_neighbours is a collection of concepts
|
267 |
+
# lets create sub collections of concepts
|
268 |
+
context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3]
|
269 |
+
n_size_chuncks = len(context_neighbours) // max_concepts
|
270 |
+
n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1
|
271 |
+
sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks))
|
272 |
+
constraints = []
|
273 |
+
for sub_concepts in sub_concepts_collection[:max_concepts]:
|
274 |
+
flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts,
|
275 |
+
add_special_tokens=False).input_ids # add_prefix_space=True,
|
276 |
+
# flexible_words_ids = self.remove_subsets(flexible_words_ids)
|
277 |
+
flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids]
|
278 |
+
disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids))))
|
279 |
+
if not any(disjunctive_set):
|
280 |
+
continue
|
281 |
+
constraint = DisjunctiveConstraint(disjunctive_set)
|
282 |
+
constraints.append(constraint)
|
283 |
+
if not any(constraints):
|
284 |
+
constraints = None
|
285 |
+
all_constraints.append(constraints)
|
286 |
+
else:
|
287 |
+
all_constraints = None
|
288 |
+
if not all_constraints:
|
289 |
+
all_constraints = None
|
290 |
+
|
291 |
+
generated_answers_encoded = []
|
292 |
+
encoder_attentions_list = []
|
293 |
+
for i, contraints in enumerate(all_constraints):
|
294 |
+
#print('contraints.token_ids:', [x.token_ids for x in contraints])
|
295 |
+
gen_kwargs = {}
|
296 |
+
inputs = model_inputs[i]
|
297 |
+
if "input_commonsense_relations" in inputs:
|
298 |
+
# print(model_input['input_commonsense_relations'].sum())
|
299 |
+
gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device)
|
300 |
+
#print('model_kwargs.get("attention_mask"):', model_kwargs.get("attention_mask"))
|
301 |
+
gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device),
|
302 |
+
attention_mask=inputs["attention_mask"].to(self.device),
|
303 |
+
constraints=constraints,
|
304 |
+
min_length=1,
|
305 |
+
max_length=self.max_length,
|
306 |
+
do_sample=False,
|
307 |
+
early_stopping=True,
|
308 |
+
num_beams=8,
|
309 |
+
temperature=1.0,
|
310 |
+
top_k=None,
|
311 |
+
top_p=None,
|
312 |
+
# eos_token_id=tokenizer.eos_token_id,
|
313 |
+
no_repeat_ngram_size=2,
|
314 |
+
num_return_sequences=1,
|
315 |
+
return_dict_in_generate=True,
|
316 |
+
output_attentions=True,
|
317 |
+
output_scores=True,
|
318 |
+
**gen_kwargs,
|
319 |
+
)
|
320 |
+
# print('[gen]:', gen)
|
321 |
+
# print(tokenizer.batch_decode(gen))
|
322 |
+
generated_answers_encoded.append(gen['sequences'][0].detach().cpu())
|
323 |
+
encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu())
|
324 |
+
# print(f'Scores: {generated_answers_encoded}')
|
325 |
+
text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
|
326 |
+
clean_up_tokenization_spaces=True)
|
327 |
+
return text_results, encoder_attentions_list, model_inputs
|
328 |
+
|
329 |
+
def prepare_context_for_visualization(self, context):
|
330 |
+
examples, relations = [], []
|
331 |
+
response, encoder_outputs, model_input = self.generate_based_on_context(context)
|
332 |
+
input_commonsense_relations = model_input.get("input_commonsense_relations")
|
333 |
+
encoder_outputs = torch.stack(encoder_outputs)
|
334 |
+
n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size()
|
335 |
+
print(encoder_outputs.size())
|
336 |
+
encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt)
|
337 |
+
for i, ex in enumerate(encoder_attentions):
|
338 |
+
d = {}
|
339 |
+
indices = model_input['input_ids'][i].detach().cpu()
|
340 |
+
all_tokens = self.tokenizer.convert_ids_to_tokens(indices)
|
341 |
+
useful_indeces = indices != self.tokenizer.pad_token_id
|
342 |
+
all_tokens = np.array(all_tokens)[useful_indeces]
|
343 |
+
all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]
|
344 |
+
d['words'] = all_tokens
|
345 |
+
d['attentions'] = ex.detach().cpu().numpy()
|
346 |
+
examples.append(d)
|
347 |
+
relations.append(input_commonsense_relations[i])
|
348 |
+
print(d['words'])
|
349 |
+
return response, examples, relations
|
kgs_binding/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .kg_base_wrapper import KGBaseHandler
|
2 |
+
from .relation_mapper_builder import RelationsMapperBuilder
|
3 |
+
from . import *
|
kgs_binding/conceptnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|
kgs_binding/conceptnet/conceptnet_english_noun_2_noun_relations.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8b82686e2cb4a32a827d3c0a0c63a91d5d102fe5813fe898cabd9a117aa7374c
|
3 |
+
size 186932142
|
kgs_binding/conceptnet/conceptnet_english_nouns.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b90ab07ca7445623bcd90489367c4016ca3b4ed743816a99b730f22e13ac339c
|
3 |
+
size 140804377
|
kgs_binding/conceptnet/conceptnet_english_nouns_simple.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad6e76d470432dc3c9c7c0ebbf340eda3c4f69008c9f8a27df97f8e005e5db02
|
3 |
+
size 22419586
|
kgs_binding/conceptnet_handler.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import Tuple, Optional, List
|
7 |
+
# Remote modules
|
8 |
+
|
9 |
+
# Local modules
|
10 |
+
from .kg_base_wrapper import KGBaseHandler
|
11 |
+
from utils import read_json_file_2_dict
|
12 |
+
|
13 |
+
#############################
|
14 |
+
# Constants
|
15 |
+
#############################
|
16 |
+
|
17 |
+
#############################
|
18 |
+
# Handler
|
19 |
+
#############################
|
20 |
+
|
21 |
+
class ConceptNetHandler(KGBaseHandler):
|
22 |
+
def __init__(self, database=""):
|
23 |
+
super(ConceptNetHandler, self).__init__()
|
24 |
+
_store_dir = 'kgs_binding/conceptnet'
|
25 |
+
self.conceptnet_concepts = read_json_file_2_dict('conceptnet_english_nouns_simple.json', store_dir=_store_dir)
|
26 |
+
self.relations_concepts = read_json_file_2_dict('conceptnet_english_noun_2_noun_relations.json', store_dir=_store_dir)
|
27 |
+
self.concept_2_concepts = read_json_file_2_dict('conceptnet_english_nouns.json', store_dir=_store_dir)
|
28 |
+
|
29 |
+
def get_relation_types(self) -> List[str]:
|
30 |
+
updated_relation_names = ['not_has_property', 'not_desires', 'external_u_r_l', 'created_by',
|
31 |
+
'not_capable_of', 'antonym', 'has_first_subevent', 'located_near',
|
32 |
+
'desires', 'has_prerequisite', 'has_last_subevent', 'synonym', 'is_a',
|
33 |
+
'manner_of', 'has_a', 'motivated_by_goal', 'instance_of',
|
34 |
+
'etymologically_derived_from', 'capable_of', 'for', 'at_location',
|
35 |
+
'has_subevent', 'causes', 'has_context', 'symbol_of', 'derived_from',
|
36 |
+
'made_of', 'causes_desire', 'has_property', 'similar_to', 'used_for', 'by',
|
37 |
+
'entails', 'form_of', 'receives_action', 'distinct_from', 'related_to',
|
38 |
+
'part_of', 'defined_as', 'etymologically_related_to']
|
39 |
+
return updated_relation_names
|
40 |
+
|
41 |
+
def exists_relation_between(self, concept, other_concept) -> bool:
|
42 |
+
left_2_right, right_2_left = self.relation_between(concept, other_concept)
|
43 |
+
return left_2_right is not None or right_2_left is not None
|
44 |
+
|
45 |
+
def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
|
46 |
+
left_2_right_txt = f'{concept}|{other_concept}'
|
47 |
+
right_2_left_txt = f'{other_concept}|{concept}'
|
48 |
+
left_2_right_relations = self.relations_concepts.get(left_2_right_txt, None)
|
49 |
+
right_2_left_relations = self.relations_concepts.get(right_2_left_txt, None)
|
50 |
+
left_2_right_relation, right_2_left_relation = None, None
|
51 |
+
if left_2_right_relations:
|
52 |
+
left_2_right_relation = self.ignore_less_relevant_connection(left_2_right_relations)
|
53 |
+
if right_2_left_relations:
|
54 |
+
right_2_left_relation = self.ignore_less_relevant_connection(right_2_left_relations)
|
55 |
+
return left_2_right_relation, right_2_left_relation
|
56 |
+
|
57 |
+
def get_related_concepts(self, concept) -> Optional[List[str]]:
|
58 |
+
return self.concept_2_concepts.get(concept, [])
|
59 |
+
|
60 |
+
def does_concept_exist(self, concept) -> bool:
|
61 |
+
return concept in self.conceptnet_concepts
|
kgs_binding/english_stopwords.txt
ADDED
@@ -0,0 +1,1126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'll
|
2 |
+
'tis
|
3 |
+
'twas
|
4 |
+
've
|
5 |
+
a
|
6 |
+
a's
|
7 |
+
able
|
8 |
+
ableabout
|
9 |
+
about
|
10 |
+
above
|
11 |
+
abroad
|
12 |
+
abst
|
13 |
+
accordance
|
14 |
+
according
|
15 |
+
accordingly
|
16 |
+
across
|
17 |
+
act
|
18 |
+
actually
|
19 |
+
ad
|
20 |
+
added
|
21 |
+
adj
|
22 |
+
adopted
|
23 |
+
ae
|
24 |
+
af
|
25 |
+
affected
|
26 |
+
affecting
|
27 |
+
ag
|
28 |
+
ah
|
29 |
+
ai
|
30 |
+
ain't
|
31 |
+
aint
|
32 |
+
al
|
33 |
+
all
|
34 |
+
almost
|
35 |
+
along
|
36 |
+
alongside
|
37 |
+
also
|
38 |
+
although
|
39 |
+
am
|
40 |
+
amid
|
41 |
+
amidst
|
42 |
+
among
|
43 |
+
amongst
|
44 |
+
amoungst
|
45 |
+
an
|
46 |
+
and
|
47 |
+
another
|
48 |
+
any
|
49 |
+
anybody
|
50 |
+
anyhow
|
51 |
+
anymore
|
52 |
+
anyone
|
53 |
+
anything
|
54 |
+
anyway
|
55 |
+
anyways
|
56 |
+
anywhere
|
57 |
+
ao
|
58 |
+
apart
|
59 |
+
apparently
|
60 |
+
appear
|
61 |
+
appreciate
|
62 |
+
appropriate
|
63 |
+
approximately
|
64 |
+
aq
|
65 |
+
ar
|
66 |
+
are
|
67 |
+
aren
|
68 |
+
aren't
|
69 |
+
arent
|
70 |
+
arise
|
71 |
+
around
|
72 |
+
arpa
|
73 |
+
as
|
74 |
+
aside
|
75 |
+
ask
|
76 |
+
asked
|
77 |
+
asking
|
78 |
+
asks
|
79 |
+
associated
|
80 |
+
at
|
81 |
+
au
|
82 |
+
auth
|
83 |
+
aw
|
84 |
+
away
|
85 |
+
awfully
|
86 |
+
az
|
87 |
+
b
|
88 |
+
ba
|
89 |
+
back
|
90 |
+
backed
|
91 |
+
backing
|
92 |
+
backs
|
93 |
+
bb
|
94 |
+
bd
|
95 |
+
be
|
96 |
+
became
|
97 |
+
because
|
98 |
+
become
|
99 |
+
becomes
|
100 |
+
becoming
|
101 |
+
been
|
102 |
+
beforehand
|
103 |
+
began
|
104 |
+
beginning
|
105 |
+
beginnings
|
106 |
+
begins
|
107 |
+
behind
|
108 |
+
being
|
109 |
+
beings
|
110 |
+
believe
|
111 |
+
below
|
112 |
+
beside
|
113 |
+
besides
|
114 |
+
best
|
115 |
+
better
|
116 |
+
between
|
117 |
+
beyond
|
118 |
+
bf
|
119 |
+
bg
|
120 |
+
bh
|
121 |
+
bi
|
122 |
+
biol
|
123 |
+
bj
|
124 |
+
bm
|
125 |
+
bn
|
126 |
+
bo
|
127 |
+
both
|
128 |
+
br
|
129 |
+
brief
|
130 |
+
briefly
|
131 |
+
bs
|
132 |
+
bt
|
133 |
+
but
|
134 |
+
buy
|
135 |
+
bv
|
136 |
+
bw
|
137 |
+
by
|
138 |
+
bz
|
139 |
+
c
|
140 |
+
c'mon
|
141 |
+
c's
|
142 |
+
ca
|
143 |
+
call
|
144 |
+
came
|
145 |
+
can
|
146 |
+
can't
|
147 |
+
cannot
|
148 |
+
cant
|
149 |
+
caption
|
150 |
+
case
|
151 |
+
cases
|
152 |
+
cause
|
153 |
+
causes
|
154 |
+
cc
|
155 |
+
cd
|
156 |
+
certain
|
157 |
+
certainly
|
158 |
+
cf
|
159 |
+
cg
|
160 |
+
ch
|
161 |
+
changes
|
162 |
+
ci
|
163 |
+
ck
|
164 |
+
cl
|
165 |
+
clear
|
166 |
+
clearly
|
167 |
+
cm
|
168 |
+
cmon
|
169 |
+
cn
|
170 |
+
co
|
171 |
+
co.
|
172 |
+
com
|
173 |
+
come
|
174 |
+
comes
|
175 |
+
con
|
176 |
+
consequently
|
177 |
+
contain
|
178 |
+
containing
|
179 |
+
contains
|
180 |
+
copy
|
181 |
+
corresponding
|
182 |
+
could
|
183 |
+
could've
|
184 |
+
couldn
|
185 |
+
couldn't
|
186 |
+
couldnt
|
187 |
+
cr
|
188 |
+
cs
|
189 |
+
cu
|
190 |
+
currently
|
191 |
+
cv
|
192 |
+
cx
|
193 |
+
cy
|
194 |
+
cz
|
195 |
+
d
|
196 |
+
dare
|
197 |
+
daren't
|
198 |
+
darent
|
199 |
+
de
|
200 |
+
dear
|
201 |
+
definitely
|
202 |
+
describe
|
203 |
+
described
|
204 |
+
despite
|
205 |
+
detail
|
206 |
+
did
|
207 |
+
didn
|
208 |
+
didn't
|
209 |
+
didnt
|
210 |
+
differ
|
211 |
+
different
|
212 |
+
differently
|
213 |
+
directly
|
214 |
+
dj
|
215 |
+
dk
|
216 |
+
dm
|
217 |
+
do
|
218 |
+
does
|
219 |
+
doesn
|
220 |
+
doesn't
|
221 |
+
doesnt
|
222 |
+
doing
|
223 |
+
don
|
224 |
+
don't
|
225 |
+
done
|
226 |
+
dont
|
227 |
+
downed
|
228 |
+
downing
|
229 |
+
due
|
230 |
+
during
|
231 |
+
dz
|
232 |
+
e
|
233 |
+
each
|
234 |
+
ec
|
235 |
+
ed
|
236 |
+
edu
|
237 |
+
ee
|
238 |
+
eg
|
239 |
+
eh
|
240 |
+
either
|
241 |
+
else
|
242 |
+
elsewhere
|
243 |
+
enough
|
244 |
+
entirely
|
245 |
+
er
|
246 |
+
es
|
247 |
+
especially
|
248 |
+
et
|
249 |
+
et-al
|
250 |
+
etc
|
251 |
+
even
|
252 |
+
evenly
|
253 |
+
ever
|
254 |
+
evermore
|
255 |
+
every
|
256 |
+
everybody
|
257 |
+
everyone
|
258 |
+
everything
|
259 |
+
everywhere
|
260 |
+
ex
|
261 |
+
exactly
|
262 |
+
example
|
263 |
+
except
|
264 |
+
f
|
265 |
+
fairly
|
266 |
+
far
|
267 |
+
farther
|
268 |
+
felt
|
269 |
+
few
|
270 |
+
fewer
|
271 |
+
ff
|
272 |
+
fi
|
273 |
+
fify
|
274 |
+
fj
|
275 |
+
fk
|
276 |
+
fm
|
277 |
+
fo
|
278 |
+
for
|
279 |
+
forever
|
280 |
+
formerly
|
281 |
+
forth
|
282 |
+
found
|
283 |
+
fr
|
284 |
+
from
|
285 |
+
front
|
286 |
+
full
|
287 |
+
fully
|
288 |
+
further
|
289 |
+
furthered
|
290 |
+
furthering
|
291 |
+
furthermore
|
292 |
+
furthers
|
293 |
+
fx
|
294 |
+
g
|
295 |
+
ga
|
296 |
+
gave
|
297 |
+
gb
|
298 |
+
gd
|
299 |
+
ge
|
300 |
+
generally
|
301 |
+
gf
|
302 |
+
gg
|
303 |
+
gh
|
304 |
+
gi
|
305 |
+
gl
|
306 |
+
gm
|
307 |
+
gmt
|
308 |
+
gn
|
309 |
+
go
|
310 |
+
got
|
311 |
+
gotten
|
312 |
+
gov
|
313 |
+
gp
|
314 |
+
gq
|
315 |
+
gr
|
316 |
+
great
|
317 |
+
greater
|
318 |
+
greatest
|
319 |
+
greetings
|
320 |
+
group
|
321 |
+
grouped
|
322 |
+
grouping
|
323 |
+
groups
|
324 |
+
gs
|
325 |
+
gt
|
326 |
+
gu
|
327 |
+
gw
|
328 |
+
gy
|
329 |
+
h
|
330 |
+
had
|
331 |
+
hadn't
|
332 |
+
hadnt
|
333 |
+
half
|
334 |
+
happens
|
335 |
+
hardly
|
336 |
+
has
|
337 |
+
hasn
|
338 |
+
hasn't
|
339 |
+
hasnt
|
340 |
+
have
|
341 |
+
haven
|
342 |
+
haven't
|
343 |
+
havent
|
344 |
+
having
|
345 |
+
he
|
346 |
+
he'd
|
347 |
+
he'll
|
348 |
+
he's
|
349 |
+
hed
|
350 |
+
hell
|
351 |
+
hello
|
352 |
+
help
|
353 |
+
hence
|
354 |
+
her
|
355 |
+
here
|
356 |
+
here's
|
357 |
+
hereafter
|
358 |
+
hereby
|
359 |
+
herein
|
360 |
+
heres
|
361 |
+
hereupon
|
362 |
+
hers
|
363 |
+
herself
|
364 |
+
herse”
|
365 |
+
hes
|
366 |
+
hi
|
367 |
+
hid
|
368 |
+
high
|
369 |
+
higher
|
370 |
+
highest
|
371 |
+
him
|
372 |
+
himself
|
373 |
+
himse”
|
374 |
+
his
|
375 |
+
hither
|
376 |
+
hk
|
377 |
+
hm
|
378 |
+
hn
|
379 |
+
hopefully
|
380 |
+
how
|
381 |
+
how'd
|
382 |
+
how'll
|
383 |
+
how's
|
384 |
+
howbeit
|
385 |
+
however
|
386 |
+
hr
|
387 |
+
ht
|
388 |
+
htm
|
389 |
+
hu
|
390 |
+
i
|
391 |
+
i'd
|
392 |
+
i'll
|
393 |
+
i'm
|
394 |
+
i've
|
395 |
+
i.e.
|
396 |
+
id
|
397 |
+
ie
|
398 |
+
if
|
399 |
+
ignored
|
400 |
+
ii
|
401 |
+
il
|
402 |
+
ill
|
403 |
+
im
|
404 |
+
immediate
|
405 |
+
immediately
|
406 |
+
importance
|
407 |
+
important
|
408 |
+
in
|
409 |
+
inasmuch
|
410 |
+
inc
|
411 |
+
inc.
|
412 |
+
indeed
|
413 |
+
index
|
414 |
+
indicate
|
415 |
+
indicated
|
416 |
+
indicates
|
417 |
+
information
|
418 |
+
inner
|
419 |
+
inside
|
420 |
+
insofar
|
421 |
+
instead
|
422 |
+
int
|
423 |
+
interest
|
424 |
+
interested
|
425 |
+
interesting
|
426 |
+
interests
|
427 |
+
into
|
428 |
+
inward
|
429 |
+
io
|
430 |
+
iq
|
431 |
+
ir
|
432 |
+
is
|
433 |
+
isn
|
434 |
+
isn't
|
435 |
+
isnt
|
436 |
+
it
|
437 |
+
it'd
|
438 |
+
it'll
|
439 |
+
it's
|
440 |
+
itd
|
441 |
+
itll
|
442 |
+
its
|
443 |
+
itself
|
444 |
+
itse”
|
445 |
+
ive
|
446 |
+
j
|
447 |
+
je
|
448 |
+
jm
|
449 |
+
jo
|
450 |
+
join
|
451 |
+
jp
|
452 |
+
just
|
453 |
+
k
|
454 |
+
ke
|
455 |
+
keep
|
456 |
+
keeps
|
457 |
+
kept
|
458 |
+
kg
|
459 |
+
kh
|
460 |
+
ki
|
461 |
+
kind
|
462 |
+
km
|
463 |
+
kn
|
464 |
+
knew
|
465 |
+
know
|
466 |
+
known
|
467 |
+
knows
|
468 |
+
kp
|
469 |
+
kr
|
470 |
+
kw
|
471 |
+
ky
|
472 |
+
kz
|
473 |
+
l
|
474 |
+
la
|
475 |
+
large
|
476 |
+
largely
|
477 |
+
last
|
478 |
+
lately
|
479 |
+
later
|
480 |
+
latest
|
481 |
+
latter
|
482 |
+
latterly
|
483 |
+
lb
|
484 |
+
lc
|
485 |
+
least
|
486 |
+
less
|
487 |
+
lest
|
488 |
+
let
|
489 |
+
let's
|
490 |
+
lets
|
491 |
+
li
|
492 |
+
like
|
493 |
+
liked
|
494 |
+
likely
|
495 |
+
likewise
|
496 |
+
line
|
497 |
+
lk
|
498 |
+
ll
|
499 |
+
look
|
500 |
+
looking
|
501 |
+
looks
|
502 |
+
lower
|
503 |
+
lr
|
504 |
+
ls
|
505 |
+
lt
|
506 |
+
ltd
|
507 |
+
lu
|
508 |
+
lv
|
509 |
+
ly
|
510 |
+
m
|
511 |
+
ma
|
512 |
+
made
|
513 |
+
mainly
|
514 |
+
make
|
515 |
+
makes
|
516 |
+
making
|
517 |
+
many
|
518 |
+
may
|
519 |
+
maybe
|
520 |
+
mayn't
|
521 |
+
maynt
|
522 |
+
mc
|
523 |
+
md
|
524 |
+
me
|
525 |
+
mean
|
526 |
+
means
|
527 |
+
meantime
|
528 |
+
meanwhile
|
529 |
+
member
|
530 |
+
members
|
531 |
+
merely
|
532 |
+
mg
|
533 |
+
mh
|
534 |
+
might
|
535 |
+
might've
|
536 |
+
mightn't
|
537 |
+
mightnt
|
538 |
+
mil
|
539 |
+
mill
|
540 |
+
mine
|
541 |
+
miss
|
542 |
+
mk
|
543 |
+
ml
|
544 |
+
mm
|
545 |
+
mn
|
546 |
+
mo
|
547 |
+
more
|
548 |
+
moreover
|
549 |
+
most
|
550 |
+
mostly
|
551 |
+
move
|
552 |
+
mp
|
553 |
+
mq
|
554 |
+
mr
|
555 |
+
mrs
|
556 |
+
ms
|
557 |
+
msie
|
558 |
+
mt
|
559 |
+
mu
|
560 |
+
much
|
561 |
+
mug
|
562 |
+
must
|
563 |
+
must've
|
564 |
+
mustn't
|
565 |
+
mustnt
|
566 |
+
mv
|
567 |
+
mw
|
568 |
+
mx
|
569 |
+
my
|
570 |
+
myself
|
571 |
+
myse”
|
572 |
+
mz
|
573 |
+
n
|
574 |
+
na
|
575 |
+
namely
|
576 |
+
nay
|
577 |
+
nc
|
578 |
+
nd
|
579 |
+
ne
|
580 |
+
nearly
|
581 |
+
necessarily
|
582 |
+
necessary
|
583 |
+
need
|
584 |
+
needed
|
585 |
+
needing
|
586 |
+
needn't
|
587 |
+
neednt
|
588 |
+
needs
|
589 |
+
neither
|
590 |
+
net
|
591 |
+
never
|
592 |
+
neverf
|
593 |
+
neverless
|
594 |
+
nevertheless
|
595 |
+
newer
|
596 |
+
newest
|
597 |
+
nf
|
598 |
+
ng
|
599 |
+
ni
|
600 |
+
nl
|
601 |
+
no
|
602 |
+
no-one
|
603 |
+
nobody
|
604 |
+
non
|
605 |
+
none
|
606 |
+
nonetheless
|
607 |
+
noone
|
608 |
+
nor
|
609 |
+
normally
|
610 |
+
nos
|
611 |
+
not
|
612 |
+
noted
|
613 |
+
nothing
|
614 |
+
notwithstanding
|
615 |
+
nowhere
|
616 |
+
np
|
617 |
+
nr
|
618 |
+
nu
|
619 |
+
null
|
620 |
+
nz
|
621 |
+
o
|
622 |
+
obtain
|
623 |
+
obtained
|
624 |
+
obviously
|
625 |
+
of
|
626 |
+
off
|
627 |
+
often
|
628 |
+
oh
|
629 |
+
ok
|
630 |
+
okay
|
631 |
+
om
|
632 |
+
omitted
|
633 |
+
on
|
634 |
+
once
|
635 |
+
one
|
636 |
+
one's
|
637 |
+
ones
|
638 |
+
only
|
639 |
+
onto
|
640 |
+
open
|
641 |
+
opened
|
642 |
+
opening
|
643 |
+
opens
|
644 |
+
opposite
|
645 |
+
or
|
646 |
+
ord
|
647 |
+
order
|
648 |
+
ordered
|
649 |
+
ordering
|
650 |
+
orders
|
651 |
+
org
|
652 |
+
other
|
653 |
+
others
|
654 |
+
otherwise
|
655 |
+
ought
|
656 |
+
oughtn't
|
657 |
+
oughtnt
|
658 |
+
our
|
659 |
+
ours
|
660 |
+
ourselves
|
661 |
+
out
|
662 |
+
over
|
663 |
+
overall
|
664 |
+
owing
|
665 |
+
own
|
666 |
+
p
|
667 |
+
pa
|
668 |
+
part
|
669 |
+
parted
|
670 |
+
particular
|
671 |
+
particularly
|
672 |
+
parting
|
673 |
+
parts
|
674 |
+
past
|
675 |
+
pe
|
676 |
+
per
|
677 |
+
perhaps
|
678 |
+
pf
|
679 |
+
pg
|
680 |
+
ph
|
681 |
+
pk
|
682 |
+
pl
|
683 |
+
place
|
684 |
+
placed
|
685 |
+
places
|
686 |
+
please
|
687 |
+
pm
|
688 |
+
pmid
|
689 |
+
pn
|
690 |
+
pointed
|
691 |
+
pointing
|
692 |
+
poorly
|
693 |
+
possible
|
694 |
+
possibly
|
695 |
+
potentially
|
696 |
+
pp
|
697 |
+
pr
|
698 |
+
predominantly
|
699 |
+
present
|
700 |
+
presented
|
701 |
+
presenting
|
702 |
+
presents
|
703 |
+
presumably
|
704 |
+
previously
|
705 |
+
primarily
|
706 |
+
probably
|
707 |
+
problem
|
708 |
+
problems
|
709 |
+
promptly
|
710 |
+
proud
|
711 |
+
provided
|
712 |
+
provides
|
713 |
+
pt
|
714 |
+
put
|
715 |
+
puts
|
716 |
+
pw
|
717 |
+
py
|
718 |
+
q
|
719 |
+
qa
|
720 |
+
que
|
721 |
+
quickly
|
722 |
+
quite
|
723 |
+
qv
|
724 |
+
r
|
725 |
+
rather
|
726 |
+
rd
|
727 |
+
re
|
728 |
+
readily
|
729 |
+
really
|
730 |
+
reasonably
|
731 |
+
recent
|
732 |
+
recently
|
733 |
+
ref
|
734 |
+
refs
|
735 |
+
regarding
|
736 |
+
regardless
|
737 |
+
regards
|
738 |
+
related
|
739 |
+
relatively
|
740 |
+
reserved
|
741 |
+
respectively
|
742 |
+
resulted
|
743 |
+
resulting
|
744 |
+
results
|
745 |
+
ro
|
746 |
+
ru
|
747 |
+
rw
|
748 |
+
s
|
749 |
+
sa
|
750 |
+
said
|
751 |
+
same
|
752 |
+
saw
|
753 |
+
saying
|
754 |
+
says
|
755 |
+
sb
|
756 |
+
sc
|
757 |
+
sd
|
758 |
+
se
|
759 |
+
sec
|
760 |
+
section
|
761 |
+
see
|
762 |
+
seeing
|
763 |
+
seem
|
764 |
+
seemed
|
765 |
+
seeming
|
766 |
+
seems
|
767 |
+
seen
|
768 |
+
sees
|
769 |
+
self
|
770 |
+
selves
|
771 |
+
sensible
|
772 |
+
sent
|
773 |
+
serious
|
774 |
+
seriously
|
775 |
+
several
|
776 |
+
sg
|
777 |
+
sh
|
778 |
+
shall
|
779 |
+
shan't
|
780 |
+
shant
|
781 |
+
she
|
782 |
+
she'd
|
783 |
+
she'll
|
784 |
+
she's
|
785 |
+
shed
|
786 |
+
shell
|
787 |
+
shes
|
788 |
+
should
|
789 |
+
should've
|
790 |
+
shouldn
|
791 |
+
shouldn't
|
792 |
+
shouldnt
|
793 |
+
showed
|
794 |
+
showing
|
795 |
+
shown
|
796 |
+
showns
|
797 |
+
si
|
798 |
+
side
|
799 |
+
sides
|
800 |
+
significant
|
801 |
+
significantly
|
802 |
+
similar
|
803 |
+
similarly
|
804 |
+
since
|
805 |
+
sincere
|
806 |
+
site
|
807 |
+
sj
|
808 |
+
sk
|
809 |
+
sl
|
810 |
+
slightly
|
811 |
+
sm
|
812 |
+
sn
|
813 |
+
so
|
814 |
+
some
|
815 |
+
somebody
|
816 |
+
someday
|
817 |
+
somehow
|
818 |
+
someone
|
819 |
+
somethan
|
820 |
+
something
|
821 |
+
sometime
|
822 |
+
sometimes
|
823 |
+
somewhat
|
824 |
+
somewhere
|
825 |
+
specifically
|
826 |
+
specified
|
827 |
+
specify
|
828 |
+
specifying
|
829 |
+
sr
|
830 |
+
st
|
831 |
+
state
|
832 |
+
states
|
833 |
+
still
|
834 |
+
stop
|
835 |
+
strongly
|
836 |
+
su
|
837 |
+
sub
|
838 |
+
substantially
|
839 |
+
successfully
|
840 |
+
such
|
841 |
+
sufficiently
|
842 |
+
suggest
|
843 |
+
sup
|
844 |
+
sure
|
845 |
+
sv
|
846 |
+
sy
|
847 |
+
sz
|
848 |
+
t
|
849 |
+
t's
|
850 |
+
take
|
851 |
+
taken
|
852 |
+
taking
|
853 |
+
tc
|
854 |
+
td
|
855 |
+
tell
|
856 |
+
tends
|
857 |
+
tf
|
858 |
+
tg
|
859 |
+
th
|
860 |
+
than
|
861 |
+
thank
|
862 |
+
thanks
|
863 |
+
thanx
|
864 |
+
that
|
865 |
+
that'll
|
866 |
+
that's
|
867 |
+
that've
|
868 |
+
thatll
|
869 |
+
thats
|
870 |
+
thatve
|
871 |
+
the
|
872 |
+
their
|
873 |
+
theirs
|
874 |
+
them
|
875 |
+
themselves
|
876 |
+
then
|
877 |
+
thence
|
878 |
+
there
|
879 |
+
there'd
|
880 |
+
there'll
|
881 |
+
there're
|
882 |
+
there's
|
883 |
+
there've
|
884 |
+
thereafter
|
885 |
+
thereby
|
886 |
+
thered
|
887 |
+
therefore
|
888 |
+
therein
|
889 |
+
therell
|
890 |
+
thereof
|
891 |
+
therere
|
892 |
+
theres
|
893 |
+
thereto
|
894 |
+
thereupon
|
895 |
+
thereve
|
896 |
+
these
|
897 |
+
they
|
898 |
+
they'd
|
899 |
+
they'll
|
900 |
+
they're
|
901 |
+
they've
|
902 |
+
theyd
|
903 |
+
theyll
|
904 |
+
theyre
|
905 |
+
theyve
|
906 |
+
thick
|
907 |
+
thin
|
908 |
+
thing
|
909 |
+
things
|
910 |
+
think
|
911 |
+
thinks
|
912 |
+
third
|
913 |
+
thirty
|
914 |
+
this
|
915 |
+
thorough
|
916 |
+
thoroughly
|
917 |
+
those
|
918 |
+
thou
|
919 |
+
though
|
920 |
+
thoughh
|
921 |
+
thought
|
922 |
+
thoughts
|
923 |
+
thousand
|
924 |
+
throug
|
925 |
+
through
|
926 |
+
throughout
|
927 |
+
thru
|
928 |
+
thus
|
929 |
+
til
|
930 |
+
till
|
931 |
+
tis
|
932 |
+
tj
|
933 |
+
tk
|
934 |
+
tm
|
935 |
+
tn
|
936 |
+
to
|
937 |
+
today
|
938 |
+
together
|
939 |
+
too
|
940 |
+
took
|
941 |
+
tp
|
942 |
+
tr
|
943 |
+
tried
|
944 |
+
tries
|
945 |
+
truly
|
946 |
+
trying
|
947 |
+
ts
|
948 |
+
tt
|
949 |
+
turn
|
950 |
+
turned
|
951 |
+
turning
|
952 |
+
turns
|
953 |
+
tw
|
954 |
+
twas
|
955 |
+
tz
|
956 |
+
u
|
957 |
+
ua
|
958 |
+
ug
|
959 |
+
uk
|
960 |
+
um
|
961 |
+
un
|
962 |
+
underneath
|
963 |
+
undoing
|
964 |
+
unfortunately
|
965 |
+
unless
|
966 |
+
unlike
|
967 |
+
unlikely
|
968 |
+
until
|
969 |
+
unto
|
970 |
+
upon
|
971 |
+
ups
|
972 |
+
us
|
973 |
+
use
|
974 |
+
used
|
975 |
+
useful
|
976 |
+
usefully
|
977 |
+
usefulness
|
978 |
+
uses
|
979 |
+
using
|
980 |
+
usually
|
981 |
+
uucp
|
982 |
+
uy
|
983 |
+
uz
|
984 |
+
v
|
985 |
+
va
|
986 |
+
value
|
987 |
+
various
|
988 |
+
vc
|
989 |
+
ve
|
990 |
+
versus
|
991 |
+
very
|
992 |
+
vg
|
993 |
+
vi
|
994 |
+
via
|
995 |
+
viz
|
996 |
+
vn
|
997 |
+
vol
|
998 |
+
vols
|
999 |
+
vs
|
1000 |
+
vu
|
1001 |
+
w
|
1002 |
+
want
|
1003 |
+
wanted
|
1004 |
+
wanting
|
1005 |
+
wants
|
1006 |
+
was
|
1007 |
+
wasn
|
1008 |
+
wasn't
|
1009 |
+
wasnt
|
1010 |
+
way
|
1011 |
+
ways
|
1012 |
+
we
|
1013 |
+
we'd
|
1014 |
+
we'll
|
1015 |
+
we're
|
1016 |
+
we've
|
1017 |
+
web
|
1018 |
+
wed
|
1019 |
+
welcome
|
1020 |
+
well
|
1021 |
+
wells
|
1022 |
+
went
|
1023 |
+
were
|
1024 |
+
weren
|
1025 |
+
weren't
|
1026 |
+
werent
|
1027 |
+
weve
|
1028 |
+
wf
|
1029 |
+
what
|
1030 |
+
what'd
|
1031 |
+
what'll
|
1032 |
+
what's
|
1033 |
+
what've
|
1034 |
+
whatever
|
1035 |
+
whatll
|
1036 |
+
whats
|
1037 |
+
whatve
|
1038 |
+
when
|
1039 |
+
when'd
|
1040 |
+
when'll
|
1041 |
+
when's
|
1042 |
+
whence
|
1043 |
+
whenever
|
1044 |
+
where
|
1045 |
+
where'd
|
1046 |
+
where'll
|
1047 |
+
where's
|
1048 |
+
whereafter
|
1049 |
+
whereas
|
1050 |
+
whereby
|
1051 |
+
wherein
|
1052 |
+
wheres
|
1053 |
+
whereupon
|
1054 |
+
wherever
|
1055 |
+
whether
|
1056 |
+
which
|
1057 |
+
whichever
|
1058 |
+
while
|
1059 |
+
whilst
|
1060 |
+
whim
|
1061 |
+
whither
|
1062 |
+
who
|
1063 |
+
who'd
|
1064 |
+
who'll
|
1065 |
+
who's
|
1066 |
+
whod
|
1067 |
+
whoever
|
1068 |
+
whole
|
1069 |
+
wholl
|
1070 |
+
whom
|
1071 |
+
whomever
|
1072 |
+
whos
|
1073 |
+
whose
|
1074 |
+
why
|
1075 |
+
why'd
|
1076 |
+
why'll
|
1077 |
+
why's
|
1078 |
+
widely
|
1079 |
+
width
|
1080 |
+
will
|
1081 |
+
willing
|
1082 |
+
with
|
1083 |
+
within
|
1084 |
+
without
|
1085 |
+
won
|
1086 |
+
won't
|
1087 |
+
wonder
|
1088 |
+
wont
|
1089 |
+
words
|
1090 |
+
worked
|
1091 |
+
working
|
1092 |
+
works
|
1093 |
+
world
|
1094 |
+
would
|
1095 |
+
would've
|
1096 |
+
wouldn
|
1097 |
+
wouldn't
|
1098 |
+
wouldnt
|
1099 |
+
ws
|
1100 |
+
www
|
1101 |
+
x
|
1102 |
+
y
|
1103 |
+
ye
|
1104 |
+
year
|
1105 |
+
years
|
1106 |
+
yes
|
1107 |
+
yet
|
1108 |
+
you
|
1109 |
+
you'd
|
1110 |
+
you'll
|
1111 |
+
you're
|
1112 |
+
you've
|
1113 |
+
youd
|
1114 |
+
youll
|
1115 |
+
your
|
1116 |
+
youre
|
1117 |
+
yours
|
1118 |
+
yourself
|
1119 |
+
yourselves
|
1120 |
+
youve
|
1121 |
+
yt
|
1122 |
+
yu
|
1123 |
+
z
|
1124 |
+
za
|
1125 |
+
zm
|
1126 |
+
zr
|
kgs_binding/kg_base_wrapper.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
from typing import Tuple, Optional, List
|
9 |
+
|
10 |
+
# Remote modules
|
11 |
+
from nltk.stem import WordNetLemmatizer
|
12 |
+
|
13 |
+
# Local modules
|
14 |
+
|
15 |
+
#############################
|
16 |
+
# Constants
|
17 |
+
#############################
|
18 |
+
|
19 |
+
class KGBaseHandler(ABC):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
self.st = WordNetLemmatizer()
|
23 |
+
|
24 |
+
def normalize_noun(self, ent):
|
25 |
+
try:
|
26 |
+
noun = self.st.lemmatize(ent, pos='n')
|
27 |
+
noun = self.st.lemmatize(noun, pos='v')
|
28 |
+
except Exception as _:
|
29 |
+
noun = ent[:-1] if ent[-1] == 's' else ent
|
30 |
+
return noun
|
31 |
+
|
32 |
+
def normalize_nouns(self, ent):
|
33 |
+
local_ent = ent[:]
|
34 |
+
nouns = local_ent.split(' ')
|
35 |
+
if len(nouns) == 1:
|
36 |
+
return ' '.join([self.normalize_noun(e) for e in nouns])
|
37 |
+
return local_ent
|
38 |
+
|
39 |
+
def ignore_less_relevant_connection(self, relations):
|
40 |
+
if len(relations) >= 2:
|
41 |
+
for r in relations:
|
42 |
+
if r != 'related_to':
|
43 |
+
return r
|
44 |
+
return relations[0]
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def get_relation_types(self) -> List[str]:
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def exists_relation_between(self, concept, other_concept) -> bool:
|
52 |
+
pass
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
|
56 |
+
pass
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def get_related_concepts(self, concept) -> Optional[List[str]]:
|
60 |
+
pass
|
61 |
+
|
62 |
+
@abstractmethod
|
63 |
+
def does_concept_exist(self, concept) -> bool:
|
64 |
+
pass
|
65 |
+
|
66 |
+
class NoKnowledge(KGBaseHandler):
|
67 |
+
def __init__(self):
|
68 |
+
super(NoKnowledge, self).__init__()
|
69 |
+
|
70 |
+
def get_relation_types(self) -> List[str]:
|
71 |
+
return []
|
72 |
+
|
73 |
+
def exists_relation_between(self, concept, other_concept) -> bool:
|
74 |
+
return False
|
75 |
+
|
76 |
+
def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
|
77 |
+
return (None, None)
|
78 |
+
|
79 |
+
def does_concept_exist(self, concept) -> bool:
|
80 |
+
return False
|
kgs_binding/kg_qa_binding_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from typing import List, Tuple
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
# Remote modules
|
10 |
+
|
11 |
+
# Local modules
|
12 |
+
from .kg_base_wrapper import KGBaseHandler
|
13 |
+
from .swow_handler import SwowHandler
|
14 |
+
from .conceptnet_handler import ConceptNetHandler
|
15 |
+
from utils import read_json_file_2_dict, Data_Type
|
16 |
+
|
17 |
+
#############################
|
18 |
+
# Constants
|
19 |
+
#############################
|
20 |
+
|
21 |
+
#############################
|
22 |
+
# Stuff
|
23 |
+
#############################
|
24 |
+
|
25 |
+
class KGType(Enum):
|
26 |
+
SWOW = 'swow'
|
27 |
+
CSKG = 'cskg'
|
28 |
+
CONCEPTNET = 'conceptnet'
|
29 |
+
|
30 |
+
def load_kg_handler(kg_type: KGType):
|
31 |
+
if kg_type.value == KGType.SWOW.value:
|
32 |
+
return SwowHandler()
|
33 |
+
elif kg_type.value == KGType.CONCEPTNET.value:
|
34 |
+
return ConceptNetHandler()
|
35 |
+
else:
|
36 |
+
raise NotImplementedError()
|
37 |
+
|
38 |
+
def _load_data_paths_metadata():
|
39 |
+
try:
|
40 |
+
data = read_json_file_2_dict('data_config.json', store_dir='run_config')
|
41 |
+
except:
|
42 |
+
data = None
|
43 |
+
return data
|
44 |
+
|
45 |
+
def from_relations_path_2_relations(dataset_types: List[Data_Type], metadata):
|
46 |
+
relations = []
|
47 |
+
print('metadata:', metadata)
|
48 |
+
for dataset_type in dataset_types:
|
49 |
+
qa_meta_data = metadata[dataset_type.value]
|
50 |
+
filename_path, dir_data = qa_meta_data['local']
|
51 |
+
print(filename_path, dir)
|
52 |
+
data = read_json_file_2_dict(filename_path, dir_data)
|
53 |
+
relations.extend(data)
|
54 |
+
return relations
|
55 |
+
|
56 |
+
def KGHandler_to_str(kg_handler: KGBaseHandler) -> str:
|
57 |
+
if isinstance(kg_handler, SwowHandler):
|
58 |
+
return 'swow'
|
59 |
+
elif isinstance(kg_handler, ConceptNetHandler):
|
60 |
+
return 'conceptnet'
|
61 |
+
else:
|
62 |
+
raise NotImplementedError()
|
63 |
+
|
64 |
+
def get_kg_qa_data_metadata(kg_handler: KGBaseHandler) -> Tuple[str, str]:
|
65 |
+
kg_qa_data_path = _load_data_paths_metadata()
|
66 |
+
if isinstance(kg_handler, SwowHandler):
|
67 |
+
swow = kg_qa_data_path["swow"]
|
68 |
+
return swow
|
69 |
+
elif isinstance(kg_handler, ConceptNetHandler):
|
70 |
+
conceptnet = kg_qa_data_path["conceptnet"]
|
71 |
+
return conceptnet
|
72 |
+
else:
|
73 |
+
raise NotImplementedError()
|
kgs_binding/parsing_utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
import re
|
8 |
+
import string
|
9 |
+
|
10 |
+
# Remote modules
|
11 |
+
|
12 |
+
# Local modules
|
13 |
+
from utils import (
|
14 |
+
read_simple_text_file_2_vec
|
15 |
+
)
|
16 |
+
|
17 |
+
#############################
|
18 |
+
# Utils
|
19 |
+
#############################
|
20 |
+
|
21 |
+
class ParsingUtils:
|
22 |
+
|
23 |
+
STOPWORDS = read_simple_text_file_2_vec('english_stopwords.txt', store_dir='kgs_binding')
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def remove_pontuation(text):
|
27 |
+
text = re.sub(r"[^a-zA-Z]", " ", text)
|
28 |
+
return text.translate(str.maketrans('', '', string.punctuation))
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def clear_common_words(index_with_words):
|
32 |
+
return [(word, (s, e)) for (word, (s, e)) in index_with_words if word not in ParsingUtils.STOPWORDS]
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def is_word_a_relevant_one(ignore_common_words, word):
|
36 |
+
if ignore_common_words:
|
37 |
+
return word not in ParsingUtils.STOPWORDS
|
38 |
+
else:
|
39 |
+
return True
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_word_range_mapping(context, word_token):
|
43 |
+
word_token_splitted = word_token.split(' ')
|
44 |
+
if len(word_token_splitted) == 1:
|
45 |
+
word_token_start = context.index(word_token)
|
46 |
+
word_token_end = word_token_start + len(word_token) - 1 # inclusive end
|
47 |
+
else:
|
48 |
+
word_token_start = context.index(word_token_splitted[0])
|
49 |
+
word_token_end = word_token_start + len(word_token) - 1 # inclusive end
|
50 |
+
return word_token_start, word_token_end
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def n_grams(words_vector, n):
|
54 |
+
grams = [words_vector[i:i + n] for i in range(len(words_vector) - n + 1)]
|
55 |
+
print(grams)
|
56 |
+
return [' '.join(x) for x in grams]
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def n_grams_with_idx(words_vector, n):
|
60 |
+
grams = [words_vector[i:i + n] for i in range(len(words_vector) - n + 1)]
|
61 |
+
return [(' '.join([pair[0] for pair in x]), (x[0][1], x[-1][1]+len(x[-1][0]))) for x in grams]
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def n_grams_context_producer_simple(context, n_gram=2):
|
65 |
+
context_tokens = context.strip().split(' ')
|
66 |
+
#context_tokens = [w for w in context_tokens if w not in STOPWORDS]
|
67 |
+
n_grams_context = []
|
68 |
+
for i in range(n_gram):
|
69 |
+
n_gram_content = ParsingUtils.n_grams(context_tokens, n_gram-i)
|
70 |
+
n_grams_context.append(n_gram_content)
|
71 |
+
return n_grams_context
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def n_grams_n_words_extractor(context, n_gram=3):
|
75 |
+
context_tokens = context.strip().split(' ')
|
76 |
+
context_tokens_with_index_info=[]
|
77 |
+
word_idx=0
|
78 |
+
for word in context_tokens:
|
79 |
+
context_tokens_with_index_info.append((word, word_idx))
|
80 |
+
word_idx += len(word) + 1
|
81 |
+
#context_tokens = [w for w in context_tokens if w not in STOPWORDS]
|
82 |
+
n_grams_context = []
|
83 |
+
for i in range(n_gram):
|
84 |
+
n_gram_content = ParsingUtils.n_grams_with_idx(context_tokens_with_index_info, n_gram-i)
|
85 |
+
n_grams_context.extend(n_gram_content)
|
86 |
+
return n_grams_context
|
kgs_binding/relation_mapper_builder.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
from collections import deque
|
8 |
+
from collections import defaultdict
|
9 |
+
from typing import List, Dict, Optional
|
10 |
+
from ast import literal_eval
|
11 |
+
from random import sample
|
12 |
+
|
13 |
+
# Remote modules
|
14 |
+
|
15 |
+
# Local modules
|
16 |
+
from .kg_base_wrapper import KGBaseHandler
|
17 |
+
from .swow_handler import SwowHandler
|
18 |
+
|
19 |
+
from utils import (
|
20 |
+
read_json_file_2_dict,
|
21 |
+
Data_Type,
|
22 |
+
)
|
23 |
+
from .parsing_utils import ParsingUtils
|
24 |
+
|
25 |
+
#############################
|
26 |
+
# Constants
|
27 |
+
#############################
|
28 |
+
|
29 |
+
#############################
|
30 |
+
# Stuff
|
31 |
+
#############################
|
32 |
+
|
33 |
+
class RelationsMapperBuilder:
|
34 |
+
def __init__(self, knowledge: KGBaseHandler,
|
35 |
+
filename: Optional[str] = None,
|
36 |
+
file_dir: Optional[str] = None,
|
37 |
+
datatype: Optional[Data_Type] = None,
|
38 |
+
tok_sep:str = '</s>',
|
39 |
+
use_extra_relations=True):
|
40 |
+
self.tok_sep = tok_sep
|
41 |
+
self.knowledge = knowledge
|
42 |
+
self.swow_knowledge = SwowHandler()
|
43 |
+
self.use_extra_relations = use_extra_relations
|
44 |
+
if filename and file_dir and datatype:
|
45 |
+
full_context = self.load_data(filename, file_dir)
|
46 |
+
self.relevant_context = self.fetch_relevant_context_from_data(data=full_context, datatype=datatype)
|
47 |
+
|
48 |
+
def load_data(self, filename='commongen_qa_final.json', store_dir='./'):
|
49 |
+
data = read_json_file_2_dict(filename=filename, store_dir=store_dir)
|
50 |
+
print('data[0]:', data[0])
|
51 |
+
return data
|
52 |
+
|
53 |
+
def fetch_relevant_context_from_data(self, data: List[Dict], datatype:Data_Type = Data_Type.COMMONGEN_QA):
|
54 |
+
if datatype == Data_Type.COMMONGEN_QA:
|
55 |
+
model_input = [data_unit.get('title').lower() for data_unit in data]
|
56 |
+
elif datatype in [Data_Type.ELI5, Data_Type.STACK_EXCHANGE]:
|
57 |
+
model_input = [data_unit.get('question').lower() for data_unit in data]
|
58 |
+
elif datatype in [Data_Type.COMMONSENSE_QA]:
|
59 |
+
#questions = [data_unit.get('question').lower() for data_unit in data]
|
60 |
+
#model_input = datasets_parsing_utils.compose_commonsenseqa_data(data)
|
61 |
+
model_input = [data_unit.get('input_data') for data_unit in data]
|
62 |
+
elif datatype in [Data_Type.COMMONGEN]:
|
63 |
+
#questions = [data_unit.get('input_data').lower() for data_unit in data]
|
64 |
+
#model_input = datasets_parsing_utils.compose_commongen_data(data)
|
65 |
+
model_input = [data_unit.get('input_data') for data_unit in data]
|
66 |
+
else:
|
67 |
+
model_input = []
|
68 |
+
return model_input
|
69 |
+
|
70 |
+
def get_kg_concepts_from_context(self, context=None, clear_common_wds=False):
|
71 |
+
if not context:
|
72 |
+
context = self.relevant_context
|
73 |
+
context_words = []
|
74 |
+
for q_id, question in enumerate(context):
|
75 |
+
simple_question = ParsingUtils.remove_pontuation(question)
|
76 |
+
n_grams = ParsingUtils.n_grams_n_words_extractor(simple_question)
|
77 |
+
words = self.relevant_entities_extractor(n_grams)
|
78 |
+
if clear_common_wds:
|
79 |
+
words = ParsingUtils.clear_common_words(words)
|
80 |
+
simple_words = [word[0] for word in words]
|
81 |
+
context_words.append(simple_words)
|
82 |
+
return context_words
|
83 |
+
|
84 |
+
def obtain_concept_neighbours(self, context_concepts:List[str], n_neighbours = 20):
|
85 |
+
"""
|
86 |
+
Use swow to get connected concepts, but then refer back to conceptnet for rich relations
|
87 |
+
"""
|
88 |
+
neighbours = []
|
89 |
+
for concept in context_concepts:
|
90 |
+
external_neighbour_concepts = self.swow_knowledge.get_related_concepts(concept)
|
91 |
+
relevant_concepts = external_neighbour_concepts
|
92 |
+
#local_neighbour_concepts = self.knowledge.get_related_concepts(concept)
|
93 |
+
#relevant_concepts = [ext_concept for ext_concept in external_neighbour_concepts if ext_concept in local_neighbour_concepts]
|
94 |
+
neighbours.extend(relevant_concepts)
|
95 |
+
n_neighbours = min(n_neighbours, len(neighbours))
|
96 |
+
some_neighbours = sample(neighbours, n_neighbours)
|
97 |
+
#print('context_concepts:', context_concepts)
|
98 |
+
#print('some_neighbours:', some_neighbours)
|
99 |
+
return some_neighbours
|
100 |
+
|
101 |
+
|
102 |
+
def get_relations_mapping_complex(self, context=None, clear_common_wds=False):
|
103 |
+
if not context:
|
104 |
+
context = self.relevant_context
|
105 |
+
relations_info = deque()
|
106 |
+
for q_id, question in enumerate(context):
|
107 |
+
simple_question = ParsingUtils.remove_pontuation(question)
|
108 |
+
n_grams = ParsingUtils.n_grams_n_words_extractor(simple_question)
|
109 |
+
words = self.relevant_entities_extractor(n_grams)
|
110 |
+
if clear_common_wds:
|
111 |
+
words = ParsingUtils.clear_common_words(words)
|
112 |
+
#print(f'question: {question}')
|
113 |
+
#print(f'words: {words}')
|
114 |
+
relation_context_between_words = defaultdict(dict)
|
115 |
+
known_tokens = set()
|
116 |
+
for token_i, (first_word_token, first_word_range) in enumerate(words[:-1]):
|
117 |
+
known_tokens.add(first_word_token)
|
118 |
+
first_word_range_str = str(first_word_range)
|
119 |
+
# normalize
|
120 |
+
first_word_phrase_normalized = self.knowledge.normalize_nouns(first_word_token)
|
121 |
+
for (second_word_token, second_word_range) in [w for w in words[token_i + 1:] if w not in known_tokens]:
|
122 |
+
second_word_range_str = str(second_word_range)
|
123 |
+
second_word_phrase_normalized = self.knowledge.normalize_nouns(second_word_token)
|
124 |
+
left_2_right, right_2_left = self.knowledge.relation_between(first_word_phrase_normalized, second_word_phrase_normalized)
|
125 |
+
#print(first_word_token, second_word_token, left_2_right, right_2_left)
|
126 |
+
if left_2_right:
|
127 |
+
relation_context_between_words[first_word_range_str][second_word_range_str] = left_2_right
|
128 |
+
if right_2_left:
|
129 |
+
relation_context_between_words[second_word_range_str][first_word_range_str] = right_2_left
|
130 |
+
relations_info.append(dict(relation_context_between_words))
|
131 |
+
return list(relations_info)
|
132 |
+
|
133 |
+
def get_concepts_from_context(self, context=None, clear_common_wds=False,alignment=0):
|
134 |
+
relations_info = self.get_relations_mapping_complex(context=[context], clear_common_wds=clear_common_wds)
|
135 |
+
words = []
|
136 |
+
#print('relations_info here:', relations_info)
|
137 |
+
for rels in relations_info:
|
138 |
+
for coords, v in rels.items():
|
139 |
+
coords_tuple = literal_eval(coords)
|
140 |
+
i,j = coords_tuple
|
141 |
+
words.append(context[i+alignment:j+alignment])
|
142 |
+
for coords_other, rel in v.items():
|
143 |
+
coords_other_tuple = literal_eval(coords_other)
|
144 |
+
i_other, j_other = coords_other_tuple
|
145 |
+
words.append(context[i_other+alignment: j_other+alignment])
|
146 |
+
returning_words = list(set(words))
|
147 |
+
#print('returning_words:', returning_words)
|
148 |
+
return returning_words
|
149 |
+
|
150 |
+
def relevant_entities_extractor(self, n_grams_n_words, verbose_output=True):
|
151 |
+
non_overlapping_knowledge = {}
|
152 |
+
# print(n_grams_n_words)
|
153 |
+
for concept, (idx_start, idx_end) in n_grams_n_words:
|
154 |
+
normalized_concept = self.knowledge.normalize_nouns(concept)
|
155 |
+
exists = self.knowledge.does_concept_exist(normalized_concept)
|
156 |
+
#print('exists: ', concept, normalized_concept, exists)
|
157 |
+
if exists and idx_start not in non_overlapping_knowledge and \
|
158 |
+
idx_end not in non_overlapping_knowledge:
|
159 |
+
non_overlapping_knowledge[idx_start] = (concept, idx_start, idx_end, 'start_idx')
|
160 |
+
non_overlapping_knowledge[idx_end] = (concept, idx_end, idx_end, 'end_idx')
|
161 |
+
if verbose_output:
|
162 |
+
return [(value[0], (value[1], value[2])) for k, value in sorted(non_overlapping_knowledge.items()) if value[-1] == 'start_idx']
|
163 |
+
else:
|
164 |
+
return [value[0] for k, value in sorted(non_overlapping_knowledge.items()) if value[-1] == 'start_idx']
|
kgs_binding/swow/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|
kgs_binding/swow/swow_knowledge.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
kgs_binding/swow_handler.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
import random
|
8 |
+
from typing import Tuple, Optional, List
|
9 |
+
|
10 |
+
# Remote modules
|
11 |
+
|
12 |
+
# Local modules
|
13 |
+
from .kg_base_wrapper import KGBaseHandler
|
14 |
+
|
15 |
+
|
16 |
+
from utils import read_json_file_2_dict
|
17 |
+
|
18 |
+
#############################
|
19 |
+
# Constants
|
20 |
+
#############################
|
21 |
+
|
22 |
+
#############################
|
23 |
+
# Stuff
|
24 |
+
#############################
|
25 |
+
|
26 |
+
class SwowHandler(KGBaseHandler):
|
27 |
+
def __init__(self, store_dir='kgs_binding/swow'):
|
28 |
+
super(SwowHandler, self).__init__()
|
29 |
+
self.swow: dict = self.load_stored_data(store_dir=store_dir)
|
30 |
+
|
31 |
+
def get_relation_types(self) -> List[str]:
|
32 |
+
return ['related_to']
|
33 |
+
|
34 |
+
def load_stored_data(self, filename='swow_knowledge.json', store_dir='kgs_binding/swow'):
|
35 |
+
self.swow = read_json_file_2_dict(filename, store_dir)
|
36 |
+
return self.swow
|
37 |
+
|
38 |
+
def exists_relation_between(self, concept, other_concept):
|
39 |
+
connections = self.swow.get(concept)
|
40 |
+
if not connections:
|
41 |
+
return False
|
42 |
+
for connetion in connections:
|
43 |
+
if connetion == other_concept:
|
44 |
+
return True
|
45 |
+
return False
|
46 |
+
|
47 |
+
def does_concept_exist(self, concept):
|
48 |
+
return self.swow.get(concept, None) is not None
|
49 |
+
|
50 |
+
def relation_between(self, concept, other_concept) -> Tuple[Optional[str], Optional[str]]:
|
51 |
+
exists_left_right = self.exists_relation_between(concept, other_concept)
|
52 |
+
exists_right_left = self.exists_relation_between(other_concept, concept)
|
53 |
+
relation = None
|
54 |
+
if exists_left_right or exists_right_left:
|
55 |
+
relation = 'related_to'
|
56 |
+
return relation, relation
|
57 |
+
|
58 |
+
def get_related_concepts(self, concept) -> Optional[List[str]]:
|
59 |
+
return self.swow.get(concept, [])
|
60 |
+
|
61 |
+
def simple_knowledge_prediction(self, knowledge):
|
62 |
+
kw = list(knowledge)
|
63 |
+
idx = random.randint(0, len(knowledge)-1) # 0-1-2
|
64 |
+
kw[idx] = '<mask>'
|
65 |
+
textual_knowledge_input = f'{kw[0]} {kw[1]} {kw[2]}'
|
66 |
+
label = f'{knowledge[0]} {knowledge[1]} {knowledge[2]}'
|
67 |
+
return f'{textual_knowledge_input},{label}\n', label
|
68 |
+
|
69 |
+
def create_mask_knowledge_for_model(self):
|
70 |
+
with open(f'bart_input/swow_bart.txt', 'w') as f:
|
71 |
+
for subject, objects in self.swow.items():
|
72 |
+
for obj in objects:
|
73 |
+
knowledge = (subject, 'is related to', obj)
|
74 |
+
w_kw, label = self.simple_knowledge_prediction(knowledge)
|
75 |
+
f.write(w_kw)
|
model_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#############################
|
3 |
+
# Imports
|
4 |
+
#############################
|
5 |
+
|
6 |
+
# Python modules
|
7 |
+
from typing import List
|
8 |
+
from random import randint
|
9 |
+
|
10 |
+
# Remote modules
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# Local modules
|
14 |
+
from utils import Head_Mask
|
15 |
+
|
16 |
+
#############################
|
17 |
+
# Constants
|
18 |
+
#############################
|
19 |
+
|
20 |
+
#############################
|
21 |
+
# Stuff
|
22 |
+
#############################
|
23 |
+
|
24 |
+
def create_layers_head_mask(config, head_mask_type: Head_Mask=Head_Mask.ALL, specific_heads: List[int] = None):
|
25 |
+
mask_heads = torch.zeros((config.encoder_layers, config.encoder_attention_heads))
|
26 |
+
if head_mask_type == Head_Mask.RANDOM:
|
27 |
+
for i in range(config.encoder_layers):
|
28 |
+
rand_idx = randint(0, config.encoder_attention_heads-1)
|
29 |
+
mask_heads[i, rand_idx] = 1
|
30 |
+
elif head_mask_type == Head_Mask.NONE:
|
31 |
+
mask_heads[:, :] = 1
|
32 |
+
elif head_mask_type == Head_Mask.ALL:
|
33 |
+
pass
|
34 |
+
elif head_mask_type == Head_Mask.SPECIFIC:
|
35 |
+
if specific_heads:
|
36 |
+
for layer_i in range(len(mask_heads)):
|
37 |
+
specific_head = specific_heads[layer_i] - 1
|
38 |
+
mask_heads[layer_i][specific_head] = 1
|
39 |
+
else:
|
40 |
+
mask_heads = torch.Tensor([[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0],
|
41 |
+
[1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0],
|
42 |
+
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0],
|
43 |
+
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
44 |
+
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
|
45 |
+
[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
|
46 |
+
[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
|
47 |
+
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
48 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1],
|
49 |
+
[0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1],
|
50 |
+
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1],
|
51 |
+
[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1]])
|
52 |
+
else:
|
53 |
+
raise NotImplementedError()
|
54 |
+
return mask_heads.tolist()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
numpy
|
4 |
+
matplotlib
|
utils.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#############################
|
2 |
+
# Imports and Contants #
|
3 |
+
#############################
|
4 |
+
|
5 |
+
# Python modules
|
6 |
+
from enum import Enum
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
|
11 |
+
# Remote packages
|
12 |
+
import torch
|
13 |
+
|
14 |
+
#############################
|
15 |
+
# utilities
|
16 |
+
#############################
|
17 |
+
|
18 |
+
class ScoringType(Enum):
|
19 |
+
DEFAULT = 'default'
|
20 |
+
MAX_PROB = 'max-prob'
|
21 |
+
INTERPOL = 'interpol'
|
22 |
+
CONSTRAINT = 'constraint'
|
23 |
+
MULTIPLE_CHOICE = 'multiple_choice'
|
24 |
+
|
25 |
+
class LossType(Enum):
|
26 |
+
DEFAULT = 'default'
|
27 |
+
CP_RP_DEF = 'cp-rp-def'
|
28 |
+
CP_DEF = 'cp-def'
|
29 |
+
PRP_NRP_DEF = 'prp-nrp-def'
|
30 |
+
|
31 |
+
class Head_Mask(Enum):
|
32 |
+
ALL = 'all'
|
33 |
+
NONE = 'none'
|
34 |
+
RANDOM = 'random'
|
35 |
+
SPECIFIC = 'specific'
|
36 |
+
|
37 |
+
class KGType(Enum):
|
38 |
+
SWOW = 'swow'
|
39 |
+
CSKG = 'cskg'
|
40 |
+
CONCEPTNET = 'conceptnet'
|
41 |
+
|
42 |
+
class Model_Type(Enum):
|
43 |
+
RELATIONS = 'relations'
|
44 |
+
MASK = 'mask'
|
45 |
+
DEFAULT = 'default'
|
46 |
+
|
47 |
+
def is_simple_mask_commonsense(self):
|
48 |
+
return self == Model_Type.MASK
|
49 |
+
|
50 |
+
def there_is_difference_between_relations(self):
|
51 |
+
return self == Model_Type.RELATIONS
|
52 |
+
|
53 |
+
class Data_Type(Enum):
|
54 |
+
ELI5 = 'eli5'
|
55 |
+
COMMONSENSE_QA = 'commonsense_qa'
|
56 |
+
COMMONGEN_QA = 'commongen_qa'
|
57 |
+
STACK_EXCHANGE = 'stackexchange_qa'
|
58 |
+
ASK_SCIENCE = 'ask_science_qa'
|
59 |
+
NATURAL_QUESTIONS = 'natural_questions'
|
60 |
+
LAMA = 'lama'
|
61 |
+
CONCEPTNET = 'conceptnet'
|
62 |
+
CUSTOM = 'custom'
|
63 |
+
COMMONGEN = 'commongen'
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def data_types_to_str(data_types):
|
67 |
+
datasets_str = '-'.join([x.value for x in data_types])
|
68 |
+
return datasets_str
|
69 |
+
|
70 |
+
#############################
|
71 |
+
# Models
|
72 |
+
#############################
|
73 |
+
|
74 |
+
MODELS_PRETRAINING_NAME = {
|
75 |
+
"bart_large": "facebook/bart-large",
|
76 |
+
"bart_large_fp32": "patrickvonplaten/bart-large-fp32",
|
77 |
+
"bart_large_tweak": "",
|
78 |
+
"bart_base": "facebook/bart-base"
|
79 |
+
}
|
80 |
+
|
81 |
+
CURRENT_PRETRAINING_NAME = MODELS_PRETRAINING_NAME.get('bart_large_fp32')
|
82 |
+
|
83 |
+
#############################
|
84 |
+
# Files Managment #
|
85 |
+
#############################
|
86 |
+
|
87 |
+
def create_directory(output_dir):
|
88 |
+
# Create output directory if needed
|
89 |
+
if not os.path.exists(output_dir):
|
90 |
+
try:
|
91 |
+
os.makedirs(output_dir)
|
92 |
+
except FileExistsError as _:
|
93 |
+
return
|
94 |
+
else:
|
95 |
+
print(f"Output directory {output_dir} already exists")
|
96 |
+
|
97 |
+
def read_simple_text_file_2_vec(filename, store_dir='.'):
|
98 |
+
with open(f'{store_dir}/{filename}', 'r') as f:
|
99 |
+
return f.read().split('\n')
|
100 |
+
|
101 |
+
def write_dict_2_json_file(json_object, filename, store_dir='.'):
|
102 |
+
create_directory(store_dir)
|
103 |
+
with open(f'{store_dir}/{filename}', 'w', encoding='utf-8') as file:
|
104 |
+
json.dump(json_object, file, ensure_ascii=False, indent=4)
|
105 |
+
|
106 |
+
|
107 |
+
def read_json_file_2_dict(filename, store_dir='.'):
|
108 |
+
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
|
109 |
+
return json.load(file)
|
110 |
+
|
111 |
+
def read_jsonl_file_2_dict(filename, store_dir='.'):
|
112 |
+
elements = []
|
113 |
+
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
|
114 |
+
for line in file:
|
115 |
+
elements.append(json.loads(line))
|
116 |
+
return elements
|
117 |
+
|
118 |
+
def read_txt_2_list(filename, store_dir='.'):
|
119 |
+
with open(f'{store_dir}/{filename}', 'r', encoding='utf-8') as file:
|
120 |
+
return file.read().split('\n')
|
121 |
+
|
122 |
+
#############################
|
123 |
+
# Data Structures helper functions
|
124 |
+
#############################
|
125 |
+
|
126 |
+
def get_chunks(lst, n):
|
127 |
+
"""Yield successive n-sized chunks from lst."""
|
128 |
+
jump = len(lst)//n
|
129 |
+
for i in range(0, len(lst), jump):
|
130 |
+
yield lst[i:i + jump]
|
131 |
+
|
132 |
+
def get_jump_chunks(lst, jump):
|
133 |
+
"""Yield successive n-sized chunks from lst."""
|
134 |
+
for i in range(0, len(lst), jump):
|
135 |
+
yield lst[i:i + jump]
|
136 |
+
|
137 |
+
def join_str_first(sep_str, lis):
|
138 |
+
return '{1}{0}'.format(sep_str.join(lis), sep_str).strip()
|
139 |
+
|
140 |
+
#############################
|
141 |
+
# Huggingface
|
142 |
+
#############################
|
143 |
+
|
144 |
+
def inputs_introspection_print(tokenizer, inputs):
|
145 |
+
input_ids = inputs.get('input_ids', None)
|
146 |
+
input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
|
147 |
+
labels_ids = inputs.get('labels', None)
|
148 |
+
labels_text = tokenizer.batch_decode(labels_ids, skip_special_tokens=False)
|
149 |
+
print('orginal input:', input_text[:2])
|
150 |
+
print("::::::::::::::::::::::::::")
|
151 |
+
print('orginal labels:', labels_text[:2])
|
152 |
+
print("==========|||||==========")
|
153 |
+
|
154 |
+
def tok_data_2_text(tokenizer, all_inputs):
|
155 |
+
def clean_input_text(text):
|
156 |
+
real_text = text.split(tokenizer.eos_token)[0]
|
157 |
+
real_text = real_text.replace(tokenizer.bos_token, '').strip()
|
158 |
+
return real_text
|
159 |
+
all_input_text, all_labels_text = [], []
|
160 |
+
for inputs in all_inputs:
|
161 |
+
input_ids = inputs.get('input_ids', None)
|
162 |
+
input_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
163 |
+
labels_ids = inputs.get('labels', None)
|
164 |
+
labels_text = tokenizer.decode(labels_ids, skip_special_tokens=True)
|
165 |
+
#print('input_text:', input_text)
|
166 |
+
#print('labels_text:', labels_text)
|
167 |
+
input_text = clean_input_text(input_text)
|
168 |
+
all_input_text.append(input_text)
|
169 |
+
all_labels_text.append(labels_text)
|
170 |
+
return all_input_text, all_labels_text
|
171 |
+
|
172 |
+
#############################
|
173 |
+
# Torch
|
174 |
+
#############################
|
175 |
+
|
176 |
+
def get_device(verbose:bool=True):
|
177 |
+
# If there's a GPU available...
|
178 |
+
if torch.cuda.is_available():
|
179 |
+
device = torch.device("cuda")
|
180 |
+
n_gpus = torch.cuda.device_count()
|
181 |
+
first_gpu = torch.cuda.get_device_name(0)
|
182 |
+
if verbose:
|
183 |
+
print(f'There are {n_gpus} GPU(s) available.')
|
184 |
+
print(f'GPU gonna be used: {first_gpu}')
|
185 |
+
else:
|
186 |
+
if verbose:
|
187 |
+
print('No GPU available, using the CPU instead.')
|
188 |
+
device = torch.device("cpu")
|
189 |
+
return device
|
190 |
+
|
191 |
+
#############################
|
192 |
+
# Timing
|
193 |
+
#############################
|
194 |
+
|
195 |
+
def timing_decorator(func):
|
196 |
+
def wrapper(*args, **kwargs):
|
197 |
+
start = time.time()
|
198 |
+
original_return_val = func(*args, **kwargs)
|
199 |
+
end = time.time()
|
200 |
+
print("time elapsed in ", func.__name__, ": ", end - start, sep='')
|
201 |
+
return original_return_val
|
202 |
+
|
203 |
+
return wrapper
|
204 |
+
|
205 |
+
#############################
|
206 |
+
# PRINTING UTILS
|
207 |
+
#############################
|
208 |
+
|
209 |
+
class LOGGER_COLORS:
|
210 |
+
HEADER = '\033[95m'
|
211 |
+
OKBLUE = '\033[94m'
|
212 |
+
INFOCYAN = '\033[96m'
|
213 |
+
OKGREEN = '\033[92m'
|
214 |
+
WARNING = '\033[93m'
|
215 |
+
FAIL = '\033[91m'
|
216 |
+
ENDC = '\033[0m'
|
217 |
+
BOLD = '\033[1m'
|
218 |
+
UNDERLINE = '\033[4m'
|
219 |
+
|
220 |
+
def print_info(logger, message):
|
221 |
+
logger.info(f'{LOGGER_COLORS.INFOCYAN}[INFO]{LOGGER_COLORS.ENDC}: {message}')
|
222 |
+
|
223 |
+
def print_success(logger, message):
|
224 |
+
logger.info(f'{LOGGER_COLORS.OKGREEN}[SUCCESS]{LOGGER_COLORS.ENDC}: {message}')
|
225 |
+
|
226 |
+
def print_warning(logger, message):
|
227 |
+
logger.info(f'{LOGGER_COLORS.WARNING}[WARNING]{LOGGER_COLORS.ENDC}: {message}')
|
228 |
+
|
229 |
+
def print_fail(logger, message):
|
230 |
+
logger.info(f'{LOGGER_COLORS.FAIL}[FAIL]{LOGGER_COLORS.ENDC}: {message}')
|