Spaces:
Runtime error
Runtime error
LightChen2333
commited on
Commit
•
37b9e99
1
Parent(s):
0882d77
Upload 34 files
Browse files- app.py +42 -0
- common/__init__.py +1 -0
- common/config.py +191 -0
- common/loader.py +333 -0
- common/logger.py +197 -0
- common/metric.py +343 -0
- common/model_manager.py +324 -0
- common/tokenizer.py +311 -0
- common/utils.py +489 -0
- model/__init__.py +3 -0
- model/decoder/__init__.py +5 -0
- model/decoder/agif_decoder.py +16 -0
- model/decoder/base_decoder.py +94 -0
- model/decoder/classifier.py +321 -0
- model/decoder/decoder_utils.py +155 -0
- model/decoder/gl_gin_decoder.py +47 -0
- model/decoder/interaction/__init__.py +10 -0
- model/decoder/interaction/agif_interaction.py +132 -0
- model/decoder/interaction/base_interaction.py +9 -0
- model/decoder/interaction/bi_model_interaction.py +74 -0
- model/decoder/interaction/dca_net_interaction.py +176 -0
- model/decoder/interaction/gl_gin_interaction.py +227 -0
- model/decoder/interaction/slot_gated_interaction.py +59 -0
- model/decoder/interaction/stack_interaction.py +36 -0
- model/encoder/__init__.py +5 -0
- model/encoder/auto_encoder.py +37 -0
- model/encoder/base_encoder.py +41 -0
- model/encoder/non_pretrained_encoder.py +212 -0
- model/encoder/pretrained_encoder.py +39 -0
- model/open_slu_model.py +64 -0
- save/stack/label.json +1 -0
- save/stack/model.pkl +3 -0
- save/stack/outputs.jsonl +0 -0
- save/stack/tokenizer.json +1 -0
app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from common.config import Config
|
4 |
+
from common.model_manager import ModelManager
|
5 |
+
|
6 |
+
config = Config.load_from_yaml("config/app.yaml")
|
7 |
+
model_manager = ModelManager(config)
|
8 |
+
model_manager.load()
|
9 |
+
|
10 |
+
|
11 |
+
def text_analysis(text):
|
12 |
+
print(text)
|
13 |
+
data = model_manager.predict(text)
|
14 |
+
html = """<link href="https://cdn.staticfile.org/twitter-bootstrap/5.1.1/css/bootstrap.min.css" rel="stylesheet">
|
15 |
+
<script src="https://cdn.staticfile.org/twitter-bootstrap/5.1.1/js/bootstrap.bundle.min.js"></script>"""
|
16 |
+
html += """<div style="background: white; padding: 16px;"><b>Intent:</b>"""
|
17 |
+
|
18 |
+
for intent in data["intent"]:
|
19 |
+
html += """<button type="button" class="btn btn-white">
|
20 |
+
<span class="badge text-dark btn-light">""" + intent + """</span> </button>"""
|
21 |
+
html += """<br /> <b>Slot:</b>"""
|
22 |
+
for t, slot in zip(data["text"], data["slot"]):
|
23 |
+
html += """<button type="button" class="btn btn-white">"""+t+"""<span class="badge text-dark" style="background-color: rgb(255, 255, 255);
|
24 |
+
color: rgb(62 62 62);
|
25 |
+
box-shadow: 2px 2px 7px 1px rgba(210, 210, 210, 0.42);">"""+slot+\
|
26 |
+
"""</span>
|
27 |
+
</button>"""
|
28 |
+
html+="</div>"
|
29 |
+
return html
|
30 |
+
|
31 |
+
|
32 |
+
demo = gr.Interface(
|
33 |
+
text_analysis,
|
34 |
+
gr.Textbox(placeholder="Enter sentence here..."),
|
35 |
+
["html"],
|
36 |
+
examples=[
|
37 |
+
["What a beautiful morning for a walk!"],
|
38 |
+
["It was the best of times, it was the worst of times."],
|
39 |
+
],
|
40 |
+
)
|
41 |
+
|
42 |
+
demo.launch()
|
common/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
common/config.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 10:55:43
|
6 |
+
Description: Configuration class to manage all process in OpenSLU like model construction, learning processing and so on.
|
7 |
+
|
8 |
+
'''
|
9 |
+
import re
|
10 |
+
|
11 |
+
from ruamel import yaml
|
12 |
+
import datetime
|
13 |
+
|
14 |
+
class Config(dict):
|
15 |
+
def __init__(self, *args, **kwargs):
|
16 |
+
""" init with dict as args
|
17 |
+
"""
|
18 |
+
dict.__init__(self, *args, **kwargs)
|
19 |
+
self.__dict__ = self
|
20 |
+
self.start_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
|
21 |
+
self.__autowired()
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def load_from_yaml(file_path:str)->"Config":
|
25 |
+
"""load config files with path
|
26 |
+
|
27 |
+
Args:
|
28 |
+
file_path (str): yaml configuration file path.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Config: config object.
|
32 |
+
"""
|
33 |
+
with open(file_path) as stream:
|
34 |
+
try:
|
35 |
+
return Config(yaml.safe_load(stream))
|
36 |
+
except yaml.YAMLError as exc:
|
37 |
+
print(exc)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def load_from_args(args)->"Config":
|
41 |
+
""" load args to replace item value in config files assigned with '--config_path' or '--model'
|
42 |
+
|
43 |
+
Args:
|
44 |
+
args (Any): args with command line.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Config: _description_
|
48 |
+
"""
|
49 |
+
if args.model is not None:
|
50 |
+
args.config_path = "config/" + args.model + ".yaml"
|
51 |
+
config = Config.load_from_yaml(args.config_path)
|
52 |
+
if args.dataset is not None:
|
53 |
+
config.__update_dataset(args.dataset)
|
54 |
+
if args.device is not None:
|
55 |
+
config["base"]["device"] = args.device
|
56 |
+
if args.learning_rate is not None:
|
57 |
+
config["optimizer"]["lr"] = args.learning_rate
|
58 |
+
if args.epoch_num is not None:
|
59 |
+
config["base"]["epoch_num"] = args.epoch_num
|
60 |
+
return config
|
61 |
+
|
62 |
+
def autoload_template(self):
|
63 |
+
""" search '{*}' template to excute as python code, support replace variable as any configure item
|
64 |
+
"""
|
65 |
+
self.__autoload_template(self.__dict__)
|
66 |
+
|
67 |
+
def __get_autoload_value(self, matched):
|
68 |
+
keys = matched.group()[1:-1].split(".")
|
69 |
+
temp = self.__dict__
|
70 |
+
for k in keys:
|
71 |
+
temp = temp[k]
|
72 |
+
return str(temp)
|
73 |
+
|
74 |
+
def __autoload_template(self, config:dict):
|
75 |
+
for k in config:
|
76 |
+
if isinstance(config, dict):
|
77 |
+
sub_config = config[k]
|
78 |
+
elif isinstance(config, list):
|
79 |
+
sub_config = k
|
80 |
+
else:
|
81 |
+
continue
|
82 |
+
if isinstance(sub_config, dict) or isinstance(sub_config, list):
|
83 |
+
self.__autoload_template(sub_config)
|
84 |
+
if isinstance(sub_config, str) and "{" in sub_config and "}" in sub_config:
|
85 |
+
res = re.sub(r'{.*?}', self.__get_autoload_value, config[k])
|
86 |
+
res_dict= {"res": None}
|
87 |
+
exec("res=" + res, res_dict)
|
88 |
+
config[k] = res_dict["res"]
|
89 |
+
|
90 |
+
def __update_dataset(self, dataset_name):
|
91 |
+
if dataset_name is not None and isinstance(dataset_name, str):
|
92 |
+
self.__dict__["dataset"]["dataset_name"] = dataset_name
|
93 |
+
|
94 |
+
def get_model_config(self):
|
95 |
+
return self.__dict__["model"]
|
96 |
+
|
97 |
+
def __autowired(self):
|
98 |
+
# Set encoder
|
99 |
+
encoder_config = self.__dict__["model"]["encoder"]
|
100 |
+
encoder_type = encoder_config["_model_target_"].split(".")[-1]
|
101 |
+
|
102 |
+
def get_output_dim(encoder_config):
|
103 |
+
encoder_type = encoder_config["_model_target_"].split(".")[-1]
|
104 |
+
if (encoder_type == "AutoEncoder" and encoder_config["encoder_name"] in ["lstm", "self-attention-lstm",
|
105 |
+
"bi-encoder"]) or encoder_type == "NoPretrainedEncoder":
|
106 |
+
output_dim = 0
|
107 |
+
if encoder_config.get("lstm"):
|
108 |
+
output_dim += encoder_config["lstm"]["output_dim"]
|
109 |
+
if encoder_config.get("attention"):
|
110 |
+
output_dim += encoder_config["attention"]["output_dim"]
|
111 |
+
return output_dim
|
112 |
+
else:
|
113 |
+
return encoder_config["output_dim"]
|
114 |
+
|
115 |
+
if encoder_type == "BiEncoder":
|
116 |
+
output_dim = get_output_dim(encoder_config["intent_encoder"]) + \
|
117 |
+
get_output_dim(encoder_config["slot_encoder"])
|
118 |
+
else:
|
119 |
+
output_dim = get_output_dim(encoder_config)
|
120 |
+
self.__dict__["model"]["encoder"]["output_dim"] = output_dim
|
121 |
+
|
122 |
+
# Set interaction
|
123 |
+
if "interaction" in self.__dict__["model"]["decoder"] and self.__dict__["model"]["decoder"]["interaction"].get(
|
124 |
+
"input_dim") is None:
|
125 |
+
self.__dict__["model"]["decoder"]["interaction"]["input_dim"] = output_dim
|
126 |
+
interaction_type = self.__dict__["model"]["decoder"]["interaction"]["_model_target_"].split(".")[-1]
|
127 |
+
if not ((encoder_type == "AutoEncoder" and encoder_config[
|
128 |
+
"encoder_name"] == "self-attention-lstm") or encoder_type == "SelfAttentionLSTMEncoder") and interaction_type != "BiModelWithoutDecoderInteraction":
|
129 |
+
output_dim = self.__dict__["model"]["decoder"]["interaction"]["output_dim"]
|
130 |
+
|
131 |
+
# Set classifier
|
132 |
+
if "slot_classifier" in self.__dict__["model"]["decoder"]:
|
133 |
+
if self.__dict__["model"]["decoder"]["slot_classifier"].get("input_dim") is None:
|
134 |
+
self.__dict__["model"]["decoder"]["slot_classifier"]["input_dim"] = output_dim
|
135 |
+
self.__dict__["model"]["decoder"]["slot_classifier"]["use_slot"] = True
|
136 |
+
if "intent_classifier" in self.__dict__["model"]["decoder"]:
|
137 |
+
if self.__dict__["model"]["decoder"]["intent_classifier"].get("input_dim") is None:
|
138 |
+
self.__dict__["model"]["decoder"]["intent_classifier"]["input_dim"] = output_dim
|
139 |
+
self.__dict__["model"]["decoder"]["intent_classifier"]["use_intent"] = True
|
140 |
+
|
141 |
+
def get_intent_label_num(self):
|
142 |
+
""" get the number of intent labels.
|
143 |
+
"""
|
144 |
+
classifier_conf = self.__dict__["model"]["decoder"]["intent_classifier"]
|
145 |
+
return classifier_conf["intent_label_num"] if "intent_label_num" in classifier_conf else 0
|
146 |
+
|
147 |
+
def get_slot_label_num(self):
|
148 |
+
""" get the number of slot labels.
|
149 |
+
"""
|
150 |
+
classifier_conf = self.__dict__["model"]["decoder"]["slot_classifier"]
|
151 |
+
return classifier_conf["slot_label_num"] if "slot_label_num" in classifier_conf else 0
|
152 |
+
|
153 |
+
def set_intent_label_num(self, intent_label_num):
|
154 |
+
""" set the number of intent labels.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
slot_label_num (int): the number of intent label
|
158 |
+
"""
|
159 |
+
self.__dict__["base"]["intent_label_num"] = intent_label_num
|
160 |
+
self.__dict__["model"]["decoder"]["intent_classifier"]["intent_label_num"] = intent_label_num
|
161 |
+
if "interaction" in self.__dict__["model"]["decoder"]:
|
162 |
+
|
163 |
+
self.__dict__["model"]["decoder"]["interaction"]["intent_label_num"] = intent_label_num
|
164 |
+
if self.__dict__["model"]["decoder"]["interaction"]["_model_target_"].split(".")[
|
165 |
+
-1] == "StackInteraction":
|
166 |
+
self.__dict__["model"]["decoder"]["slot_classifier"]["input_dim"] += intent_label_num
|
167 |
+
|
168 |
+
|
169 |
+
def set_slot_label_num(self, slot_label_num:int)->None:
|
170 |
+
"""set the number of slot label
|
171 |
+
|
172 |
+
Args:
|
173 |
+
slot_label_num (int): the number of slot label
|
174 |
+
"""
|
175 |
+
self.__dict__["base"]["slot_label_num"] = slot_label_num
|
176 |
+
self.__dict__["model"]["decoder"]["slot_classifier"]["slot_label_num"] = slot_label_num
|
177 |
+
if "interaction" in self.__dict__["model"]["decoder"]:
|
178 |
+
self.__dict__["model"]["decoder"]["interaction"]["slot_label_num"] = slot_label_num
|
179 |
+
|
180 |
+
def set_vocab_size(self, vocab_size):
|
181 |
+
"""set the size of vocabulary in non-pretrained tokenizer
|
182 |
+
Args:
|
183 |
+
slot_label_num (int): the number of slot label
|
184 |
+
"""
|
185 |
+
encoder_type = self.__dict__["model"]["encoder"]["_model_target_"].split(".")[-1]
|
186 |
+
encoder_name = self.__dict__["model"]["encoder"].get("encoder_name")
|
187 |
+
if encoder_type == "BiEncoder" or (encoder_type == "AutoEncoder" and encoder_name == "bi-encoder"):
|
188 |
+
self.__dict__["model"]["encoder"]["intent_encoder"]["embedding"]["vocab_size"] = vocab_size
|
189 |
+
self.__dict__["model"]["encoder"]["slot_encoder"]["embedding"]["vocab_size"] = vocab_size
|
190 |
+
elif self.__dict__["model"]["encoder"].get("embedding"):
|
191 |
+
self.__dict__["model"]["encoder"]["embedding"]["vocab_size"] = vocab_size
|
common/loader.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-02-07 19:26:06
|
6 |
+
Description: all class for load data.
|
7 |
+
|
8 |
+
'''
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import json
|
12 |
+
from datasets import load_dataset, Dataset
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
|
15 |
+
from common.utils import InputData
|
16 |
+
|
17 |
+
ABS_PATH=os.path.join(os.path.abspath(os.path.dirname(__file__)), "../")
|
18 |
+
|
19 |
+
class DataFactory(object):
|
20 |
+
def __init__(self, tokenizer,use_multi_intent=False, to_lower_case=True):
|
21 |
+
"""_summary_
|
22 |
+
|
23 |
+
Args:
|
24 |
+
tokenizer (Tokenizer): _description_
|
25 |
+
use_multi_intent (bool, optional): _description_. Defaults to False.
|
26 |
+
"""
|
27 |
+
self.tokenizer = tokenizer
|
28 |
+
self.slot_label_list = []
|
29 |
+
self.intent_label_list = []
|
30 |
+
self.use_multi = use_multi_intent
|
31 |
+
self.to_lower_case = to_lower_case
|
32 |
+
self.slot_label_dict = None
|
33 |
+
self.intent_label_dict = None
|
34 |
+
|
35 |
+
def __is_supported_datasets(self, dataset_name:str)->bool:
|
36 |
+
return dataset_name.lower() in ["atis", "snips", "mix-atis", "mix-atis"]
|
37 |
+
|
38 |
+
def load_dataset(self, dataset_config, split="train"):
|
39 |
+
# TODO: 关闭use_auth_token
|
40 |
+
dataset_name = None
|
41 |
+
if split not in dataset_config:
|
42 |
+
dataset_name = dataset_config.get("dataset_name")
|
43 |
+
elif self.__is_supported_datasets(dataset_config[split]):
|
44 |
+
dataset_name = dataset_config[split].lower()
|
45 |
+
if dataset_name is not None:
|
46 |
+
return load_dataset("LightChen2333/OpenSLU", dataset_name, split=split, use_auth_token=True)
|
47 |
+
else:
|
48 |
+
data_file = dataset_config[split]
|
49 |
+
data_dict = {"text": [], "slot": [], "intent":[]}
|
50 |
+
with open(data_file, encoding="utf-8") as f:
|
51 |
+
for line in f:
|
52 |
+
row = json.loads(line)
|
53 |
+
data_dict["text"].append(row["text"])
|
54 |
+
data_dict["slot"].append(row["slot"])
|
55 |
+
data_dict["intent"].append(row["intent"])
|
56 |
+
return Dataset.from_dict(data_dict)
|
57 |
+
|
58 |
+
def update_label_names(self, dataset):
|
59 |
+
for intent_labels in dataset["intent"]:
|
60 |
+
if self.use_multi:
|
61 |
+
intent_label = intent_labels.split("#")
|
62 |
+
else:
|
63 |
+
intent_label = [intent_labels]
|
64 |
+
for x in intent_label:
|
65 |
+
if x not in self.intent_label_list:
|
66 |
+
self.intent_label_list.append(x)
|
67 |
+
for slot_label in dataset["slot"]:
|
68 |
+
for x in slot_label:
|
69 |
+
if x not in self.slot_label_list:
|
70 |
+
self.slot_label_list.append(x)
|
71 |
+
self.intent_label_dict = {key: index for index,
|
72 |
+
key in enumerate(self.intent_label_list)}
|
73 |
+
self.slot_label_dict = {key: index for index,
|
74 |
+
key in enumerate(self.slot_label_list)}
|
75 |
+
|
76 |
+
def update_vocabulary(self, dataset):
|
77 |
+
if self.tokenizer.name_or_path in ["word_tokenizer"]:
|
78 |
+
for data in dataset:
|
79 |
+
self.tokenizer.add_instance(data["text"])
|
80 |
+
|
81 |
+
@staticmethod
|
82 |
+
def fast_align_data(text, padding_side="right"):
|
83 |
+
for i in range(len(text.input_ids)):
|
84 |
+
desired_output = []
|
85 |
+
for word_id in text.word_ids(i):
|
86 |
+
if word_id is not None:
|
87 |
+
start, end = text.word_to_tokens(
|
88 |
+
i, word_id, sequence_index=0 if padding_side == "right" else 1)
|
89 |
+
if start == end - 1:
|
90 |
+
tokens = [start]
|
91 |
+
else:
|
92 |
+
tokens = [start, end - 1]
|
93 |
+
if len(desired_output) == 0 or desired_output[-1] != tokens:
|
94 |
+
desired_output.append(tokens)
|
95 |
+
yield desired_output
|
96 |
+
|
97 |
+
def fast_align(self,
|
98 |
+
batch,
|
99 |
+
ignore_index=-100,
|
100 |
+
device="cuda",
|
101 |
+
config=None,
|
102 |
+
enable_label=True,
|
103 |
+
label2tensor=True):
|
104 |
+
if self.to_lower_case:
|
105 |
+
input_list = [[t.lower() for t in x["text"]] for x in batch]
|
106 |
+
else:
|
107 |
+
input_list = [x["text"] for x in batch]
|
108 |
+
text = self.tokenizer(input_list,
|
109 |
+
return_tensors="pt",
|
110 |
+
padding=True,
|
111 |
+
is_split_into_words=True,
|
112 |
+
truncation=True,
|
113 |
+
**config).to(device)
|
114 |
+
if enable_label:
|
115 |
+
if label2tensor:
|
116 |
+
|
117 |
+
slot_mask = torch.ones_like(text.input_ids) * ignore_index
|
118 |
+
for i, offsets in enumerate(
|
119 |
+
DataFactory.fast_align_data(text, padding_side=self.tokenizer.padding_side)):
|
120 |
+
num = 0
|
121 |
+
assert len(offsets) == len(batch[i]["text"])
|
122 |
+
assert len(offsets) == len(batch[i]["slot"])
|
123 |
+
for off in offsets:
|
124 |
+
slot_mask[i][off[0]
|
125 |
+
] = self.slot_label_dict[batch[i]["slot"][num]]
|
126 |
+
num += 1
|
127 |
+
slot = slot_mask.clone()
|
128 |
+
attentin_id = 0 if self.tokenizer.padding_side == "right" else 1
|
129 |
+
for i, slot_batch in enumerate(slot):
|
130 |
+
for j, x in enumerate(slot_batch):
|
131 |
+
if x == ignore_index and text.attention_mask[i][j] == attentin_id and (text.input_ids[i][
|
132 |
+
j] not in self.tokenizer.all_special_ids or text.input_ids[i][j] == self.tokenizer.unk_token_id):
|
133 |
+
slot[i][j] = slot[i][j - 1]
|
134 |
+
slot = slot.to(device)
|
135 |
+
if not self.use_multi:
|
136 |
+
intent = torch.tensor(
|
137 |
+
[self.intent_label_dict[x["intent"]] for x in batch]).to(device)
|
138 |
+
else:
|
139 |
+
one_hot = torch.zeros(
|
140 |
+
(len(batch), len(self.intent_label_list)), dtype=torch.float)
|
141 |
+
for index, b in enumerate(batch):
|
142 |
+
for x in b["intent"].split("#"):
|
143 |
+
one_hot[index][self.intent_label_dict[x]] = 1.
|
144 |
+
intent = one_hot.to(device)
|
145 |
+
else:
|
146 |
+
slot_mask = None
|
147 |
+
slot = [['#' for _ in range(text.input_ids.shape[1])]
|
148 |
+
for _ in range(text.input_ids.shape[0])]
|
149 |
+
for i, offsets in enumerate(DataFactory.fast_align_data(text)):
|
150 |
+
num = 0
|
151 |
+
for off in offsets:
|
152 |
+
slot[i][off[0]] = batch[i]["slot"][num]
|
153 |
+
num += 1
|
154 |
+
if not self.use_multi:
|
155 |
+
intent = [x["intent"] for x in batch]
|
156 |
+
else:
|
157 |
+
intent = [
|
158 |
+
[x for x in b["intent"].split("#")] for b in batch]
|
159 |
+
return InputData((text, slot, intent))
|
160 |
+
else:
|
161 |
+
return InputData((text, None, None))
|
162 |
+
|
163 |
+
def general_align_data(self, split_text_list, raw_text_list, encoded_text):
|
164 |
+
for i in range(len(split_text_list)):
|
165 |
+
desired_output = []
|
166 |
+
jdx = 0
|
167 |
+
offset = encoded_text.offset_mapping[i].tolist()
|
168 |
+
split_texts = split_text_list[i]
|
169 |
+
raw_text = raw_text_list[i]
|
170 |
+
last = 0
|
171 |
+
temp_offset = []
|
172 |
+
for off in offset:
|
173 |
+
s, e = off
|
174 |
+
if len(temp_offset) > 0 and (e != 0 and last == s):
|
175 |
+
len_1 = off[1] - off[0]
|
176 |
+
len_2 = temp_offset[-1][1] - temp_offset[-1][0]
|
177 |
+
if len_1 > len_2:
|
178 |
+
temp_offset.pop(-1)
|
179 |
+
temp_offset.append([0, 0])
|
180 |
+
temp_offset.append(off)
|
181 |
+
continue
|
182 |
+
temp_offset.append(off)
|
183 |
+
last = s
|
184 |
+
offset = temp_offset
|
185 |
+
for split_text in split_texts:
|
186 |
+
while jdx < len(offset) and offset[jdx][0] == 0 and offset[jdx][1] == 0:
|
187 |
+
jdx += 1
|
188 |
+
if jdx == len(offset):
|
189 |
+
continue
|
190 |
+
start_, end_ = offset[jdx]
|
191 |
+
tokens = None
|
192 |
+
if split_text == raw_text[start_:end_].strip():
|
193 |
+
tokens = [jdx]
|
194 |
+
else:
|
195 |
+
# Compute "xxx" -> "xx" "#x"
|
196 |
+
temp_jdx = jdx
|
197 |
+
last_str = raw_text[start_:end_].strip()
|
198 |
+
while last_str != split_text and temp_jdx < len(offset) - 1:
|
199 |
+
temp_jdx += 1
|
200 |
+
last_str += raw_text[offset[temp_jdx]
|
201 |
+
[0]:offset[temp_jdx][1]].strip()
|
202 |
+
|
203 |
+
if temp_jdx == jdx:
|
204 |
+
raise ValueError("Illegal Input data")
|
205 |
+
elif last_str == split_text:
|
206 |
+
tokens = [jdx, temp_jdx]
|
207 |
+
jdx = temp_jdx
|
208 |
+
else:
|
209 |
+
jdx -= 1
|
210 |
+
jdx += 1
|
211 |
+
if tokens is not None:
|
212 |
+
desired_output.append(tokens)
|
213 |
+
yield desired_output
|
214 |
+
|
215 |
+
def general_align(self,
|
216 |
+
batch,
|
217 |
+
ignore_index=-100,
|
218 |
+
device="cuda",
|
219 |
+
config=None,
|
220 |
+
enable_label=True,
|
221 |
+
label2tensor=True,
|
222 |
+
locale="en-US"):
|
223 |
+
if self.to_lower_case:
|
224 |
+
raw_data = [" ".join(x["text"]).lower() if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in
|
225 |
+
batch]
|
226 |
+
input_list = [[t.lower() for t in x["text"]] for x in batch]
|
227 |
+
else:
|
228 |
+
input_list = [x["text"] for x in batch]
|
229 |
+
raw_data = [" ".join(x["text"]) if locale not in ['ja-JP', 'zh-CN', 'zh-TW'] else "".join(x["text"]) for x in
|
230 |
+
batch]
|
231 |
+
text = self.tokenizer(raw_data,
|
232 |
+
return_tensors="pt",
|
233 |
+
padding=True,
|
234 |
+
truncation=True,
|
235 |
+
return_offsets_mapping=True,
|
236 |
+
**config).to(device)
|
237 |
+
if enable_label:
|
238 |
+
if label2tensor:
|
239 |
+
slot_mask = torch.ones_like(text.input_ids) * ignore_index
|
240 |
+
for i, offsets in enumerate(
|
241 |
+
self.general_align_data(input_list, raw_data, encoded_text=text)):
|
242 |
+
num = 0
|
243 |
+
# if len(offsets) != len(batch[i]["text"]) or len(offsets) != len(batch[i]["slot"]):
|
244 |
+
# if
|
245 |
+
for off in offsets:
|
246 |
+
slot_mask[i][off[0]
|
247 |
+
] = self.slot_label_dict[batch[i]["slot"][num]]
|
248 |
+
num += 1
|
249 |
+
# slot = slot_mask.clone()
|
250 |
+
# attentin_id = 0 if self.tokenizer.padding_side == "right" else 1
|
251 |
+
# for i, slot_batch in enumerate(slot):
|
252 |
+
# for j, x in enumerate(slot_batch):
|
253 |
+
# if x == ignore_index and text.attention_mask[i][j] == attentin_id and text.input_ids[i][
|
254 |
+
# j] not in self.tokenizer.all_special_ids:
|
255 |
+
# slot[i][j] = slot[i][j - 1]
|
256 |
+
slot = slot_mask.to(device)
|
257 |
+
if not self.use_multi:
|
258 |
+
intent = torch.tensor(
|
259 |
+
[self.intent_label_dict[x["intent"]] for x in batch]).to(device)
|
260 |
+
else:
|
261 |
+
one_hot = torch.zeros(
|
262 |
+
(len(batch), len(self.intent_label_list)), dtype=torch.float)
|
263 |
+
for index, b in enumerate(batch):
|
264 |
+
for x in b["intent"].split("#"):
|
265 |
+
one_hot[index][self.intent_label_dict[x]] = 1.
|
266 |
+
intent = one_hot.to(device)
|
267 |
+
else:
|
268 |
+
slot_mask = None
|
269 |
+
slot = [['#' for _ in range(text.input_ids.shape[1])]
|
270 |
+
for _ in range(text.input_ids.shape[0])]
|
271 |
+
for i, offsets in enumerate(self.general_align_data(input_list, raw_data, encoded_text=text)):
|
272 |
+
num = 0
|
273 |
+
for off in offsets:
|
274 |
+
slot[i][off[0]] = batch[i]["slot"][num]
|
275 |
+
num += 1
|
276 |
+
if not self.use_multi:
|
277 |
+
intent = [x["intent"] for x in batch]
|
278 |
+
else:
|
279 |
+
intent = [
|
280 |
+
[x for x in b["intent"].split("#")] for b in batch]
|
281 |
+
return InputData((text, slot, intent))
|
282 |
+
else:
|
283 |
+
return InputData((text, None, None))
|
284 |
+
|
285 |
+
def batch_fn(self,
|
286 |
+
batch,
|
287 |
+
ignore_index=-100,
|
288 |
+
device="cuda",
|
289 |
+
config=None,
|
290 |
+
align_mode="fast",
|
291 |
+
enable_label=True,
|
292 |
+
label2tensor=True):
|
293 |
+
if align_mode == "fast":
|
294 |
+
# try:
|
295 |
+
return self.fast_align(batch,
|
296 |
+
ignore_index=ignore_index,
|
297 |
+
device=device,
|
298 |
+
config=config,
|
299 |
+
enable_label=enable_label,
|
300 |
+
label2tensor=label2tensor)
|
301 |
+
# except:
|
302 |
+
# return self.general_align(batch,
|
303 |
+
# ignore_index=ignore_index,
|
304 |
+
# device=device,
|
305 |
+
# config=config,
|
306 |
+
# enable_label=enable_label,
|
307 |
+
# label2tensor=label2tensor)
|
308 |
+
else:
|
309 |
+
return self.general_align(batch,
|
310 |
+
ignore_index=ignore_index,
|
311 |
+
device=device,
|
312 |
+
config=config,
|
313 |
+
enable_label=enable_label,
|
314 |
+
label2tensor=label2tensor)
|
315 |
+
|
316 |
+
def get_data_loader(self,
|
317 |
+
dataset,
|
318 |
+
batch_size,
|
319 |
+
shuffle=False,
|
320 |
+
device="cuda",
|
321 |
+
enable_label=True,
|
322 |
+
align_mode="fast",
|
323 |
+
label2tensor=True, **config):
|
324 |
+
data_loader = DataLoader(dataset,
|
325 |
+
shuffle=shuffle,
|
326 |
+
batch_size=batch_size,
|
327 |
+
collate_fn=lambda x: self.batch_fn(x,
|
328 |
+
device=device,
|
329 |
+
config=config,
|
330 |
+
enable_label=enable_label,
|
331 |
+
align_mode=align_mode,
|
332 |
+
label2tensor=label2tensor))
|
333 |
+
return data_loader
|
common/logger.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-02-02 16:29:13
|
6 |
+
Description: log manager
|
7 |
+
|
8 |
+
'''
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import time
|
12 |
+
from common.config import Config
|
13 |
+
|
14 |
+
def mkdirs(dir_names):
|
15 |
+
for dir_name in dir_names:
|
16 |
+
if not os.path.exists(dir_name):
|
17 |
+
os.mkdir(dir_name)
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class Logger():
|
22 |
+
""" logging infomation by [wandb, fitlog, local file]
|
23 |
+
"""
|
24 |
+
def __init__(self,
|
25 |
+
logger_type: str,
|
26 |
+
logger_name: str,
|
27 |
+
logging_level="INFO",
|
28 |
+
start_time='',
|
29 |
+
accelerator=None):
|
30 |
+
""" create logger
|
31 |
+
|
32 |
+
Args:
|
33 |
+
logger_type (str): support type = ["wandb", "fitlog", "local"]
|
34 |
+
logger_name (str): logger name, means project name in wandb, and logging file name
|
35 |
+
logging_level (str, optional): logging level. Defaults to "INFO".
|
36 |
+
start_time (str, optional): start time string. Defaults to ''.
|
37 |
+
"""
|
38 |
+
self.logger_type = logger_type
|
39 |
+
times = time.localtime()
|
40 |
+
self.output_dir = "logs/" + logger_name + "/" + str(times.tm_year) + start_time
|
41 |
+
self.accelerator = accelerator
|
42 |
+
self.logger_name = logger_name
|
43 |
+
if accelerator is not None:
|
44 |
+
from accelerate.logging import get_logger
|
45 |
+
self.logging = get_logger(logger_name)
|
46 |
+
else:
|
47 |
+
if self.logger_type == "wandb":
|
48 |
+
import wandb
|
49 |
+
self.logger = wandb
|
50 |
+
mkdirs(["logs", "logs/" + logger_name, self.output_dir])
|
51 |
+
self.logger.init(project=logger_name)
|
52 |
+
elif self.logger_type == "fitlog":
|
53 |
+
import fitlog
|
54 |
+
self.logger = fitlog
|
55 |
+
mkdirs(["logs", "logs/" + logger_name, self.output_dir])
|
56 |
+
self.logger.set_log_dir("logs/" + logger_name)
|
57 |
+
else:
|
58 |
+
mkdirs(["logs", "logs/" + logger_name, self.output_dir])
|
59 |
+
self.config_file = os.path.join(self.output_dir, "/config.jsonl")
|
60 |
+
with open(self.config_file, "w", encoding="utf8") as f:
|
61 |
+
print(f"Config will be written to {self.config_file}")
|
62 |
+
|
63 |
+
self.loss_file = os.path.join(self.output_dir, "/loss.jsonl")
|
64 |
+
with open(self.loss_file, "w", encoding="utf8") as f:
|
65 |
+
print(f"Loss Result will be written to {self.loss_file}")
|
66 |
+
|
67 |
+
self.metric_file = os.path.join(self.output_dir, "/metric.jsonl")
|
68 |
+
with open(self.metric_file, "w", encoding="utf8") as f:
|
69 |
+
print(f"Metric Result will be written to {self.metric_file}")
|
70 |
+
|
71 |
+
self.other_log_file = os.path.join(self.output_dir, "/other_log.jsonl")
|
72 |
+
with open(self.other_log_file, "w", encoding="utf8") as f:
|
73 |
+
print(f"Other Log Result will be written to {self.other_log_file}")
|
74 |
+
import logging
|
75 |
+
LOGGING_LEVEL_MAP = {
|
76 |
+
"CRITICAL": logging.CRITICAL,
|
77 |
+
"FATAL": logging.FATAL,
|
78 |
+
"ERROR": logging.ERROR,
|
79 |
+
"WARNING": logging.WARNING,
|
80 |
+
"WARN": logging.WARN,
|
81 |
+
"INFO": logging.INFO,
|
82 |
+
"DEBUG": logging.DEBUG,
|
83 |
+
"NOTSET": logging.NOTSET,
|
84 |
+
}
|
85 |
+
logging.basicConfig(format='[%(levelname)s - %(asctime)s]\t%(message)s', datefmt='%m/%d/%Y %I:%M:%S %p',
|
86 |
+
filename=os.path.join(self.output_dir, "log.log"), level=LOGGING_LEVEL_MAP[logging_level])
|
87 |
+
self.logging = logging
|
88 |
+
|
89 |
+
def set_config(self, config: Config):
|
90 |
+
"""save config
|
91 |
+
|
92 |
+
Args:
|
93 |
+
config (Config): configuration object to save
|
94 |
+
"""
|
95 |
+
if self.accelerator is not None:
|
96 |
+
self.accelerator.init_trackers(self.logger_name, config=config)
|
97 |
+
elif self.logger_type == "wandb":
|
98 |
+
self.logger.config.update(config)
|
99 |
+
elif self.logger_type == "fitlog":
|
100 |
+
self.logger.add_hyper(config)
|
101 |
+
else:
|
102 |
+
with open(self.config_file, "a", encoding="utf8") as f:
|
103 |
+
f.write(json.dumps(config) + "\n")
|
104 |
+
|
105 |
+
def log(self, data, step=0):
|
106 |
+
"""log data and step
|
107 |
+
|
108 |
+
Args:
|
109 |
+
data (Any): data to log
|
110 |
+
step (int, optional): step num. Defaults to 0.
|
111 |
+
"""
|
112 |
+
if self.accelerator is not None:
|
113 |
+
self.accelerator.log(data, step=0)
|
114 |
+
elif self.logger_type == "wandb":
|
115 |
+
self.logger.log(data, step=step)
|
116 |
+
elif self.logger_type == "fitlog":
|
117 |
+
self.logger.add_other({"data": data, "step": step})
|
118 |
+
else:
|
119 |
+
with open(self.other_log_file, "a", encoding="utf8") as f:
|
120 |
+
f.write(json.dumps({"data": data, "step": step}) + "\n")
|
121 |
+
|
122 |
+
def log_metric(self, metric, metric_split="dev", step=0):
|
123 |
+
"""log metric
|
124 |
+
|
125 |
+
Args:
|
126 |
+
metric (Any): metric
|
127 |
+
metric_split (str, optional): dataset split. Defaults to 'dev'.
|
128 |
+
step (int, optional): step num. Defaults to 0.
|
129 |
+
"""
|
130 |
+
if self.accelerator is not None:
|
131 |
+
self.accelerator.log({metric_split: metric}, step=step)
|
132 |
+
elif self.logger_type == "wandb":
|
133 |
+
self.logger.log({metric_split: metric}, step=step)
|
134 |
+
elif self.logger_type == "fitlog":
|
135 |
+
self.logger.add_metric({metric_split: metric}, step=step)
|
136 |
+
else:
|
137 |
+
with open(self.metric_file, "a", encoding="utf8") as f:
|
138 |
+
f.write(json.dumps({metric_split: metric, "step": step}) + "\n")
|
139 |
+
|
140 |
+
def log_loss(self, loss, loss_name="Loss", step=0):
|
141 |
+
"""log loss
|
142 |
+
|
143 |
+
Args:
|
144 |
+
loss (Any): loss
|
145 |
+
loss_name (str, optional): loss description. Defaults to 'Loss'.
|
146 |
+
step (int, optional): step num. Defaults to 0.
|
147 |
+
"""
|
148 |
+
if self.accelerator is not None:
|
149 |
+
self.accelerator.log({loss_name: loss}, step=step)
|
150 |
+
elif self.logger_type == "wandb":
|
151 |
+
self.logger.log({loss_name: loss}, step=step)
|
152 |
+
elif self.logger_type == "fitlog":
|
153 |
+
self.logger.add_loss(loss, name=loss_name, step=step)
|
154 |
+
else:
|
155 |
+
with open(self.loss_file, "a", encoding="utf8") as f:
|
156 |
+
f.write(json.dumps({loss_name: loss, "step": step}) + "\n")
|
157 |
+
|
158 |
+
def finish(self):
|
159 |
+
"""finish logging
|
160 |
+
"""
|
161 |
+
if self.logger_type == "fitlog":
|
162 |
+
self.logger.finish()
|
163 |
+
|
164 |
+
def info(self, message:str):
|
165 |
+
""" Log a message with severity 'INFO' in local file / console.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
message (str): message to log
|
169 |
+
"""
|
170 |
+
self.logging.info(message)
|
171 |
+
|
172 |
+
def warning(self, message):
|
173 |
+
""" Log a message with severity 'WARNING' in local file / console.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
message (str): message to log
|
177 |
+
"""
|
178 |
+
self.logging.warning(message)
|
179 |
+
|
180 |
+
def error(self, message):
|
181 |
+
""" Log a message with severity 'ERROR' in local file / console.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
message (str): message to log
|
185 |
+
"""
|
186 |
+
self.logging.error(message)
|
187 |
+
|
188 |
+
def debug(self, message):
|
189 |
+
""" Log a message with severity 'DEBUG' in local file / console.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
message (str): message to log
|
193 |
+
"""
|
194 |
+
self.logging.debug(message)
|
195 |
+
|
196 |
+
def critical(self, message):
|
197 |
+
self.logging.critical(message)
|
common/metric.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 12:12:55
|
6 |
+
Description: Metric calculation class
|
7 |
+
|
8 |
+
'''
|
9 |
+
from collections import Counter
|
10 |
+
from typing import List, Dict
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from sklearn.metrics import f1_score
|
14 |
+
|
15 |
+
from common.utils import InputData, OutputData
|
16 |
+
|
17 |
+
|
18 |
+
class Evaluator(object):
|
19 |
+
"""Evaluation metric funtions library class
|
20 |
+
supported metric:
|
21 |
+
- slot_f1
|
22 |
+
- intent_acc
|
23 |
+
- exactly_match_accuracy
|
24 |
+
- intent_f1 (defult "macro_intent_f1")
|
25 |
+
- macro_intent_f1
|
26 |
+
- micro_intent_f1=
|
27 |
+
"""
|
28 |
+
@staticmethod
|
29 |
+
def exactly_match_accuracy(pred_slot: List[List[str or int]],
|
30 |
+
real_slot: List[List[str or int]],
|
31 |
+
pred_intent: List[List[str or int] or str or int],
|
32 |
+
real_intent: List[List[str or int] or str or int]) -> float:
|
33 |
+
"""Compute the accuracy based on the whole predictions of given sentence, including slot and intent.
|
34 |
+
(both support str or int index as the representation of slot and intent)
|
35 |
+
Args:
|
36 |
+
pred_slot (List[List[str or int]]): predicted sequence of slot list
|
37 |
+
real_slot (List[List[str or int]]): golden sequence of slot list.
|
38 |
+
pred_intent (List[List[str or int] or str or int]): golden intent list / golden multi intent list.
|
39 |
+
real_intent (List[List[str or int] or str or int]): predicted intent list / predicted multi intent list.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
float: exactly match accuracy score
|
43 |
+
"""
|
44 |
+
total_count, correct_count = 0.0, 0.0
|
45 |
+
for p_slot, r_slot, p_intent, r_intent in zip(pred_slot, real_slot, pred_intent, real_intent):
|
46 |
+
if isinstance(p_intent, list):
|
47 |
+
p_intent, r_intent = set(p_intent), set(r_intent)
|
48 |
+
if p_slot == r_slot and p_intent == r_intent:
|
49 |
+
correct_count += 1.0
|
50 |
+
total_count += 1.0
|
51 |
+
|
52 |
+
return 1.0 * correct_count / total_count
|
53 |
+
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def intent_accuracy(pred_list: List, real_list: List) -> float:
|
57 |
+
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
pred_list (List): predicted intent list
|
61 |
+
real_list (List): golden intent list
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
float: intent accuracy score
|
65 |
+
"""
|
66 |
+
total_count, correct_count = 0.0, 0.0
|
67 |
+
for p_intent, r_intent in zip(pred_list, real_list):
|
68 |
+
if isinstance(p_intent, list):
|
69 |
+
p_intent, r_intent = set(p_intent), set(r_intent)
|
70 |
+
if p_intent == r_intent:
|
71 |
+
correct_count += 1.0
|
72 |
+
total_count += 1.0
|
73 |
+
|
74 |
+
return 1.0 * correct_count / total_count
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def intent_f1(pred_list: List[List[int]], real_list: List[List[int]], num_intent: int, average='macro') -> float:
|
78 |
+
"""Get intent accuracy measured by predictions and ground-trues. Support both multi intent and single intent.
|
79 |
+
(Only support multi intent now, but you can use [[intent1], [intent2], ...] to compute intent f1 in single intent)
|
80 |
+
Args:
|
81 |
+
pred_list (List[List[int]]): predicted multi intent list.
|
82 |
+
real_list (List[List[int]]): golden multi intent list.
|
83 |
+
num_intent (int)
|
84 |
+
average (str): support "micro" and "macro"
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
float: intent accuracy score
|
88 |
+
"""
|
89 |
+
return f1_score(Evaluator.__instance2onehot(num_intent, real_list),
|
90 |
+
Evaluator.__instance2onehot(num_intent, pred_list),
|
91 |
+
average=average,
|
92 |
+
zero_division=0)
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def __multilabel2one_hot(labels, nums):
|
96 |
+
res = [0.] * nums
|
97 |
+
if len(labels) == 0:
|
98 |
+
return res
|
99 |
+
if isinstance(labels[0], list):
|
100 |
+
for label in labels[0]:
|
101 |
+
res[label] = 1.
|
102 |
+
return res
|
103 |
+
for label in labels:
|
104 |
+
res[label] = 1.
|
105 |
+
return res
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def __instance2onehot(num_intent, data):
|
109 |
+
res = []
|
110 |
+
for intents in data:
|
111 |
+
res.append(Evaluator.__multilabel2one_hot(intents, num_intent))
|
112 |
+
return np.array(res)
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def __startOfChunk(prevTag, tag, prevTagType, tagType, chunkStart=False):
|
116 |
+
if prevTag == 'B' and tag == 'B':
|
117 |
+
chunkStart = True
|
118 |
+
if prevTag == 'I' and tag == 'B':
|
119 |
+
chunkStart = True
|
120 |
+
if prevTag == 'O' and tag == 'B':
|
121 |
+
chunkStart = True
|
122 |
+
if prevTag == 'O' and tag == 'I':
|
123 |
+
chunkStart = True
|
124 |
+
|
125 |
+
if prevTag == 'E' and tag == 'E':
|
126 |
+
chunkStart = True
|
127 |
+
if prevTag == 'E' and tag == 'I':
|
128 |
+
chunkStart = True
|
129 |
+
if prevTag == 'O' and tag == 'E':
|
130 |
+
chunkStart = True
|
131 |
+
if prevTag == 'O' and tag == 'I':
|
132 |
+
chunkStart = True
|
133 |
+
|
134 |
+
if tag != 'O' and tag != '.' and prevTagType != tagType:
|
135 |
+
chunkStart = True
|
136 |
+
return chunkStart
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def __endOfChunk(prevTag, tag, prevTagType, tagType, chunkEnd=False):
|
140 |
+
if prevTag == 'B' and tag == 'B':
|
141 |
+
chunkEnd = True
|
142 |
+
if prevTag == 'B' and tag == 'O':
|
143 |
+
chunkEnd = True
|
144 |
+
if prevTag == 'I' and tag == 'B':
|
145 |
+
chunkEnd = True
|
146 |
+
if prevTag == 'I' and tag == 'O':
|
147 |
+
chunkEnd = True
|
148 |
+
|
149 |
+
if prevTag == 'E' and tag == 'E':
|
150 |
+
chunkEnd = True
|
151 |
+
if prevTag == 'E' and tag == 'I':
|
152 |
+
chunkEnd = True
|
153 |
+
if prevTag == 'E' and tag == 'O':
|
154 |
+
chunkEnd = True
|
155 |
+
if prevTag == 'I' and tag == 'O':
|
156 |
+
chunkEnd = True
|
157 |
+
|
158 |
+
if prevTag != 'O' and prevTag != '.' and prevTagType != tagType:
|
159 |
+
chunkEnd = True
|
160 |
+
return chunkEnd
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def __splitTagType(tag):
|
164 |
+
s = tag.split('-')
|
165 |
+
if len(s) > 2 or len(s) == 0:
|
166 |
+
raise ValueError('tag format wrong. it must be B-xxx.xxx')
|
167 |
+
if len(s) == 1:
|
168 |
+
tag = s[0]
|
169 |
+
tagType = ""
|
170 |
+
else:
|
171 |
+
tag = s[0]
|
172 |
+
tagType = s[1]
|
173 |
+
return tag, tagType
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def computeF1Score(correct_slots: List[List[str]], pred_slots: List[List[str]]) -> float:
|
177 |
+
"""compute f1 score is modified from conlleval.pl
|
178 |
+
|
179 |
+
Args:
|
180 |
+
correct_slots (List[List[str]]): golden slot string list
|
181 |
+
pred_slots (List[List[str]]): predicted slot string list
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
float: slot f1 score
|
185 |
+
"""
|
186 |
+
correctChunk = {}
|
187 |
+
correctChunkCnt = 0.0
|
188 |
+
foundCorrect = {}
|
189 |
+
foundCorrectCnt = 0.0
|
190 |
+
foundPred = {}
|
191 |
+
foundPredCnt = 0.0
|
192 |
+
correctTags = 0.0
|
193 |
+
tokenCount = 0.0
|
194 |
+
for correct_slot, pred_slot in zip(correct_slots, pred_slots):
|
195 |
+
inCorrect = False
|
196 |
+
lastCorrectTag = 'O'
|
197 |
+
lastCorrectType = ''
|
198 |
+
lastPredTag = 'O'
|
199 |
+
lastPredType = ''
|
200 |
+
for c, p in zip(correct_slot, pred_slot):
|
201 |
+
correctTag, correctType = Evaluator.__splitTagType(c)
|
202 |
+
predTag, predType = Evaluator.__splitTagType(p)
|
203 |
+
|
204 |
+
if inCorrect == True:
|
205 |
+
if Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
|
206 |
+
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
|
207 |
+
(lastCorrectType == lastPredType):
|
208 |
+
inCorrect = False
|
209 |
+
correctChunkCnt += 1.0
|
210 |
+
if lastCorrectType in correctChunk:
|
211 |
+
correctChunk[lastCorrectType] += 1.0
|
212 |
+
else:
|
213 |
+
correctChunk[lastCorrectType] = 1.0
|
214 |
+
elif Evaluator.__endOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) != \
|
215 |
+
Evaluator.__endOfChunk(lastPredTag, predTag, lastPredType, predType) or \
|
216 |
+
(correctType != predType):
|
217 |
+
inCorrect = False
|
218 |
+
|
219 |
+
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True and \
|
220 |
+
Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True and \
|
221 |
+
(correctType == predType):
|
222 |
+
inCorrect = True
|
223 |
+
|
224 |
+
if Evaluator.__startOfChunk(lastCorrectTag, correctTag, lastCorrectType, correctType) == True:
|
225 |
+
foundCorrectCnt += 1
|
226 |
+
if correctType in foundCorrect:
|
227 |
+
foundCorrect[correctType] += 1.0
|
228 |
+
else:
|
229 |
+
foundCorrect[correctType] = 1.0
|
230 |
+
|
231 |
+
if Evaluator.__startOfChunk(lastPredTag, predTag, lastPredType, predType) == True:
|
232 |
+
foundPredCnt += 1.0
|
233 |
+
if predType in foundPred:
|
234 |
+
foundPred[predType] += 1.0
|
235 |
+
else:
|
236 |
+
foundPred[predType] = 1.0
|
237 |
+
|
238 |
+
if correctTag == predTag and correctType == predType:
|
239 |
+
correctTags += 1.0
|
240 |
+
|
241 |
+
tokenCount += 1.0
|
242 |
+
|
243 |
+
lastCorrectTag = correctTag
|
244 |
+
lastCorrectType = correctType
|
245 |
+
lastPredTag = predTag
|
246 |
+
lastPredType = predType
|
247 |
+
|
248 |
+
if inCorrect == True:
|
249 |
+
correctChunkCnt += 1.0
|
250 |
+
if lastCorrectType in correctChunk:
|
251 |
+
correctChunk[lastCorrectType] += 1.0
|
252 |
+
else:
|
253 |
+
correctChunk[lastCorrectType] = 1.0
|
254 |
+
|
255 |
+
if foundPredCnt > 0:
|
256 |
+
precision = 1.0 * correctChunkCnt / foundPredCnt
|
257 |
+
else:
|
258 |
+
precision = 0
|
259 |
+
|
260 |
+
if foundCorrectCnt > 0:
|
261 |
+
recall = 1.0 * correctChunkCnt / foundCorrectCnt
|
262 |
+
else:
|
263 |
+
recall = 0
|
264 |
+
|
265 |
+
if (precision + recall) > 0:
|
266 |
+
f1 = (2.0 * precision * recall) / (precision + recall)
|
267 |
+
else:
|
268 |
+
f1 = 0
|
269 |
+
|
270 |
+
return f1
|
271 |
+
|
272 |
+
@staticmethod
|
273 |
+
def max_freq_predict(sample):
|
274 |
+
"""Max frequency prediction.
|
275 |
+
"""
|
276 |
+
predict = []
|
277 |
+
for items in sample:
|
278 |
+
predict.append(Counter(items).most_common(1)[0][0])
|
279 |
+
return predict
|
280 |
+
|
281 |
+
@staticmethod
|
282 |
+
def __token_map(indexes, token_label_map):
|
283 |
+
return [[token_label_map[idx] if idx in token_label_map else -1 for idx in index] for index in indexes]
|
284 |
+
|
285 |
+
@staticmethod
|
286 |
+
def compute_all_metric(inps: InputData,
|
287 |
+
output: OutputData,
|
288 |
+
intent_label_map: dict = None,
|
289 |
+
metric_list: List=None)-> Dict:
|
290 |
+
"""Auto compute all metric mentioned in 'metric_list'
|
291 |
+
|
292 |
+
Args:
|
293 |
+
inps (InputData): input golden slot and intent labels
|
294 |
+
output (OutputData): output predicted slot and intent labels
|
295 |
+
intent_label_map (dict, Optional): dict like {"intent1": 0, "intent2": 1, ...},which aims to map intent string to index
|
296 |
+
metric_list (List): support metrics in ["slot_f1", "intent_acc", "intent_f1", "macro_intent_f1", "micro_intent_f1", "EMA"]
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Dict: all metric mentioned in 'metric_list', like {'EMA': 0.7, ...}
|
300 |
+
|
301 |
+
|
302 |
+
Example:
|
303 |
+
if compute slot metric:
|
304 |
+
|
305 |
+
inps.slot = [["slot1", "slot2", ...], ...]; output.slot_ids=[["slot1", "slot2", ...], ...];
|
306 |
+
|
307 |
+
if compute intent metric:
|
308 |
+
|
309 |
+
[Multi Intent] inps.intent = [["intent1", "intent2", ...], ...]; output.intent_ids = [["intent1", "intent2", ...], ...]
|
310 |
+
|
311 |
+
[Single Intent] inps.intent = ["intent1", ...]; [Single Intent] output.intent_ids = ["intent1", ...]
|
312 |
+
"""
|
313 |
+
if not metric_list:
|
314 |
+
metric_list = ["slot_f1", "intent_acc", "EMA"]
|
315 |
+
res_dict = {}
|
316 |
+
use_slot = output.slot_ids is not None and len(output.slot_ids) > 0
|
317 |
+
use_intent = output.intent_ids is not None and len(
|
318 |
+
output.intent_ids) > 0
|
319 |
+
if use_slot and "slot_f1" in metric_list:
|
320 |
+
res_dict["slot_f1"] = Evaluator.computeF1Score(
|
321 |
+
output.slot_ids, inps.slot)
|
322 |
+
if use_intent and "intent_acc" in metric_list:
|
323 |
+
res_dict["intent_acc"] = Evaluator.intent_accuracy(
|
324 |
+
output.intent_ids, inps.intent)
|
325 |
+
if isinstance(output.intent_ids[0], list):
|
326 |
+
if "intent_f1" in metric_list:
|
327 |
+
res_dict["intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
|
328 |
+
Evaluator.__token_map(
|
329 |
+
inps.intent, intent_label_map),
|
330 |
+
len(intent_label_map.keys()))
|
331 |
+
elif "macro_intent_f1" in metric_list:
|
332 |
+
res_dict["macro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
|
333 |
+
Evaluator.__token_map(inps.intent, intent_label_map),
|
334 |
+
len(intent_label_map.keys()), average="macro")
|
335 |
+
if "micro_intent_f1" in metric_list:
|
336 |
+
res_dict["micro_intent_f1"] = Evaluator.intent_f1(Evaluator.__token_map(output.intent_ids, intent_label_map),
|
337 |
+
Evaluator.__token_map(inps.intent, intent_label_map),
|
338 |
+
len(intent_label_map.keys()), average="micro")
|
339 |
+
|
340 |
+
if use_slot and use_intent and "EMA" in metric_list:
|
341 |
+
res_dict["EMA"] = Evaluator.exactly_match_accuracy(output.slot_ids, inps.slot, output.intent_ids,
|
342 |
+
inps.intent)
|
343 |
+
return res_dict
|
common/model_manager.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-02-07 21:36:06
|
6 |
+
Description: manage all process of model training and prediction.
|
7 |
+
|
8 |
+
'''
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
from common import utils
|
18 |
+
from common.loader import DataFactory
|
19 |
+
from common.logger import Logger
|
20 |
+
from common.metric import Evaluator
|
21 |
+
from common.tokenizer import get_tokenizer, get_tokenizer_class, load_embedding
|
22 |
+
from common.utils import InputData, instantiate
|
23 |
+
from common.utils import OutputData
|
24 |
+
from common.config import Config
|
25 |
+
import dill
|
26 |
+
|
27 |
+
|
28 |
+
class ModelManager(object):
|
29 |
+
def __init__(self, config: Config):
|
30 |
+
"""create model manager by config
|
31 |
+
|
32 |
+
Args:
|
33 |
+
config (Config): configuration to manage all process in OpenSLU
|
34 |
+
"""
|
35 |
+
# init config
|
36 |
+
self.config = config
|
37 |
+
self.__set_seed(self.config.base.get("seed"))
|
38 |
+
self.device = self.config.base.get("device")
|
39 |
+
|
40 |
+
# enable accelerator
|
41 |
+
if "accelerator" in self.config and self.config["accelerator"].get("use_accelerator"):
|
42 |
+
from accelerate import Accelerator
|
43 |
+
self.accelerator = Accelerator(log_with="wandb")
|
44 |
+
else:
|
45 |
+
self.accelerator = None
|
46 |
+
if self.config.base.get("train"):
|
47 |
+
self.tokenizer = get_tokenizer(
|
48 |
+
self.config.tokenizer.get("_tokenizer_name_"))
|
49 |
+
self.logger = Logger(
|
50 |
+
"wandb", self.config.base["name"], start_time=config.start_time, accelerator=self.accelerator)
|
51 |
+
|
52 |
+
# init dataloader & load data
|
53 |
+
if self.config.base.get("save_dir"):
|
54 |
+
self.model_save_dir = self.config.base["save_dir"]
|
55 |
+
else:
|
56 |
+
if not os.path.exists("save/"):
|
57 |
+
os.mkdir("save/")
|
58 |
+
self.model_save_dir = "save/" + config.start_time
|
59 |
+
if not os.path.exists(self.model_save_dir):
|
60 |
+
os.mkdir(self.model_save_dir)
|
61 |
+
batch_size = self.config.base["batch_size"]
|
62 |
+
df = DataFactory(tokenizer=self.tokenizer,
|
63 |
+
use_multi_intent=self.config.base.get("multi_intent"),
|
64 |
+
to_lower_case=self.config.base.get("_to_lower_case_"))
|
65 |
+
train_dataset = df.load_dataset(self.config.dataset, split="train")
|
66 |
+
|
67 |
+
# update label and vocabulary
|
68 |
+
df.update_label_names(train_dataset)
|
69 |
+
df.update_vocabulary(train_dataset)
|
70 |
+
|
71 |
+
# init tokenizer config and dataloaders
|
72 |
+
tokenizer_config = {key: self.config.tokenizer[key]
|
73 |
+
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
|
74 |
+
self.train_dataloader = df.get_data_loader(train_dataset,
|
75 |
+
batch_size,
|
76 |
+
shuffle=True,
|
77 |
+
device=self.device,
|
78 |
+
enable_label=True,
|
79 |
+
align_mode=self.config.tokenizer.get(
|
80 |
+
"_align_mode_"),
|
81 |
+
label2tensor=True,
|
82 |
+
**tokenizer_config)
|
83 |
+
dev_dataset = df.load_dataset(
|
84 |
+
self.config.dataset, split="validation")
|
85 |
+
self.dev_dataloader = df.get_data_loader(dev_dataset,
|
86 |
+
batch_size,
|
87 |
+
shuffle=False,
|
88 |
+
device=self.device,
|
89 |
+
enable_label=True,
|
90 |
+
align_mode=self.config.tokenizer.get(
|
91 |
+
"_align_mode_"),
|
92 |
+
label2tensor=False,
|
93 |
+
**tokenizer_config)
|
94 |
+
df.update_vocabulary(dev_dataset)
|
95 |
+
# add intent label num and slot label num to config
|
96 |
+
if int(self.config.get_intent_label_num()) == 0 or int(self.config.get_slot_label_num()) == 0:
|
97 |
+
self.intent_list = df.intent_label_list
|
98 |
+
self.intent_dict = df.intent_label_dict
|
99 |
+
self.config.set_intent_label_num(len(self.intent_list))
|
100 |
+
self.slot_list = df.slot_label_list
|
101 |
+
self.slot_dict = df.slot_label_dict
|
102 |
+
self.config.set_slot_label_num(len(self.slot_list))
|
103 |
+
self.config.set_vocab_size(self.tokenizer.vocab_size)
|
104 |
+
|
105 |
+
# autoload embedding for non-pretrained encoder
|
106 |
+
if self.config["model"]["encoder"].get("embedding") and self.config["model"]["encoder"]["embedding"].get(
|
107 |
+
"load_embedding_name"):
|
108 |
+
self.config["model"]["encoder"]["embedding"]["embedding_matrix"] = load_embedding(self.tokenizer,
|
109 |
+
self.config["model"][
|
110 |
+
"encoder"][
|
111 |
+
"embedding"].get(
|
112 |
+
"load_embedding_name"))
|
113 |
+
# fill template in config
|
114 |
+
self.config.autoload_template()
|
115 |
+
# save config
|
116 |
+
self.logger.set_config(self.config)
|
117 |
+
|
118 |
+
self.model = None
|
119 |
+
self.optimizer = None
|
120 |
+
self.total_step = None
|
121 |
+
self.lr_scheduler = None
|
122 |
+
if self.config.tokenizer.get("_tokenizer_name_") == "word_tokenizer":
|
123 |
+
self.tokenizer.save(os.path.join(self.model_save_dir, "tokenizer.json"))
|
124 |
+
utils.save_json(os.path.join(
|
125 |
+
self.model_save_dir, "label.json"), {"intent": self.intent_list,"slot": self.slot_list})
|
126 |
+
if self.config.base.get("test"):
|
127 |
+
self.test_dataset = df.load_dataset(
|
128 |
+
self.config.dataset, split="test")
|
129 |
+
self.test_dataloader = df.get_data_loader(self.test_dataset,
|
130 |
+
batch_size,
|
131 |
+
shuffle=False,
|
132 |
+
device=self.device,
|
133 |
+
enable_label=True,
|
134 |
+
align_mode=self.config.tokenizer.get(
|
135 |
+
"_align_mode_"),
|
136 |
+
label2tensor=False,
|
137 |
+
**tokenizer_config)
|
138 |
+
|
139 |
+
def init_model(self, model):
|
140 |
+
"""init model, optimizer, lr_scheduler
|
141 |
+
|
142 |
+
Args:
|
143 |
+
model (Any): pytorch model
|
144 |
+
"""
|
145 |
+
self.model = model
|
146 |
+
self.model.to(self.device)
|
147 |
+
if self.config.base.get("train"):
|
148 |
+
self.optimizer = instantiate(
|
149 |
+
self.config["optimizer"])(self.model.parameters())
|
150 |
+
self.total_step = int(self.config.base.get(
|
151 |
+
"epoch_num")) * len(self.train_dataloader)
|
152 |
+
self.lr_scheduler = instantiate(self.config["scheduler"])(
|
153 |
+
optimizer=self.optimizer,
|
154 |
+
num_training_steps=self.total_step
|
155 |
+
)
|
156 |
+
if self.accelerator is not None:
|
157 |
+
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler = self.accelerator.prepare(
|
158 |
+
self.model, self.optimizer, self.train_dataloader, self.lr_scheduler)
|
159 |
+
if self.config.base.get("load_dir_path"):
|
160 |
+
self.accelerator.load_state(self.config.base.get("load_dir_path"))
|
161 |
+
# self.dev_dataloader = self.accelerator.prepare(self.dev_dataloader)
|
162 |
+
|
163 |
+
def eval(self, step: int, best_metric: float) -> float:
|
164 |
+
""" evaluation models.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
step (int): which step the model has trained in
|
168 |
+
best_metric (float): last best metric value to judge whether to test or save model
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
float: updated best metric value
|
172 |
+
"""
|
173 |
+
# TODO: save dev
|
174 |
+
_, res = self.__evaluate(self.model, self.dev_dataloader)
|
175 |
+
self.logger.log_metric(res, metric_split="dev", step=step)
|
176 |
+
if res[self.config.base.get("best_key")] > best_metric:
|
177 |
+
best_metric = res[self.config.base.get("best_key")]
|
178 |
+
outputs, test_res = self.__evaluate(
|
179 |
+
self.model, self.test_dataloader)
|
180 |
+
if not os.path.exists(self.model_save_dir):
|
181 |
+
os.mkdir(self.model_save_dir)
|
182 |
+
if self.accelerator is None:
|
183 |
+
torch.save(self.model, os.path.join(
|
184 |
+
self.model_save_dir, "model.pkl"))
|
185 |
+
torch.save(self.optimizer, os.path.join(
|
186 |
+
self.model_save_dir, "optimizer.pkl"))
|
187 |
+
torch.save(self.lr_scheduler, os.path.join(
|
188 |
+
self.model_save_dir, "lr_scheduler.pkl"), pickle_module=dill)
|
189 |
+
torch.save(step, os.path.join(
|
190 |
+
self.model_save_dir, "step.pkl"))
|
191 |
+
else:
|
192 |
+
self.accelerator.wait_for_everyone()
|
193 |
+
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
194 |
+
self.accelerator.save(unwrapped_model.state_dict(
|
195 |
+
), os.path.join(self.model_save_dir, "model.pkl"))
|
196 |
+
self.accelerator.save_state(output_dir=self.model_save_dir)
|
197 |
+
outputs.save(self.model_save_dir, self.test_dataset)
|
198 |
+
self.logger.log_metric(test_res, metric_split="test", step=step)
|
199 |
+
return best_metric
|
200 |
+
|
201 |
+
def train(self) -> float:
|
202 |
+
""" train models.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
float: updated best metric value
|
206 |
+
"""
|
207 |
+
step = 0
|
208 |
+
best_metric = 0
|
209 |
+
progress_bar = tqdm(range(self.total_step))
|
210 |
+
for _ in range(int(self.config.base.get("epoch_num"))):
|
211 |
+
for data in self.train_dataloader:
|
212 |
+
if step == 0:
|
213 |
+
self.logger.info(data.get_item(
|
214 |
+
0, tokenizer=self.tokenizer, intent_map=self.intent_list, slot_map=self.slot_list))
|
215 |
+
output = self.model(data)
|
216 |
+
if self.accelerator is not None and hasattr(self.model, "module"):
|
217 |
+
loss, intent_loss, slot_loss = self.model.module.compute_loss(
|
218 |
+
pred=output, target=data)
|
219 |
+
else:
|
220 |
+
loss, intent_loss, slot_loss = self.model.compute_loss(
|
221 |
+
pred=output, target=data)
|
222 |
+
self.logger.log_loss(loss, "Loss", step=step)
|
223 |
+
self.logger.log_loss(intent_loss, "Intent Loss", step=step)
|
224 |
+
self.logger.log_loss(slot_loss, "Slot Loss", step=step)
|
225 |
+
self.optimizer.zero_grad()
|
226 |
+
|
227 |
+
if self.accelerator is not None:
|
228 |
+
self.accelerator.backward(loss)
|
229 |
+
else:
|
230 |
+
loss.backward()
|
231 |
+
self.optimizer.step()
|
232 |
+
self.lr_scheduler.step()
|
233 |
+
if not self.config.base.get("eval_by_epoch") and step % self.config.base.get(
|
234 |
+
"eval_step") == 0 and step != 0:
|
235 |
+
best_metric = self.eval(step, best_metric)
|
236 |
+
step += 1
|
237 |
+
progress_bar.update(1)
|
238 |
+
if self.config.base.get("eval_by_epoch"):
|
239 |
+
best_metric = self.eval(step, best_metric)
|
240 |
+
self.logger.finish()
|
241 |
+
return best_metric
|
242 |
+
|
243 |
+
def __set_seed(self, seed_value: int):
|
244 |
+
"""Manually set random seeds.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
seed_value (int): random seed
|
248 |
+
"""
|
249 |
+
random.seed(seed_value)
|
250 |
+
np.random.seed(seed_value)
|
251 |
+
torch.manual_seed(seed_value)
|
252 |
+
torch.random.manual_seed(seed_value)
|
253 |
+
os.environ['PYTHONHASHSEED'] = str(seed_value)
|
254 |
+
if torch.cuda.is_available():
|
255 |
+
torch.cuda.manual_seed(seed_value)
|
256 |
+
torch.cuda.manual_seed_all(seed_value)
|
257 |
+
torch.backends.cudnn.deterministic = True
|
258 |
+
torch.backends.cudnn.benchmark = True
|
259 |
+
return
|
260 |
+
|
261 |
+
def __evaluate(self, model, dataloader):
|
262 |
+
model.eval()
|
263 |
+
inps = InputData()
|
264 |
+
outputs = OutputData()
|
265 |
+
for data in dataloader:
|
266 |
+
torch.cuda.empty_cache()
|
267 |
+
output = model(data)
|
268 |
+
if self.accelerator is not None and hasattr(self.model, "module"):
|
269 |
+
decode_output = model.module.decode(output, data)
|
270 |
+
else:
|
271 |
+
decode_output = model.decode(output, data)
|
272 |
+
|
273 |
+
decode_output.map_output(slot_map=self.slot_list,
|
274 |
+
intent_map=self.intent_list)
|
275 |
+
data, decode_output = utils.remove_slot_ignore_index(
|
276 |
+
data, decode_output, ignore_index="#")
|
277 |
+
|
278 |
+
inps.merge_input_data(data)
|
279 |
+
outputs.merge_output_data(decode_output)
|
280 |
+
if "metric" in self.config:
|
281 |
+
res = Evaluator.compute_all_metric(
|
282 |
+
inps, outputs, intent_label_map=self.intent_dict, metric_list=self.config.metric)
|
283 |
+
else:
|
284 |
+
res = Evaluator.compute_all_metric(
|
285 |
+
inps, outputs, intent_label_map=self.intent_dict)
|
286 |
+
model.train()
|
287 |
+
return outputs, res
|
288 |
+
|
289 |
+
def load(self):
|
290 |
+
|
291 |
+
self.model = torch.load(os.path.join(self.config.base["model_dir"], "model.pkl"))
|
292 |
+
if self.config.tokenizer["_tokenizer_name_"] == "word_tokenizer":
|
293 |
+
self.tokenizer = get_tokenizer_class(self.config.tokenizer["_tokenizer_name_"]).from_file(
|
294 |
+
os.path.join(self.config.base["model_dir"], "tokenizer.json"))
|
295 |
+
else:
|
296 |
+
self.tokenizer = get_tokenizer(self.config.tokenizer["_tokenizer_name_"])
|
297 |
+
self.model.to(self.device)
|
298 |
+
label = utils.load_json(os.path.join(self.config.base["model_dir"], "label.json"))
|
299 |
+
self.intent_list = label["intent"]
|
300 |
+
self.slot_list = label["slot"]
|
301 |
+
self.data_factory=DataFactory(tokenizer=self.tokenizer,
|
302 |
+
use_multi_intent=self.config.base.get("multi_intent"),
|
303 |
+
to_lower_case=self.config.tokenizer.get("_to_lower_case_"))
|
304 |
+
|
305 |
+
def predict(self, text_data):
|
306 |
+
self.model.eval()
|
307 |
+
tokenizer_config = {key: self.config.tokenizer[key]
|
308 |
+
for key in self.config.tokenizer if key[0] != "_" and key[-1] != "_"}
|
309 |
+
align_mode = self.config.tokenizer.get("_align_mode_")
|
310 |
+
inputs = self.data_factory.batch_fn(batch=[{"text": text_data.split(" ")}],
|
311 |
+
device=self.device,
|
312 |
+
config=tokenizer_config,
|
313 |
+
enable_label=False,
|
314 |
+
align_mode= align_mode if align_mode is not None else "general",
|
315 |
+
label2tensor=False)
|
316 |
+
output = self.model(inputs)
|
317 |
+
decode_output = self.model.decode(output, inputs)
|
318 |
+
decode_output.map_output(slot_map=self.slot_list,
|
319 |
+
intent_map=self.intent_list)
|
320 |
+
if self.config.base.get("multi_intent"):
|
321 |
+
intent = decode_output.intent_ids[0]
|
322 |
+
else:
|
323 |
+
intent = [decode_output.intent_ids[0]]
|
324 |
+
return {"intent": intent, "slot": decode_output.slot_ids[0], "text": self.tokenizer.decode(inputs.input_ids[0])}
|
common/tokenizer.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from collections import Counter
|
4 |
+
from collections import OrderedDict
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from ordered_set import OrderedSet
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
|
11 |
+
from common.utils import download, unzip_file
|
12 |
+
|
13 |
+
|
14 |
+
def get_tokenizer(tokenizer_name:str):
|
15 |
+
"""auto get tokenizer
|
16 |
+
|
17 |
+
Args:
|
18 |
+
tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Any: Tokenizer Object
|
22 |
+
"""
|
23 |
+
if tokenizer_name == "word_tokenizer":
|
24 |
+
return WordTokenizer(tokenizer_name)
|
25 |
+
else:
|
26 |
+
return AutoTokenizer.from_pretrained(tokenizer_name)
|
27 |
+
|
28 |
+
def get_tokenizer_class(tokenizer_name:str):
|
29 |
+
"""auto get tokenizer class
|
30 |
+
|
31 |
+
Args:
|
32 |
+
tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Any: Tokenizer Class
|
36 |
+
"""
|
37 |
+
if tokenizer_name == "word_tokenizer":
|
38 |
+
return WordTokenizer
|
39 |
+
else:
|
40 |
+
return AutoTokenizer.from_pretrained
|
41 |
+
|
42 |
+
BATCH_STATE = 1
|
43 |
+
INSTANCE_STATE = 2
|
44 |
+
|
45 |
+
|
46 |
+
class WordTokenizer(object):
|
47 |
+
|
48 |
+
def __init__(self, name):
|
49 |
+
self.__name = name
|
50 |
+
self.index2instance = OrderedSet()
|
51 |
+
self.instance2index = OrderedDict()
|
52 |
+
# Counter Object record the frequency
|
53 |
+
# of element occurs in raw text.
|
54 |
+
self.counter = Counter()
|
55 |
+
|
56 |
+
self.__sign_pad = "[PAD]"
|
57 |
+
self.add_instance(self.__sign_pad)
|
58 |
+
self.__sign_unk = "[UNK]"
|
59 |
+
self.add_instance(self.__sign_unk)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def padding_side(self):
|
63 |
+
return "right"
|
64 |
+
@property
|
65 |
+
def all_special_ids(self):
|
66 |
+
return [self.unk_token_id, self.pad_token_id]
|
67 |
+
|
68 |
+
@property
|
69 |
+
def name_or_path(self):
|
70 |
+
return self.__name
|
71 |
+
|
72 |
+
@property
|
73 |
+
def vocab_size(self):
|
74 |
+
return len(self.instance2index)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def pad_token_id(self):
|
78 |
+
return self.instance2index[self.__sign_pad]
|
79 |
+
|
80 |
+
@property
|
81 |
+
def unk_token_id(self):
|
82 |
+
return self.instance2index[self.__sign_unk]
|
83 |
+
|
84 |
+
def add_instance(self, instance):
|
85 |
+
""" Add instances to alphabet.
|
86 |
+
|
87 |
+
1, We support any iterative data structure which
|
88 |
+
contains elements of str type.
|
89 |
+
|
90 |
+
2, We will count added instances that will influence
|
91 |
+
the serialization of unknown instance.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
instance: is given instance or a list of it.
|
95 |
+
"""
|
96 |
+
|
97 |
+
if isinstance(instance, (list, tuple)):
|
98 |
+
for element in instance:
|
99 |
+
self.add_instance(element)
|
100 |
+
return
|
101 |
+
|
102 |
+
# We only support elements of str type.
|
103 |
+
assert isinstance(instance, str)
|
104 |
+
|
105 |
+
# count the frequency of instances.
|
106 |
+
self.counter[instance] += 1
|
107 |
+
|
108 |
+
if instance not in self.index2instance:
|
109 |
+
self.instance2index[instance] = len(self.index2instance)
|
110 |
+
self.index2instance.append(instance)
|
111 |
+
|
112 |
+
def __call__(self, instance,
|
113 |
+
return_tensors="pt",
|
114 |
+
is_split_into_words=True,
|
115 |
+
padding=True,
|
116 |
+
add_special_tokens=False,
|
117 |
+
truncation=True,
|
118 |
+
max_length=512,
|
119 |
+
**config):
|
120 |
+
if isinstance(instance, (list, tuple)) and isinstance(instance[0], (str)) and is_split_into_words:
|
121 |
+
res = self.get_index(instance)
|
122 |
+
state = INSTANCE_STATE
|
123 |
+
elif isinstance(instance, str) and not is_split_into_words:
|
124 |
+
res = self.get_index(instance.split(" "))
|
125 |
+
state = INSTANCE_STATE
|
126 |
+
elif not is_split_into_words and isinstance(instance, (list, tuple)):
|
127 |
+
res = [self.get_index(ins.split(" ")) for ins in instance]
|
128 |
+
state = BATCH_STATE
|
129 |
+
else:
|
130 |
+
res = [self.get_index(ins) for ins in instance]
|
131 |
+
state = BATCH_STATE
|
132 |
+
res = [r[:max_length] if len(r) >= max_length else r for r in res]
|
133 |
+
pad_id = self.get_index(self.__sign_pad)
|
134 |
+
if padding and state == BATCH_STATE:
|
135 |
+
max_len = max([len(x) for x in instance])
|
136 |
+
|
137 |
+
for i in range(len(res)):
|
138 |
+
res[i] = res[i] + [pad_id] * (max_len - len(res[i]))
|
139 |
+
if return_tensors == "pt":
|
140 |
+
input_ids = torch.Tensor(res).long()
|
141 |
+
attention_mask = (input_ids != pad_id).long()
|
142 |
+
elif state == BATCH_STATE:
|
143 |
+
input_ids = res
|
144 |
+
attention_mask = [1 if r != pad_id else 0 for batch in res for r in batch]
|
145 |
+
else:
|
146 |
+
input_ids = res
|
147 |
+
attention_mask = [1 if r != pad_id else 0 for r in res]
|
148 |
+
return TokenizedData(input_ids, token_type_ids=attention_mask, attention_mask=attention_mask)
|
149 |
+
|
150 |
+
def get_index(self, instance):
|
151 |
+
""" Serialize given instance and return.
|
152 |
+
|
153 |
+
For unknown words, the return index of alphabet
|
154 |
+
depends on variable self.__use_unk:
|
155 |
+
|
156 |
+
1, If True, then return the index of "<UNK>";
|
157 |
+
2, If False, then return the index of the
|
158 |
+
element that hold max frequency in training data.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
instance (Any): is given instance or a list of it.
|
162 |
+
Return:
|
163 |
+
Any: the serialization of query instance.
|
164 |
+
"""
|
165 |
+
|
166 |
+
if isinstance(instance, (list, tuple)):
|
167 |
+
return [self.get_index(elem) for elem in instance]
|
168 |
+
|
169 |
+
assert isinstance(instance, str)
|
170 |
+
|
171 |
+
try:
|
172 |
+
return self.instance2index[instance]
|
173 |
+
except KeyError:
|
174 |
+
return self.instance2index[self.__sign_unk]
|
175 |
+
|
176 |
+
def decode(self, index):
|
177 |
+
""" Get corresponding instance of query index.
|
178 |
+
|
179 |
+
if index is invalid, then throws exception.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
index (int): is query index, possibly iterable.
|
183 |
+
Returns:
|
184 |
+
is corresponding instance.
|
185 |
+
"""
|
186 |
+
|
187 |
+
if isinstance(index, list):
|
188 |
+
return [self.decode(elem) for elem in index]
|
189 |
+
if isinstance(index, torch.Tensor):
|
190 |
+
index = index.tolist()
|
191 |
+
return self.decode(index)
|
192 |
+
return self.index2instance[index]
|
193 |
+
|
194 |
+
def save(self, path):
|
195 |
+
""" Save the content of alphabet to files.
|
196 |
+
|
197 |
+
There are two kinds of saved files:
|
198 |
+
1, The first is a list file, elements are
|
199 |
+
sorted by the frequency of occurrence.
|
200 |
+
|
201 |
+
2, The second is a dictionary file, elements
|
202 |
+
are sorted by it serialized index.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
path (str): is the path to save object.
|
206 |
+
"""
|
207 |
+
|
208 |
+
with open(path, 'w', encoding="utf8") as fw:
|
209 |
+
fw.write(json.dumps({"name": self.__name, "token_map": self.instance2index}))
|
210 |
+
|
211 |
+
@staticmethod
|
212 |
+
def from_file(path):
|
213 |
+
with open(path, 'r', encoding="utf8") as fw:
|
214 |
+
obj = json.load(fw)
|
215 |
+
tokenizer = WordTokenizer(obj["name"])
|
216 |
+
tokenizer.instance2index = OrderedDict(obj["token_map"])
|
217 |
+
tokenizer.counter = len(tokenizer.instance2index)
|
218 |
+
tokenizer.index2instance = OrderedSet(tokenizer.instance2index.keys())
|
219 |
+
return tokenizer
|
220 |
+
|
221 |
+
def __len__(self):
|
222 |
+
return len(self.index2instance)
|
223 |
+
|
224 |
+
def __str__(self):
|
225 |
+
return 'Alphabet {} contains about {} words: \n\t{}'.format(self.name_or_path, len(self), self.index2instance)
|
226 |
+
|
227 |
+
def convert_tokens_to_ids(self, tokens):
|
228 |
+
"""convert token sequence to intput ids sequence
|
229 |
+
|
230 |
+
Args:
|
231 |
+
tokens (Any): token sequence
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
Any: intput ids sequence
|
235 |
+
"""
|
236 |
+
try:
|
237 |
+
if isinstance(tokens, (list, tuple)):
|
238 |
+
return [self.instance2index[x] for x in tokens]
|
239 |
+
return self.instance2index[tokens]
|
240 |
+
|
241 |
+
except KeyError:
|
242 |
+
return self.instance2index[self.__sign_unk]
|
243 |
+
|
244 |
+
|
245 |
+
class TokenizedData():
|
246 |
+
"""tokenized output data with input_ids, token_type_ids, attention_mask
|
247 |
+
"""
|
248 |
+
def __init__(self, input_ids, token_type_ids, attention_mask):
|
249 |
+
self.input_ids = input_ids
|
250 |
+
self.token_type_ids = token_type_ids
|
251 |
+
self.attention_mask = attention_mask
|
252 |
+
|
253 |
+
def word_ids(self, index: int) -> List[int or None]:
|
254 |
+
""" get word id list
|
255 |
+
|
256 |
+
Args:
|
257 |
+
index (int): word index in sequence
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
List[int or None]: word id list
|
261 |
+
"""
|
262 |
+
return [j if self.attention_mask[index][j] != 0 else None for j, x in enumerate(self.input_ids[index])]
|
263 |
+
|
264 |
+
def word_to_tokens(self, index, word_id, **kwargs):
|
265 |
+
"""map word and tokens
|
266 |
+
|
267 |
+
Args:
|
268 |
+
index (int): unused
|
269 |
+
word_id (int): word index in sequence
|
270 |
+
"""
|
271 |
+
return (word_id, word_id + 1)
|
272 |
+
|
273 |
+
def to(self, device):
|
274 |
+
"""set device
|
275 |
+
|
276 |
+
Args:
|
277 |
+
device (str): support ["cpu", "cuda"]
|
278 |
+
"""
|
279 |
+
self.input_ids = self.input_ids.to(device)
|
280 |
+
self.token_type_ids = self.token_type_ids.to(device)
|
281 |
+
self.attention_mask = self.attention_mask.to(device)
|
282 |
+
return self
|
283 |
+
|
284 |
+
|
285 |
+
def load_embedding(tokenizer: WordTokenizer, glove_name:str):
|
286 |
+
""" load embedding from standford server or local cache.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
tokenizer (WordTokenizer): non-pretrained tokenizer
|
290 |
+
glove_name (str): _description_
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
Any: word embedding
|
294 |
+
"""
|
295 |
+
save_path = "save/" + glove_name + ".zip"
|
296 |
+
if not os.path.exists(save_path):
|
297 |
+
download("http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip#" + glove_name, save_path)
|
298 |
+
unzip_file(save_path, "save/" + glove_name)
|
299 |
+
dim = int(glove_name.split(".")[-2][:-1])
|
300 |
+
embedding_list = torch.rand((tokenizer.vocab_size, dim))
|
301 |
+
embedding_list[tokenizer.pad_token_id] = torch.zeros((1, dim))
|
302 |
+
with open("save/" + glove_name + "/" + glove_name, "r", encoding="utf8") as f:
|
303 |
+
for line in f.readlines():
|
304 |
+
datas = line.split(" ")
|
305 |
+
word = datas[0]
|
306 |
+
embedding = torch.Tensor([float(datas[i + 1]) for i in range(len(datas) - 1)])
|
307 |
+
tokenized = tokenizer.convert_tokens_to_ids(word)
|
308 |
+
if isinstance(tokenized, int) and tokenized != tokenizer.unk_token_id:
|
309 |
+
embedding_list[tokenized] = embedding
|
310 |
+
|
311 |
+
return embedding_list
|
common/utils.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import importlib
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import tarfile
|
6 |
+
from typing import List, Tuple
|
7 |
+
import zipfile
|
8 |
+
from collections import Callable
|
9 |
+
from ruamel import yaml
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
from torch.nn.utils.rnn import pad_sequence
|
13 |
+
from tqdm import tqdm
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
class InputData():
|
17 |
+
"""input datas class
|
18 |
+
"""
|
19 |
+
def __init__(self, inputs: List =None):
|
20 |
+
"""init input datas class
|
21 |
+
|
22 |
+
if inputs is None:
|
23 |
+
this class can be used to save all InputData in the history by 'merge_input_data(X:InputData)'
|
24 |
+
else:
|
25 |
+
this class can be used for model input.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
inputs (List, optional): inputs with [tokenized_data, slot, intent]. Defaults to None.
|
29 |
+
"""
|
30 |
+
if inputs == None:
|
31 |
+
self.slot = []
|
32 |
+
self.intent = []
|
33 |
+
self.input_ids = None
|
34 |
+
self.token_type_ids = None
|
35 |
+
self.attention_mask = None
|
36 |
+
self.seq_lens = None
|
37 |
+
else:
|
38 |
+
self.input_ids = inputs[0].input_ids
|
39 |
+
self.token_type_ids = None
|
40 |
+
if hasattr(inputs[0], "token_type_ids"):
|
41 |
+
self.token_type_ids = inputs[0].token_type_ids
|
42 |
+
self.attention_mask = inputs[0].attention_mask
|
43 |
+
if len(inputs)>=2:
|
44 |
+
self.slot = inputs[1]
|
45 |
+
if len(inputs)>=3:
|
46 |
+
self.intent = inputs[2]
|
47 |
+
self.seq_lens = self.attention_mask.sum(-1)
|
48 |
+
|
49 |
+
def get_inputs(self):
|
50 |
+
""" get tokenized_data
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
dict: tokenized data
|
54 |
+
"""
|
55 |
+
res = {
|
56 |
+
"input_ids": self.input_ids,
|
57 |
+
"attention_mask": self.attention_mask
|
58 |
+
}
|
59 |
+
if self.token_type_ids is not None:
|
60 |
+
res["token_type_ids"] = self.token_type_ids
|
61 |
+
return res
|
62 |
+
|
63 |
+
def merge_input_data(self, inp: "InputData"):
|
64 |
+
"""merge another InputData object with slot and intent
|
65 |
+
|
66 |
+
Args:
|
67 |
+
inp (InputData): another InputData object
|
68 |
+
"""
|
69 |
+
self.slot += inp.slot
|
70 |
+
self.intent += inp.intent
|
71 |
+
|
72 |
+
def get_slot_mask(self, ignore_index:int)->Tensor:
|
73 |
+
"""get slot mask
|
74 |
+
|
75 |
+
Args:
|
76 |
+
ignore_index (int): ignore index used in slot padding
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Tensor: mask tensor
|
80 |
+
"""
|
81 |
+
mask = self.slot != ignore_index
|
82 |
+
mask[:, 0] = torch.ones_like(mask[:, 0]).to(self.slot.device)
|
83 |
+
return mask
|
84 |
+
|
85 |
+
def get_item(self, index, tokenizer=None, intent_map=None, slot_map=None, ignore_index = -100):
|
86 |
+
res = {"input_ids": self.input_ids[index]}
|
87 |
+
if tokenizer is not None:
|
88 |
+
res["tokens"] = [tokenizer.decode(x) for x in self.input_ids[index]]
|
89 |
+
if intent_map is not None:
|
90 |
+
intents = self.intent.tolist()
|
91 |
+
if isinstance(intents[index], list):
|
92 |
+
res["intent"] = [intent_map[int(x)] for x in intents[index]]
|
93 |
+
else:
|
94 |
+
res["intent"] = intent_map[intents[index]]
|
95 |
+
if slot_map is not None:
|
96 |
+
res["slot"] = [slot_map[x] if x != ignore_index else "#" for x in self.slot.tolist()[index]]
|
97 |
+
return res
|
98 |
+
|
99 |
+
class OutputData():
|
100 |
+
"""output data class
|
101 |
+
"""
|
102 |
+
def __init__(self, intent_ids=None, slot_ids=None):
|
103 |
+
"""init output data class
|
104 |
+
|
105 |
+
if intent_ids is None and slot_ids is None:
|
106 |
+
this class can be used to save all OutputData in the history by 'merge_output_data(X:OutputData)'
|
107 |
+
else:
|
108 |
+
this class can be used to model output management.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
intent_ids (Any, optional): list(Tensor) of intent ids / logits / strings. Defaults to None.
|
112 |
+
slot_ids (Any, optional): list(Tensor) of slot ids / ids / strings. Defaults to None.
|
113 |
+
"""
|
114 |
+
if intent_ids is None and slot_ids is None:
|
115 |
+
self.intent_ids = []
|
116 |
+
self.slot_ids = []
|
117 |
+
else:
|
118 |
+
if isinstance(intent_ids, ClassifierOutputData):
|
119 |
+
self.intent_ids = intent_ids.classifier_output
|
120 |
+
else:
|
121 |
+
self.intent_ids = intent_ids
|
122 |
+
if isinstance(slot_ids, ClassifierOutputData):
|
123 |
+
self.slot_ids = slot_ids.classifier_output
|
124 |
+
else:
|
125 |
+
self.slot_ids = slot_ids
|
126 |
+
|
127 |
+
def map_output(self, slot_map=None, intent_map=None):
|
128 |
+
""" map intent or slot ids to intent or slot string.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
slot_map (dict, optional): slot id-to-string map. Defaults to None.
|
132 |
+
intent_map (dict, optional): intent id-to-string map. Defaults to None.
|
133 |
+
"""
|
134 |
+
if self.slot_ids is not None:
|
135 |
+
if slot_map:
|
136 |
+
self.slot_ids = [[slot_map[x] if x >= 0 else "#" for x in sid] for sid in self.slot_ids]
|
137 |
+
if self.intent_ids is not None:
|
138 |
+
if intent_map:
|
139 |
+
self.intent_ids = [[intent_map[x] for x in sid] if isinstance(sid, list) else intent_map[sid] for sid in
|
140 |
+
self.intent_ids]
|
141 |
+
|
142 |
+
def merge_output_data(self, output:"OutputData"):
|
143 |
+
"""merge another OutData object with slot and intent
|
144 |
+
|
145 |
+
Args:
|
146 |
+
output (OutputData): another OutputData object
|
147 |
+
"""
|
148 |
+
if output.slot_ids is not None:
|
149 |
+
self.slot_ids += output.slot_ids
|
150 |
+
if output.intent_ids is not None:
|
151 |
+
self.intent_ids += output.intent_ids
|
152 |
+
|
153 |
+
def save(self, path:str, original_dataset=None):
|
154 |
+
""" save all OutputData in the history
|
155 |
+
|
156 |
+
Args:
|
157 |
+
path (str): save dir path
|
158 |
+
original_dataset(Iterable): original dataset
|
159 |
+
"""
|
160 |
+
# with open(f"{path}/intent.jsonl", "w") as f:
|
161 |
+
# for x in self.intent_ids:
|
162 |
+
# f.write(json.dumps(x) + "\n")
|
163 |
+
with open(f"{path}/outputs.jsonl", "w") as f:
|
164 |
+
if original_dataset is not None:
|
165 |
+
for i, s, d in zip(self.intent_ids, self.slot_ids, original_dataset):
|
166 |
+
f.write(json.dumps({"pred_intent": i, "pred_slot": s, "text": d["text"], "golden_intent":d["intent"], "golden_slot":d["slot"]}) + "\n")
|
167 |
+
else:
|
168 |
+
for i, s in zip(self.intent_ids, self.slot_ids):
|
169 |
+
f.write(json.dumps({"pred_intent": i, "pred_slot": s}) + "\n")
|
170 |
+
|
171 |
+
|
172 |
+
class HiddenData():
|
173 |
+
"""Interactive data structure for all model components
|
174 |
+
"""
|
175 |
+
def __init__(self, intent_hidden, slot_hidden):
|
176 |
+
"""init hidden data structure
|
177 |
+
|
178 |
+
Args:
|
179 |
+
intent_hidden (Any): sentence-level or intent hidden state
|
180 |
+
slot_hidden (Any): token-level or slot hidden state
|
181 |
+
"""
|
182 |
+
self.intent_hidden = intent_hidden
|
183 |
+
self.slot_hidden = slot_hidden
|
184 |
+
self.inputs = None
|
185 |
+
self.embedding = None
|
186 |
+
|
187 |
+
def get_intent_hidden_state(self):
|
188 |
+
"""get intent hidden state
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
Any: intent hidden state
|
192 |
+
"""
|
193 |
+
return self.intent_hidden
|
194 |
+
|
195 |
+
def get_slot_hidden_state(self):
|
196 |
+
"""get slot hidden state
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
Any: slot hidden state
|
200 |
+
"""
|
201 |
+
return self.slot_hidden
|
202 |
+
|
203 |
+
def update_slot_hidden_state(self, hidden_state):
|
204 |
+
"""update slot hidden state
|
205 |
+
|
206 |
+
Args:
|
207 |
+
hidden_state (Any): slot hidden state to update
|
208 |
+
"""
|
209 |
+
self.slot_hidden = hidden_state
|
210 |
+
|
211 |
+
def update_intent_hidden_state(self, hidden_state):
|
212 |
+
"""update intent hidden state
|
213 |
+
|
214 |
+
Args:
|
215 |
+
hidden_state (Any): intent hidden state to update
|
216 |
+
"""
|
217 |
+
self.intent_hidden = hidden_state
|
218 |
+
|
219 |
+
def add_input(self, inputs: InputData or "HiddenData"):
|
220 |
+
"""add last model component input information to next model component
|
221 |
+
|
222 |
+
Args:
|
223 |
+
inputs (InputDataor or HiddenData): last model component input
|
224 |
+
"""
|
225 |
+
self.inputs = inputs
|
226 |
+
|
227 |
+
def add_embedding(self, embedding):
|
228 |
+
self.embedding = embedding
|
229 |
+
|
230 |
+
|
231 |
+
class ClassifierOutputData():
|
232 |
+
"""Classifier output data structure of all classifier components
|
233 |
+
"""
|
234 |
+
def __init__(self, classifier_output):
|
235 |
+
self.classifier_output = classifier_output
|
236 |
+
self.output_embedding = None
|
237 |
+
|
238 |
+
def remove_slot_ignore_index(inputs:InputData, outputs:OutputData, ignore_index=-100):
|
239 |
+
""" remove padding or extra token in input id and output id
|
240 |
+
|
241 |
+
Args:
|
242 |
+
inputs (InputData): input data with input id
|
243 |
+
outputs (OutputData): output data with decoded output id
|
244 |
+
ignore_index (int, optional): ignore_index in input_ids. Defaults to -100.
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
InputData: input data removed padding or extra token
|
248 |
+
OutputData: output data removed padding or extra token
|
249 |
+
"""
|
250 |
+
for index, (inp_ss, out_ss) in enumerate(zip(inputs.slot, outputs.slot_ids)):
|
251 |
+
temp_inp = []
|
252 |
+
temp_out = []
|
253 |
+
for inp_s, out_s in zip(list(inp_ss), list(out_ss)):
|
254 |
+
if inp_s != ignore_index:
|
255 |
+
temp_inp.append(inp_s)
|
256 |
+
temp_out.append(out_s)
|
257 |
+
|
258 |
+
inputs.slot[index] = temp_inp
|
259 |
+
outputs.slot_ids[index] = temp_out
|
260 |
+
return inputs, outputs
|
261 |
+
|
262 |
+
|
263 |
+
def pack_sequence(inputs:Tensor, seq_len:Tensor or List) -> Tensor:
|
264 |
+
"""pack sequence data to packed data without padding.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
inputs (Tensor): list(Tensor) of packed sequence inputs
|
268 |
+
seq_len (Tensor or List): list(Tensor) of sequence length
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
Tensor: packed inputs
|
272 |
+
|
273 |
+
Examples:
|
274 |
+
inputs = [[x, y, z, PAD, PAD], [x, y, PAD, PAD, PAD]]
|
275 |
+
|
276 |
+
seq_len = [3,2]
|
277 |
+
|
278 |
+
return -> [x, y, z, x, y]
|
279 |
+
"""
|
280 |
+
output = []
|
281 |
+
for index, batch in enumerate(inputs):
|
282 |
+
output.append(batch[:seq_len[index]])
|
283 |
+
return torch.cat(output, dim=0)
|
284 |
+
|
285 |
+
|
286 |
+
def unpack_sequence(inputs:Tensor, seq_lens:Tensor or List, padding_value=0) -> Tensor:
|
287 |
+
"""unpack sequence data.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
inputs (Tensor): list(Tensor) of packed sequence inputs
|
291 |
+
seq_lens (Tensor or List): list(Tensor) of sequence length
|
292 |
+
padding_value (int, optional): padding value. Defaults to 0.
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
Tensor: unpacked inputs
|
296 |
+
|
297 |
+
Examples:
|
298 |
+
inputs = [x, y, z, x, y]
|
299 |
+
|
300 |
+
seq_len = [3,2]
|
301 |
+
|
302 |
+
return -> [[x, y, z, PAD, PAD], [x, y, PAD, PAD, PAD]]
|
303 |
+
"""
|
304 |
+
last_idx = 0
|
305 |
+
output = []
|
306 |
+
for _, seq_len in enumerate(seq_lens):
|
307 |
+
output.append(inputs[last_idx:last_idx + seq_len])
|
308 |
+
last_idx = last_idx + seq_len
|
309 |
+
return pad_sequence(output, batch_first=True, padding_value=padding_value)
|
310 |
+
|
311 |
+
|
312 |
+
def get_dict_with_key_prefix(input_dict: dict, prefix=""):
|
313 |
+
res = {}
|
314 |
+
for t in input_dict:
|
315 |
+
res[t + prefix] = input_dict[t]
|
316 |
+
return res
|
317 |
+
|
318 |
+
|
319 |
+
def download(url: str, fname: str):
|
320 |
+
"""download file from url to fname
|
321 |
+
|
322 |
+
Args:
|
323 |
+
url (str): remote server url path
|
324 |
+
fname (str): local path to save
|
325 |
+
"""
|
326 |
+
resp = requests.get(url, stream=True)
|
327 |
+
total = int(resp.headers.get('content-length', 0))
|
328 |
+
with open(fname, 'wb') as file, tqdm(
|
329 |
+
desc=fname,
|
330 |
+
total=total,
|
331 |
+
unit='iB',
|
332 |
+
unit_scale=True,
|
333 |
+
unit_divisor=1024,
|
334 |
+
) as bar:
|
335 |
+
for data in resp.iter_content(chunk_size=1024):
|
336 |
+
size = file.write(data)
|
337 |
+
bar.update(size)
|
338 |
+
|
339 |
+
|
340 |
+
def tar_gz_data(file_name:str):
|
341 |
+
"""use "tar.gz" format to compress data
|
342 |
+
|
343 |
+
Args:
|
344 |
+
file_name (str): file path to tar
|
345 |
+
"""
|
346 |
+
t = tarfile.open(f"{file_name}.tar.gz", "w:gz")
|
347 |
+
|
348 |
+
for root, dir, files in os.walk(f"{file_name}"):
|
349 |
+
print(root, dir, files)
|
350 |
+
for file in files:
|
351 |
+
fullpath = os.path.join(root, file)
|
352 |
+
t.add(fullpath)
|
353 |
+
t.close()
|
354 |
+
|
355 |
+
|
356 |
+
def untar(fname:str, dirs:str):
|
357 |
+
""" uncompress "tar.gz" file
|
358 |
+
|
359 |
+
Args:
|
360 |
+
fname (str): file path to untar
|
361 |
+
dirs (str): target dir path
|
362 |
+
"""
|
363 |
+
t = tarfile.open(fname)
|
364 |
+
t.extractall(path=dirs)
|
365 |
+
|
366 |
+
|
367 |
+
def unzip_file(zip_src:str, dst_dir:str):
|
368 |
+
""" uncompress "zip" file
|
369 |
+
|
370 |
+
Args:
|
371 |
+
fname (str): file path to unzip
|
372 |
+
dirs (str): target dir path
|
373 |
+
"""
|
374 |
+
r = zipfile.is_zipfile(zip_src)
|
375 |
+
if r:
|
376 |
+
if not os.path.exists(dst_dir):
|
377 |
+
os.mkdir(dst_dir)
|
378 |
+
fz = zipfile.ZipFile(zip_src, 'r')
|
379 |
+
for file in fz.namelist():
|
380 |
+
fz.extract(file, dst_dir)
|
381 |
+
else:
|
382 |
+
print('This is not zip')
|
383 |
+
|
384 |
+
|
385 |
+
def find_callable(target: str) -> Callable:
|
386 |
+
""" find callable function / class to instantiate
|
387 |
+
|
388 |
+
Args:
|
389 |
+
target (str): class/module path
|
390 |
+
|
391 |
+
Raises:
|
392 |
+
e: can not import module
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
Callable: return function / class
|
396 |
+
"""
|
397 |
+
target_module_path, target_callable_path = target.rsplit(".", 1)
|
398 |
+
target_callable_paths = [target_callable_path]
|
399 |
+
|
400 |
+
target_module = None
|
401 |
+
while len(target_module_path):
|
402 |
+
try:
|
403 |
+
target_module = importlib.import_module(target_module_path)
|
404 |
+
break
|
405 |
+
except Exception as e:
|
406 |
+
raise e
|
407 |
+
target_callable = target_module
|
408 |
+
for attr in reversed(target_callable_paths):
|
409 |
+
target_callable = getattr(target_callable, attr)
|
410 |
+
|
411 |
+
return target_callable
|
412 |
+
|
413 |
+
|
414 |
+
def instantiate(config, target="_model_target_", partial="_model_partial_"):
|
415 |
+
""" instantiate object by config.
|
416 |
+
|
417 |
+
Modified from https://github.com/HIT-SCIR/ltp/blob/main/python/core/ltp_core/models/utils/instantiate.py.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
config (Any): configuration
|
421 |
+
target (str, optional): key to assign the class to be instantiated. Defaults to "_model_target_".
|
422 |
+
partial (str, optional): key to judge object whether should be instantiated partially. Defaults to "_model_partial_".
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
Any: instantiated object
|
426 |
+
"""
|
427 |
+
if isinstance(config, dict) and target in config:
|
428 |
+
target_path = config.get(target)
|
429 |
+
target_callable = find_callable(target_path)
|
430 |
+
|
431 |
+
is_partial = config.get(partial, False)
|
432 |
+
target_args = {
|
433 |
+
key: instantiate(value)
|
434 |
+
for key, value in config.items()
|
435 |
+
if key not in [target, partial]
|
436 |
+
}
|
437 |
+
|
438 |
+
if is_partial:
|
439 |
+
return functools.partial(target_callable, **target_args)
|
440 |
+
else:
|
441 |
+
return target_callable(**target_args)
|
442 |
+
elif isinstance(config, dict):
|
443 |
+
return {key: instantiate(value) for key, value in config.items()}
|
444 |
+
else:
|
445 |
+
return config
|
446 |
+
|
447 |
+
|
448 |
+
def load_yaml(file):
|
449 |
+
""" load data from yaml files.
|
450 |
+
|
451 |
+
Args:
|
452 |
+
file (str): yaml file path.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
Any: data
|
456 |
+
"""
|
457 |
+
with open(file, encoding="utf-8") as stream:
|
458 |
+
try:
|
459 |
+
return yaml.safe_load(stream)
|
460 |
+
except yaml.YAMLError as exc:
|
461 |
+
raise exc
|
462 |
+
|
463 |
+
def from_configured(configure_name_or_file:str, model_class:Callable, config_prefix="./config/", **input_config):
|
464 |
+
"""load module from pre-configured data
|
465 |
+
|
466 |
+
Args:
|
467 |
+
configure_name_or_file (str): config path -> {config_prefix}/{configure_name_or_file}.yaml
|
468 |
+
model_class (Callable): module class
|
469 |
+
config_prefix (str, optional): configuration root path. Defaults to "./config/".
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
Any: instantiated object.
|
473 |
+
"""
|
474 |
+
if os.path.exists(configure_name_or_file):
|
475 |
+
configure_file=configure_name_or_file
|
476 |
+
else:
|
477 |
+
configure_file= os.path.join(config_prefix, configure_name_or_file+".yaml")
|
478 |
+
config = load_yaml(configure_file)
|
479 |
+
config.update(input_config)
|
480 |
+
return model_class(**config)
|
481 |
+
|
482 |
+
def save_json(file_path, obj):
|
483 |
+
with open(file_path, 'w', encoding="utf8") as fw:
|
484 |
+
fw.write(json.dumps(obj))
|
485 |
+
|
486 |
+
def load_json(file_path):
|
487 |
+
with open(file_path, 'r', encoding="utf8") as fw:
|
488 |
+
res =json.load(fw)
|
489 |
+
return res
|
model/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from model.open_slu_model import OpenSLUModel
|
2 |
+
|
3 |
+
__all__ = ["OpenSLUModel"]
|
model/decoder/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.decoder.agif_decoder import AGIFDecoder
|
2 |
+
from model.decoder.base_decoder import StackPropagationDecoder, BaseDecoder, DCANetDecoder
|
3 |
+
from model.decoder.gl_gin_decoder import GLGINDecoder
|
4 |
+
|
5 |
+
__all__ = ["StackPropagationDecoder", "BaseDecoder", "DCANetDecoder", "AGIFDecoder", "GLGINDecoder"]
|
model/decoder/agif_decoder.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from common.utils import HiddenData, OutputData
|
2 |
+
from model.decoder.base_decoder import BaseDecoder
|
3 |
+
|
4 |
+
|
5 |
+
class AGIFDecoder(BaseDecoder):
|
6 |
+
def forward(self, hidden: HiddenData, **kwargs):
|
7 |
+
# hidden = self.interaction(hidden)
|
8 |
+
pred_intent = self.intent_classifier(hidden)
|
9 |
+
intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),
|
10 |
+
return_list=False,
|
11 |
+
return_sentence_level=True)
|
12 |
+
interact_args = {"intent_index": intent_index,
|
13 |
+
"batch_size": pred_intent.classifier_output.shape[0],
|
14 |
+
"intent_label_num": self.intent_classifier.config["intent_label_num"]}
|
15 |
+
pred_slot = self.slot_classifier(hidden, internal_interaction=self.interaction, **interact_args)
|
16 |
+
return OutputData(pred_intent, pred_slot)
|
model/decoder/base_decoder.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-31 18:22:36
|
6 |
+
Description:
|
7 |
+
|
8 |
+
'''
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from common.utils import HiddenData, OutputData, InputData
|
12 |
+
|
13 |
+
|
14 |
+
class BaseDecoder(nn.Module):
|
15 |
+
"""Base class for all decoder module.
|
16 |
+
|
17 |
+
Notice: t is often only necessary to change this module and its sub-modules
|
18 |
+
"""
|
19 |
+
def __init__(self, intent_classifier, slot_classifier, interaction=None):
|
20 |
+
super().__init__()
|
21 |
+
self.intent_classifier = intent_classifier
|
22 |
+
self.slot_classifier = slot_classifier
|
23 |
+
self.interaction = interaction
|
24 |
+
|
25 |
+
def forward(self, hidden: HiddenData):
|
26 |
+
"""forward
|
27 |
+
|
28 |
+
Args:
|
29 |
+
hidden (HiddenData): encoded data
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
OutputData: prediction logits
|
33 |
+
"""
|
34 |
+
if self.interaction is not None:
|
35 |
+
hidden = self.interaction(hidden)
|
36 |
+
return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))
|
37 |
+
|
38 |
+
def decode(self, output: OutputData, target: InputData = None):
|
39 |
+
"""decode output logits
|
40 |
+
|
41 |
+
Args:
|
42 |
+
output (OutputData): output logits data
|
43 |
+
target (InputData, optional): input data with attention mask. Defaults to None.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
List: decoded sequence ids
|
47 |
+
"""
|
48 |
+
return OutputData(self.intent_classifier.decode(output, target), self.slot_classifier.decode(output, target))
|
49 |
+
|
50 |
+
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
|
51 |
+
"""compute loss.
|
52 |
+
Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
pred (OutputData): output logits data
|
56 |
+
target (InputData): input golden data
|
57 |
+
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
|
58 |
+
compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tensor: loss result
|
62 |
+
"""
|
63 |
+
intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None
|
64 |
+
slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None
|
65 |
+
slot_weight = self.slot_classifier.config.get("weight")
|
66 |
+
slot_weight = slot_weight if slot_weight is not None else 1.
|
67 |
+
intent_weight = self.intent_classifier.config.get("weight")
|
68 |
+
intent_weight = intent_weight if intent_weight is not None else 1.
|
69 |
+
loss = 0
|
70 |
+
if intent_loss is not None:
|
71 |
+
loss += intent_loss * intent_weight
|
72 |
+
if slot_loss is not None:
|
73 |
+
loss += slot_loss * slot_weight
|
74 |
+
return loss, intent_loss, slot_loss
|
75 |
+
|
76 |
+
|
77 |
+
class StackPropagationDecoder(BaseDecoder):
|
78 |
+
|
79 |
+
def forward(self, hidden: HiddenData):
|
80 |
+
# hidden = self.interaction(hidden)
|
81 |
+
pred_intent = self.intent_classifier(hidden)
|
82 |
+
# embedding = pred_intent.output_embedding if pred_intent.output_embedding is not None else pred_intent.classifier_output
|
83 |
+
# hidden.update_intent_hidden_state(torch.cat([hidden.get_slot_hidden_state(), embedding], dim=-1))
|
84 |
+
hidden = self.interaction(pred_intent, hidden)
|
85 |
+
pred_slot = self.slot_classifier(hidden)
|
86 |
+
return OutputData(pred_intent, pred_slot)
|
87 |
+
|
88 |
+
class DCANetDecoder(BaseDecoder):
|
89 |
+
|
90 |
+
def forward(self, hidden: HiddenData):
|
91 |
+
if self.interaction is not None:
|
92 |
+
hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier)
|
93 |
+
return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden))
|
94 |
+
|
model/decoder/classifier.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-31 20:07:00
|
6 |
+
Description:
|
7 |
+
|
8 |
+
'''
|
9 |
+
import random
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import CrossEntropyLoss
|
15 |
+
|
16 |
+
from model.decoder import decoder_utils
|
17 |
+
|
18 |
+
from torchcrf import CRF
|
19 |
+
|
20 |
+
from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \
|
21 |
+
instantiate
|
22 |
+
|
23 |
+
|
24 |
+
class BaseClassifier(nn.Module):
|
25 |
+
"""Base class for all classifier module
|
26 |
+
"""
|
27 |
+
def __init__(self, **config):
|
28 |
+
super().__init__()
|
29 |
+
self.config = config
|
30 |
+
if config.get("loss_fn"):
|
31 |
+
self.loss_fn = instantiate(config.get("loss_fn"))
|
32 |
+
else:
|
33 |
+
self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index"))
|
34 |
+
|
35 |
+
def forward(self, *args, **kwargs):
|
36 |
+
raise NotImplementedError("No implemented classifier.")
|
37 |
+
|
38 |
+
def decode(self, output: OutputData,
|
39 |
+
target: InputData = None,
|
40 |
+
return_list=True,
|
41 |
+
return_sentence_level=None):
|
42 |
+
"""decode output logits
|
43 |
+
|
44 |
+
Args:
|
45 |
+
output (OutputData): output logits data
|
46 |
+
target (InputData, optional): input data with attention mask. Defaults to None.
|
47 |
+
return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True.
|
48 |
+
return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
List or Tensor: decoded sequence ids
|
52 |
+
"""
|
53 |
+
if self.config.get("return_sentence_level") is not None and return_sentence_level is None:
|
54 |
+
return_sentence_level = self.config.get("return_sentence_level")
|
55 |
+
elif self.config.get("return_sentence_level") is None and return_sentence_level is None:
|
56 |
+
return_sentence_level = False
|
57 |
+
return decoder_utils.decode(output, target,
|
58 |
+
return_list=return_list,
|
59 |
+
return_sentence_level=return_sentence_level,
|
60 |
+
pred_type=self.config.get("mode"),
|
61 |
+
use_multi=self.config.get("use_multi"),
|
62 |
+
multi_threshold=self.config.get("multi_threshold"))
|
63 |
+
|
64 |
+
def compute_loss(self, pred: OutputData, target: InputData):
|
65 |
+
"""compute loss
|
66 |
+
|
67 |
+
Args:
|
68 |
+
pred (OutputData): output logits data
|
69 |
+
target (InputData): input golden data
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Tensor: loss result
|
73 |
+
"""
|
74 |
+
_CRF = None
|
75 |
+
if self.config.get("use_crf"):
|
76 |
+
_CRF = self.CRF
|
77 |
+
return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"],
|
78 |
+
use_crf=_CRF is not None,
|
79 |
+
ignore_index=self.config["ignore_index"],
|
80 |
+
use_multi=self.config.get("use_multi"),
|
81 |
+
loss_fn=self.loss_fn,
|
82 |
+
CRF=_CRF)
|
83 |
+
|
84 |
+
|
85 |
+
class LinearClassifier(BaseClassifier):
|
86 |
+
"""
|
87 |
+
Decoder structure based on Linear.
|
88 |
+
"""
|
89 |
+
def __init__(self, **config):
|
90 |
+
"""Construction function for LinearClassifier
|
91 |
+
|
92 |
+
Args:
|
93 |
+
config (dict):
|
94 |
+
input_dim (int): hidden state dim.
|
95 |
+
use_slot (bool): whether to classify slot label.
|
96 |
+
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
|
97 |
+
use_intent (bool): whether to classify intent label.
|
98 |
+
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
|
99 |
+
use_crf (bool): whether to use crf for slot.
|
100 |
+
"""
|
101 |
+
super().__init__(**config)
|
102 |
+
self.config = config
|
103 |
+
if config.get("use_slot"):
|
104 |
+
self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"])
|
105 |
+
if self.config.get("use_crf"):
|
106 |
+
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
|
107 |
+
if config.get("use_intent"):
|
108 |
+
self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"])
|
109 |
+
|
110 |
+
def forward(self, hidden: HiddenData):
|
111 |
+
if self.config.get("use_intent"):
|
112 |
+
return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state()))
|
113 |
+
if self.config.get("use_slot"):
|
114 |
+
return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state()))
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
class AutoregressiveLSTMClassifier(BaseClassifier):
|
119 |
+
"""
|
120 |
+
Decoder structure based on unidirectional LSTM.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self, **config):
|
124 |
+
""" Construction function for Decoder.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
config (dict):
|
128 |
+
input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size.
|
129 |
+
use_slot (bool): whether to classify slot label.
|
130 |
+
slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True.
|
131 |
+
use_intent (bool): whether to classify intent label.
|
132 |
+
intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True.
|
133 |
+
use_crf (bool): whether to use crf for slot.
|
134 |
+
hidden_dim (int): hidden dimension of iterative LSTM.
|
135 |
+
embedding_dim (int): if it's not None, the input and output are relevant.
|
136 |
+
dropout_rate (float): dropout rate of network which is only useful for embedding.
|
137 |
+
"""
|
138 |
+
|
139 |
+
super(AutoregressiveLSTMClassifier, self).__init__(**config)
|
140 |
+
if config.get("use_slot") and config.get("use_crf"):
|
141 |
+
self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True)
|
142 |
+
self.input_dim = config["input_dim"]
|
143 |
+
self.hidden_dim = config["hidden_dim"]
|
144 |
+
if config.get("use_intent"):
|
145 |
+
self.output_dim = config["intent_label_num"]
|
146 |
+
if config.get("use_slot"):
|
147 |
+
self.output_dim = config["slot_label_num"]
|
148 |
+
self.dropout_rate = config["dropout_rate"]
|
149 |
+
self.embedding_dim = config.get("embedding_dim")
|
150 |
+
self.force_ratio = config.get("force_ratio")
|
151 |
+
self.config = config
|
152 |
+
self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100
|
153 |
+
# If embedding_dim is not None, the output and input
|
154 |
+
# of this structure is relevant.
|
155 |
+
if self.embedding_dim is not None:
|
156 |
+
self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim)
|
157 |
+
self.init_tensor = nn.Parameter(
|
158 |
+
torch.randn(1, self.embedding_dim),
|
159 |
+
requires_grad=True
|
160 |
+
)
|
161 |
+
|
162 |
+
# Make sure the input dimension of iterative LSTM.
|
163 |
+
if self.embedding_dim is not None:
|
164 |
+
lstm_input_dim = self.input_dim + self.embedding_dim
|
165 |
+
else:
|
166 |
+
lstm_input_dim = self.input_dim
|
167 |
+
|
168 |
+
# Network parameter definition.
|
169 |
+
self.dropout_layer = nn.Dropout(self.dropout_rate)
|
170 |
+
self.lstm_layer = nn.LSTM(
|
171 |
+
input_size=lstm_input_dim,
|
172 |
+
hidden_size=self.hidden_dim,
|
173 |
+
batch_first=True,
|
174 |
+
bidirectional=self.config["bidirectional"],
|
175 |
+
dropout=self.dropout_rate,
|
176 |
+
num_layers=self.config["layer_num"]
|
177 |
+
)
|
178 |
+
self.linear_layer = nn.Linear(
|
179 |
+
self.hidden_dim,
|
180 |
+
self.output_dim
|
181 |
+
)
|
182 |
+
# self.loss_fn = CrossEntropyLoss(ignore_index=self.ignore_index)
|
183 |
+
|
184 |
+
def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args):
|
185 |
+
""" Forward process for decoder.
|
186 |
+
|
187 |
+
:param internal_interaction:
|
188 |
+
:param hidden:
|
189 |
+
:return: is distribution of prediction labels.
|
190 |
+
"""
|
191 |
+
input_tensor = hidden.slot_hidden
|
192 |
+
seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist()
|
193 |
+
output_tensor_list, sent_start_pos = [], 0
|
194 |
+
input_tensor = pack_sequence(input_tensor, seq_lens)
|
195 |
+
forced_input = None
|
196 |
+
if self.training:
|
197 |
+
if random.random() < self.force_ratio:
|
198 |
+
if self.config["mode"]=="slot":
|
199 |
+
|
200 |
+
forced_slot = pack_sequence(hidden.inputs.slot, seq_lens)
|
201 |
+
temp_slot = []
|
202 |
+
for index, x in enumerate(forced_slot):
|
203 |
+
if index == 0:
|
204 |
+
temp_slot.append(x.reshape(1))
|
205 |
+
elif x == self.ignore_index:
|
206 |
+
temp_slot.append(temp_slot[-1])
|
207 |
+
else:
|
208 |
+
temp_slot.append(x.reshape(1))
|
209 |
+
forced_input = torch.cat(temp_slot, 0)
|
210 |
+
if self.config["mode"]=="token-level-intent":
|
211 |
+
forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1])
|
212 |
+
forced_input = pack_sequence(forced_intent, seq_lens)
|
213 |
+
if self.embedding_dim is None or forced_input is not None:
|
214 |
+
|
215 |
+
for sent_i in range(0, len(seq_lens)):
|
216 |
+
sent_end_pos = sent_start_pos + seq_lens[sent_i]
|
217 |
+
|
218 |
+
# Segment input hidden tensors.
|
219 |
+
seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :]
|
220 |
+
|
221 |
+
if self.embedding_dim is not None and forced_input is not None:
|
222 |
+
if seq_lens[sent_i] > 1:
|
223 |
+
seg_forced_input = forced_input[sent_start_pos: sent_end_pos]
|
224 |
+
|
225 |
+
seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1]
|
226 |
+
seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0)
|
227 |
+
else:
|
228 |
+
seg_prev_tensor = self.init_tensor
|
229 |
+
|
230 |
+
# Concatenate forced target tensor.
|
231 |
+
combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1)
|
232 |
+
else:
|
233 |
+
combined_input = seg_hiddens
|
234 |
+
dropout_input = self.dropout_layer(combined_input)
|
235 |
+
lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1))
|
236 |
+
if internal_interaction is not None:
|
237 |
+
interaction_args["sent_id"] = sent_i
|
238 |
+
lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0]
|
239 |
+
linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1))
|
240 |
+
|
241 |
+
output_tensor_list.append(linear_out)
|
242 |
+
sent_start_pos = sent_end_pos
|
243 |
+
else:
|
244 |
+
for sent_i in range(0, len(seq_lens)):
|
245 |
+
prev_tensor = self.init_tensor
|
246 |
+
|
247 |
+
# It's necessary to remember h and c state
|
248 |
+
# when output prediction every single step.
|
249 |
+
last_h, last_c = None, None
|
250 |
+
|
251 |
+
sent_end_pos = sent_start_pos + seq_lens[sent_i]
|
252 |
+
for word_i in range(sent_start_pos, sent_end_pos):
|
253 |
+
seg_input = input_tensor[[word_i], :]
|
254 |
+
combined_input = torch.cat([seg_input, prev_tensor], dim=1)
|
255 |
+
dropout_input = self.dropout_layer(combined_input).view(1, 1, -1)
|
256 |
+
if last_h is None and last_c is None:
|
257 |
+
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input)
|
258 |
+
else:
|
259 |
+
lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c))
|
260 |
+
|
261 |
+
if internal_interaction is not None:
|
262 |
+
interaction_args["sent_id"] = sent_i
|
263 |
+
lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0]
|
264 |
+
|
265 |
+
lstm_out = self.linear_layer(lstm_out.view(1, -1))
|
266 |
+
output_tensor_list.append(lstm_out)
|
267 |
+
|
268 |
+
_, index = lstm_out.topk(1, dim=1)
|
269 |
+
prev_tensor = self.embedding_layer(index).view(1, -1)
|
270 |
+
sent_start_pos = sent_end_pos
|
271 |
+
seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens)
|
272 |
+
# TODO: 都支持softmax
|
273 |
+
if self.config.get("use_multi"):
|
274 |
+
pred_output = ClassifierOutputData(seq_unpacked)
|
275 |
+
else:
|
276 |
+
pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1))
|
277 |
+
return pred_output
|
278 |
+
|
279 |
+
|
280 |
+
class MLPClassifier(BaseClassifier):
|
281 |
+
"""
|
282 |
+
Decoder structure based on MLP.
|
283 |
+
"""
|
284 |
+
def __init__(self, **config):
|
285 |
+
""" Construction function for Decoder.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
config (dict):
|
289 |
+
use_slot (bool): whether to classify slot label.
|
290 |
+
use_intent (bool): whether to classify intent label.
|
291 |
+
mlp (List):
|
292 |
+
|
293 |
+
- _model_target_: torch.nn.Linear
|
294 |
+
|
295 |
+
in_features (int): input feature dim
|
296 |
+
|
297 |
+
out_features (int): output feature dim
|
298 |
+
|
299 |
+
- _model_target_: torch.nn.LeakyReLU
|
300 |
+
|
301 |
+
negative_slope: 0.2
|
302 |
+
|
303 |
+
- ...
|
304 |
+
"""
|
305 |
+
super(MLPClassifier, self).__init__(**config)
|
306 |
+
self.config = config
|
307 |
+
for i, x in enumerate(config["mlp"]):
|
308 |
+
if isinstance(x.get("in_features"), str):
|
309 |
+
config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]]
|
310 |
+
if isinstance(x.get("out_features"), str):
|
311 |
+
config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]]
|
312 |
+
mlp = [instantiate(x) for x in config["mlp"]]
|
313 |
+
self.seq = nn.Sequential(*mlp)
|
314 |
+
|
315 |
+
|
316 |
+
def forward(self, hidden: HiddenData):
|
317 |
+
if self.config.get("use_intent"):
|
318 |
+
res = self.seq(hidden.intent_hidden)
|
319 |
+
else:
|
320 |
+
res = self.seq(hidden.slot_hidden)
|
321 |
+
return ClassifierOutputData(res)
|
model/decoder/decoder_utils.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from common import utils
|
5 |
+
from common.utils import OutputData, InputData
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
def argmax_for_seq_len(inputs, seq_lens, padding_value=-100):
|
9 |
+
packed_inputs = utils.pack_sequence(inputs, seq_lens)
|
10 |
+
outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True)
|
11 |
+
return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1)
|
12 |
+
|
13 |
+
|
14 |
+
def decode(output: OutputData,
|
15 |
+
target: InputData = None,
|
16 |
+
pred_type="slot",
|
17 |
+
multi_threshold=0.5,
|
18 |
+
ignore_index=-100,
|
19 |
+
return_list=True,
|
20 |
+
return_sentence_level=True,
|
21 |
+
use_multi=False,
|
22 |
+
use_crf=False,
|
23 |
+
CRF=None) -> List or Tensor:
|
24 |
+
""" decode output logits
|
25 |
+
|
26 |
+
Args:
|
27 |
+
output (OutputData): output logits data
|
28 |
+
target (InputData, optional): input data with attention mask. Defaults to None.
|
29 |
+
pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
|
30 |
+
multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5.
|
31 |
+
ignore_index (int, optional): align and pad token with ignore index. Defaults to -100.
|
32 |
+
return_list (bool, optional): if True return list else return torch Tensor. Defaults to True.
|
33 |
+
return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True.
|
34 |
+
use_multi (bool, optional): whether to decode to multi intent. Defaults to False.
|
35 |
+
use_crf (bool, optional): whether to use crf. Defaults to False.
|
36 |
+
CRF (CRF, optional): CRF function. Defaults to None.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
List or Tensor: decoded sequence ids
|
40 |
+
"""
|
41 |
+
if pred_type == "slot":
|
42 |
+
inputs = output.slot_ids
|
43 |
+
else:
|
44 |
+
inputs = output.intent_ids
|
45 |
+
|
46 |
+
if pred_type == "slot":
|
47 |
+
if not use_multi:
|
48 |
+
if use_crf:
|
49 |
+
res = CRF.decode(inputs, mask=target.attention_mask)
|
50 |
+
else:
|
51 |
+
res = torch.argmax(inputs, dim=-1)
|
52 |
+
else:
|
53 |
+
raise NotImplementedError("Multi-slot prediction is not supported.")
|
54 |
+
elif pred_type == "intent":
|
55 |
+
if not use_multi:
|
56 |
+
res = torch.argmax(inputs, dim=-1)
|
57 |
+
else:
|
58 |
+
res = (torch.sigmoid(inputs) > multi_threshold).nonzero()
|
59 |
+
if return_list:
|
60 |
+
res_index = res.detach().cpu().tolist()
|
61 |
+
res_list = [[] for _ in range(len(target.seq_lens))]
|
62 |
+
for item in res_index:
|
63 |
+
res_list[item[0]].append(item[1])
|
64 |
+
return res_list
|
65 |
+
else:
|
66 |
+
return res
|
67 |
+
elif pred_type == "token-level-intent":
|
68 |
+
if not use_multi:
|
69 |
+
res = torch.argmax(inputs, dim=-1)
|
70 |
+
if not return_sentence_level:
|
71 |
+
return res
|
72 |
+
if return_list:
|
73 |
+
res = res.detach().cpu().tolist()
|
74 |
+
attention_mask = target.attention_mask
|
75 |
+
for i in range(attention_mask.shape[0]):
|
76 |
+
temp = []
|
77 |
+
for j in range(attention_mask.shape[1]):
|
78 |
+
if attention_mask[i][j] == 1:
|
79 |
+
temp.append(res[i][j])
|
80 |
+
else:
|
81 |
+
break
|
82 |
+
res[i] = temp
|
83 |
+
return [max(it, key=lambda v: it.count(v)) for it in res]
|
84 |
+
else:
|
85 |
+
seq_lens = target.seq_lens
|
86 |
+
|
87 |
+
if not return_sentence_level:
|
88 |
+
token_res = torch.cat([
|
89 |
+
torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold
|
90 |
+
for i in range(len(seq_lens))],
|
91 |
+
dim=0)
|
92 |
+
return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index)
|
93 |
+
|
94 |
+
intent_index_sum = torch.cat([
|
95 |
+
torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0)
|
96 |
+
for i in range(len(seq_lens))],
|
97 |
+
dim=0)
|
98 |
+
|
99 |
+
res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero()
|
100 |
+
if return_list:
|
101 |
+
res_index = res.detach().cpu().tolist()
|
102 |
+
res_list = [[] for _ in range(len(seq_lens))]
|
103 |
+
for item in res_index:
|
104 |
+
res_list[item[0]].append(item[1])
|
105 |
+
return res_list
|
106 |
+
else:
|
107 |
+
return res
|
108 |
+
else:
|
109 |
+
raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.")
|
110 |
+
if return_list:
|
111 |
+
res = res.detach().cpu().tolist()
|
112 |
+
return res
|
113 |
+
|
114 |
+
|
115 |
+
def compute_loss(pred: OutputData,
|
116 |
+
target: InputData,
|
117 |
+
criterion_type="slot",
|
118 |
+
use_crf=False,
|
119 |
+
ignore_index=-100,
|
120 |
+
loss_fn=None,
|
121 |
+
use_multi=False,
|
122 |
+
CRF=None):
|
123 |
+
""" compute loss
|
124 |
+
|
125 |
+
Args:
|
126 |
+
pred (OutputData): output logits data
|
127 |
+
target (InputData): input golden data
|
128 |
+
criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot".
|
129 |
+
ignore_index (int, optional): compute loss with ignore index. Defaults to -100.
|
130 |
+
loss_fn (_type_, optional): loss function. Defaults to None.
|
131 |
+
use_crf (bool, optional): whether to use crf. Defaults to False.
|
132 |
+
CRF (CRF, optional): CRF function. Defaults to None.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Tensor: loss result
|
136 |
+
"""
|
137 |
+
if criterion_type == "slot":
|
138 |
+
if use_crf:
|
139 |
+
return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte())
|
140 |
+
else:
|
141 |
+
pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens)
|
142 |
+
target_slot = utils.pack_sequence(target.slot, target.seq_lens)
|
143 |
+
return loss_fn(pred_slot, target_slot)
|
144 |
+
elif criterion_type == "token-level-intent":
|
145 |
+
# TODO: Two decode function
|
146 |
+
intent_target = target.intent.unsqueeze(1)
|
147 |
+
if not use_multi:
|
148 |
+
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1])
|
149 |
+
else:
|
150 |
+
intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1)
|
151 |
+
intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens)
|
152 |
+
intent_target = utils.pack_sequence(intent_target, target.seq_lens)
|
153 |
+
return loss_fn(intent_pred, intent_target)
|
154 |
+
else:
|
155 |
+
return loss_fn(pred.intent_ids, target.intent)
|
model/decoder/gl_gin_decoder.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from common.utils import HiddenData, OutputData, InputData
|
6 |
+
from model.decoder import BaseDecoder
|
7 |
+
from model.decoder.interaction.gl_gin_interaction import LSTMEncoder
|
8 |
+
|
9 |
+
|
10 |
+
class IntentEncoder(nn.Module):
|
11 |
+
def __init__(self,input_dim, dropout_rate):
|
12 |
+
super().__init__()
|
13 |
+
self.dropout_rate = dropout_rate
|
14 |
+
self.__intent_lstm = LSTMEncoder(
|
15 |
+
input_dim,
|
16 |
+
input_dim,
|
17 |
+
dropout_rate
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, g_hiddens, seq_lens):
|
21 |
+
intent_lstm_out = self.__intent_lstm(g_hiddens, seq_lens)
|
22 |
+
return F.dropout(intent_lstm_out, p=self.dropout_rate, training=self.training)
|
23 |
+
|
24 |
+
|
25 |
+
class GLGINDecoder(BaseDecoder):
|
26 |
+
def __init__(self, intent_classifier, slot_classifier, interaction=None, **config):
|
27 |
+
super().__init__(intent_classifier, slot_classifier, interaction)
|
28 |
+
self.config=config
|
29 |
+
self.intent_encoder = IntentEncoder(self.intent_classifier.config["input_dim"], self.config["dropout_rate"])
|
30 |
+
|
31 |
+
def forward(self, hidden: HiddenData, forced_slot=None, forced_intent=None, differentiable=None):
|
32 |
+
seq_lens = hidden.inputs.attention_mask.sum(-1)
|
33 |
+
intent_lstm_out = self.intent_encoder(hidden.slot_hidden, seq_lens)
|
34 |
+
hidden.update_intent_hidden_state(intent_lstm_out)
|
35 |
+
pred_intent = self.intent_classifier(hidden)
|
36 |
+
intent_index = self.intent_classifier.decode(OutputData(pred_intent, None),hidden.inputs,
|
37 |
+
return_list=False,
|
38 |
+
return_sentence_level=True)
|
39 |
+
slot_hidden = self.interaction(
|
40 |
+
hidden,
|
41 |
+
pred_intent=pred_intent,
|
42 |
+
intent_index=intent_index,
|
43 |
+
)
|
44 |
+
pred_slot = self.slot_classifier(slot_hidden)
|
45 |
+
num_intent = self.intent_classifier.config["intent_label_num"]
|
46 |
+
pred_slot = pred_slot.classifier_output[:, num_intent:]
|
47 |
+
return OutputData(pred_intent, F.log_softmax(pred_slot, dim=1))
|
model/decoder/interaction/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.decoder.interaction.agif_interaction import AGIFInteraction
|
2 |
+
from model.decoder.interaction.base_interaction import BaseInteraction
|
3 |
+
from model.decoder.interaction.bi_model_interaction import BiModelInteraction, BiModelWithoutDecoderInteraction
|
4 |
+
from model.decoder.interaction.dca_net_interaction import DCANetInteraction
|
5 |
+
from model.decoder.interaction.gl_gin_interaction import GLGINInteraction
|
6 |
+
from model.decoder.interaction.slot_gated_interaction import SlotGatedInteraction
|
7 |
+
from model.decoder.interaction.stack_interaction import StackInteraction
|
8 |
+
|
9 |
+
__all__ = ["BaseInteraction", "BiModelInteraction", "BiModelWithoutDecoderInteraction", "DCANetInteraction",
|
10 |
+
"StackInteraction", "SlotGatedInteraction", "AGIFInteraction", "GLGINInteraction"]
|
model/decoder/interaction/agif_interaction.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from model.decoder.interaction.base_interaction import BaseInteraction
|
6 |
+
|
7 |
+
|
8 |
+
class GraphAttentionLayer(nn.Module):
|
9 |
+
"""
|
10 |
+
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
|
14 |
+
super(GraphAttentionLayer, self).__init__()
|
15 |
+
self.dropout = dropout
|
16 |
+
self.in_features = in_features
|
17 |
+
self.out_features = out_features
|
18 |
+
self.alpha = alpha
|
19 |
+
self.concat = concat
|
20 |
+
|
21 |
+
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
|
22 |
+
nn.init.xavier_uniform_(self.W.data, gain=1.414)
|
23 |
+
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
|
24 |
+
nn.init.xavier_uniform_(self.a.data, gain=1.414)
|
25 |
+
|
26 |
+
self.leakyrelu = nn.LeakyReLU(self.alpha)
|
27 |
+
|
28 |
+
def forward(self, input, adj):
|
29 |
+
h = torch.matmul(input, self.W)
|
30 |
+
B, N = h.size()[0], h.size()[1]
|
31 |
+
|
32 |
+
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1,
|
33 |
+
2 * self.out_features)
|
34 |
+
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
|
35 |
+
|
36 |
+
zero_vec = -9e15 * torch.ones_like(e)
|
37 |
+
attention = torch.where(adj > 0, e, zero_vec)
|
38 |
+
attention = F.softmax(attention, dim=2)
|
39 |
+
attention = F.dropout(attention, self.dropout, training=self.training)
|
40 |
+
h_prime = torch.matmul(attention, h)
|
41 |
+
|
42 |
+
if self.concat:
|
43 |
+
return F.elu(h_prime)
|
44 |
+
else:
|
45 |
+
return h_prime
|
46 |
+
|
47 |
+
|
48 |
+
class GAT(nn.Module):
|
49 |
+
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2):
|
50 |
+
"""Dense version of GAT."""
|
51 |
+
super(GAT, self).__init__()
|
52 |
+
self.dropout = dropout
|
53 |
+
self.nlayers = nlayers
|
54 |
+
self.nheads = nheads
|
55 |
+
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
|
56 |
+
range(nheads)]
|
57 |
+
for i, attention in enumerate(self.attentions):
|
58 |
+
self.add_module('attention_{}'.format(i), attention)
|
59 |
+
if self.nlayers > 2:
|
60 |
+
for i in range(self.nlayers - 2):
|
61 |
+
for j in range(self.nheads):
|
62 |
+
self.add_module('attention_{}_{}'.format(i + 1, j),
|
63 |
+
GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True))
|
64 |
+
|
65 |
+
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
|
66 |
+
|
67 |
+
def forward(self, x, adj):
|
68 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
69 |
+
input = x
|
70 |
+
x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
|
71 |
+
if self.nlayers > 2:
|
72 |
+
for i in range(self.nlayers - 2):
|
73 |
+
temp = []
|
74 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
75 |
+
cur_input = x
|
76 |
+
for j in range(self.nheads):
|
77 |
+
temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj))
|
78 |
+
x = torch.cat(temp, dim=2) + cur_input
|
79 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
80 |
+
x = F.elu(self.out_att(x, adj))
|
81 |
+
return x + input
|
82 |
+
|
83 |
+
|
84 |
+
def normalize_adj(mx):
|
85 |
+
"""
|
86 |
+
Row-normalize matrix D^{-1}A
|
87 |
+
torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447
|
88 |
+
"""
|
89 |
+
mx = mx.float()
|
90 |
+
rowsum = mx.sum(2)
|
91 |
+
r_inv = torch.pow(rowsum, -1)
|
92 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
93 |
+
r_mat_inv = torch.diag_embed(r_inv, 0)
|
94 |
+
mx = r_mat_inv.matmul(mx)
|
95 |
+
return mx
|
96 |
+
|
97 |
+
|
98 |
+
class AGIFInteraction(BaseInteraction):
|
99 |
+
def __init__(self, **config):
|
100 |
+
super().__init__(**config)
|
101 |
+
self.intent_embedding = nn.Parameter(
|
102 |
+
torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) # 191, 32
|
103 |
+
nn.init.normal_(self.intent_embedding.data)
|
104 |
+
self.adj = None
|
105 |
+
self.graph = GAT(
|
106 |
+
config["output_dim"],
|
107 |
+
config["hidden_dim"],
|
108 |
+
config["output_dim"],
|
109 |
+
config["dropout_rate"],
|
110 |
+
config["alpha"],
|
111 |
+
config["num_heads"],
|
112 |
+
config["num_layers"])
|
113 |
+
|
114 |
+
def generate_adj_gat(self, index, batch, intent_label_num):
|
115 |
+
intent_idx_ = [[torch.tensor(0)] for i in range(batch)]
|
116 |
+
for item in index:
|
117 |
+
intent_idx_[item[0]].append(item[1] + 1)
|
118 |
+
intent_idx = intent_idx_
|
119 |
+
self.adj = torch.cat([torch.eye(intent_label_num + 1).unsqueeze(0) for i in range(batch)])
|
120 |
+
for i in range(batch):
|
121 |
+
for j in intent_idx[i]:
|
122 |
+
self.adj[i, j, intent_idx[i]] = 1.
|
123 |
+
if self.config["row_normalized"]:
|
124 |
+
self.adj = normalize_adj(self.adj)
|
125 |
+
self.adj = self.adj.to(self.intent_embedding.device)
|
126 |
+
|
127 |
+
def forward(self, encode_hidden, **interaction_args):
|
128 |
+
if self.adj is None or interaction_args["sent_id"] == 0:
|
129 |
+
self.generate_adj_gat(interaction_args["intent_index"], interaction_args["batch_size"], interaction_args["intent_label_num"])
|
130 |
+
lstm_out = torch.cat((encode_hidden,
|
131 |
+
self.intent_embedding.unsqueeze(0).repeat(encode_hidden.shape[0], 1, 1)), dim=1)
|
132 |
+
return self.graph(lstm_out, self.adj[interaction_args["sent_id"]])
|
model/decoder/interaction/base_interaction.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
class BaseInteraction(nn.Module):
|
4 |
+
def __init__(self, **config):
|
5 |
+
super().__init__()
|
6 |
+
self.config = config
|
7 |
+
|
8 |
+
def forward(self, hidden1, hidden2):
|
9 |
+
NotImplementedError("no implemented")
|
model/decoder/interaction/bi_model_interaction.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from common.utils import HiddenData
|
6 |
+
from model.decoder.interaction.base_interaction import BaseInteraction
|
7 |
+
|
8 |
+
|
9 |
+
class BiModelInteraction(BaseInteraction):
|
10 |
+
def __init__(self, **config):
|
11 |
+
super().__init__(**config)
|
12 |
+
self.intent_lstm = nn.LSTM(input_size=self.config["input_dim"], hidden_size=self.config["output_dim"],
|
13 |
+
batch_first=True,
|
14 |
+
num_layers=1)
|
15 |
+
self.slot_lstm = nn.LSTM(input_size=self.config["input_dim"] + self.config["output_dim"],
|
16 |
+
hidden_size=self.config["output_dim"], num_layers=1)
|
17 |
+
|
18 |
+
def forward(self, encode_hidden: HiddenData, **kwargs):
|
19 |
+
slot_hidden = encode_hidden.get_slot_hidden_state()
|
20 |
+
intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
|
21 |
+
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
|
22 |
+
batch = slot_hidden.size(0)
|
23 |
+
length = slot_hidden.size(1)
|
24 |
+
dec_init_out = torch.zeros(batch, 1, self.config["output_dim"]).to(slot_hidden.device)
|
25 |
+
hidden_state = (torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device), torch.zeros(1, 1, self.config["output_dim"]).to(slot_hidden.device))
|
26 |
+
slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1).transpose(1,
|
27 |
+
0) # 50 x batch x feature_size
|
28 |
+
slot_drop = F.dropout(slot_hidden, self.config["dropout_rate"])
|
29 |
+
all_out = []
|
30 |
+
for i in range(length):
|
31 |
+
if i == 0:
|
32 |
+
out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), dec_init_out), dim=-1),
|
33 |
+
hidden_state)
|
34 |
+
else:
|
35 |
+
out, hidden_state = self.slot_lstm(torch.cat((slot_drop[i].unsqueeze(1), out), dim=-1), hidden_state)
|
36 |
+
all_out.append(out)
|
37 |
+
slot_output = torch.cat(all_out, dim=1) # batch x 50 x feature_size
|
38 |
+
|
39 |
+
intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
|
40 |
+
encode_hidden.get_slot_hidden_state().clone().detach()),
|
41 |
+
dim=-1)
|
42 |
+
intent_drop = F.dropout(intent_hidden, self.config["dropout_rate"])
|
43 |
+
intent_lstm_output, _ = self.intent_lstm(intent_drop)
|
44 |
+
intent_output = F.dropout(intent_lstm_output, self.config["dropout_rate"])
|
45 |
+
output_list = []
|
46 |
+
for index, slen in enumerate(seq_lens):
|
47 |
+
output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
|
48 |
+
|
49 |
+
encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
|
50 |
+
encode_hidden.update_slot_hidden_state(slot_output)
|
51 |
+
|
52 |
+
return encode_hidden
|
53 |
+
|
54 |
+
|
55 |
+
class BiModelWithoutDecoderInteraction(BaseInteraction):
|
56 |
+
def forward(self, encode_hidden: HiddenData, **kwargs):
|
57 |
+
slot_hidden = encode_hidden.get_slot_hidden_state()
|
58 |
+
intent_hidden_detached = encode_hidden.get_intent_hidden_state().clone().detach()
|
59 |
+
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
|
60 |
+
slot_hidden = torch.cat((slot_hidden, intent_hidden_detached), dim=-1) # 50 x batch x feature_size
|
61 |
+
slot_output = F.dropout(slot_hidden, self.config["dropout_rate"])
|
62 |
+
|
63 |
+
intent_hidden = torch.cat((encode_hidden.get_intent_hidden_state(),
|
64 |
+
encode_hidden.get_slot_hidden_state().clone().detach()),
|
65 |
+
dim=-1)
|
66 |
+
intent_output = F.dropout(intent_hidden, self.config["dropout_rate"])
|
67 |
+
output_list = []
|
68 |
+
for index, slen in enumerate(seq_lens):
|
69 |
+
output_list.append(intent_output[index, slen - 1, :].unsqueeze(0))
|
70 |
+
|
71 |
+
encode_hidden.update_intent_hidden_state(torch.cat(output_list, dim=0))
|
72 |
+
encode_hidden.update_slot_hidden_state(slot_output)
|
73 |
+
|
74 |
+
return encode_hidden
|
model/decoder/interaction/dca_net_interaction.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn import LayerNorm
|
7 |
+
|
8 |
+
from common.utils import HiddenData
|
9 |
+
from model.decoder.interaction import BaseInteraction
|
10 |
+
|
11 |
+
|
12 |
+
class DCANetInteraction(BaseInteraction):
|
13 |
+
def __init__(self, **config):
|
14 |
+
super().__init__(**config)
|
15 |
+
self.I_S_Emb = Label_Attention()
|
16 |
+
self.T_block1 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"])
|
17 |
+
self.T_block2 = I_S_Block(self.config["input_dim"], self.config["attention_dropout"], self.config["num_attention_heads"])
|
18 |
+
|
19 |
+
def forward(self, encode_hidden: HiddenData, **kwargs):
|
20 |
+
mask = encode_hidden.inputs.attention_mask
|
21 |
+
H = encode_hidden.slot_hidden
|
22 |
+
H_I, H_S = self.I_S_Emb(H, H, kwargs["intent_emb"], kwargs["slot_emb"])
|
23 |
+
H_I, H_S = self.T_block1(H_I + H, H_S + H, mask)
|
24 |
+
H_I_1, H_S_1 = self.I_S_Emb(H_I, H_S, kwargs["intent_emb"], kwargs["slot_emb"])
|
25 |
+
H_I, H_S = self.T_block2(H_I + H_I_1, H_S + H_S_1, mask)
|
26 |
+
encode_hidden.update_intent_hidden_state(F.max_pool1d((H_I + H).transpose(1, 2), H_I.size(1)).squeeze(2))
|
27 |
+
encode_hidden.update_slot_hidden_state(H_S + H)
|
28 |
+
return encode_hidden
|
29 |
+
|
30 |
+
|
31 |
+
class Label_Attention(nn.Module):
|
32 |
+
def __init__(self):
|
33 |
+
super(Label_Attention, self).__init__()
|
34 |
+
|
35 |
+
def forward(self, input_intent, input_slot, intent_emb, slot_emb):
|
36 |
+
self.W_intent_emb = intent_emb.intent_classifier.weight
|
37 |
+
self.W_slot_emb = slot_emb.slot_classifier.weight
|
38 |
+
intent_score = torch.matmul(input_intent, self.W_intent_emb.t())
|
39 |
+
slot_score = torch.matmul(input_slot, self.W_slot_emb.t())
|
40 |
+
intent_probs = nn.Softmax(dim=-1)(intent_score)
|
41 |
+
slot_probs = nn.Softmax(dim=-1)(slot_score)
|
42 |
+
intent_res = torch.matmul(intent_probs, self.W_intent_emb)
|
43 |
+
slot_res = torch.matmul(slot_probs, self.W_slot_emb)
|
44 |
+
|
45 |
+
return intent_res, slot_res
|
46 |
+
|
47 |
+
|
48 |
+
class I_S_Block(nn.Module):
|
49 |
+
def __init__(self, hidden_size, attention_dropout, num_attention_heads):
|
50 |
+
super(I_S_Block, self).__init__()
|
51 |
+
self.I_S_Attention = I_S_SelfAttention(hidden_size, 2 * hidden_size, hidden_size, attention_dropout, num_attention_heads)
|
52 |
+
self.I_Out = SelfOutput(hidden_size, attention_dropout)
|
53 |
+
self.S_Out = SelfOutput(hidden_size, attention_dropout)
|
54 |
+
self.I_S_Feed_forward = Intermediate_I_S(hidden_size, hidden_size, attention_dropout)
|
55 |
+
|
56 |
+
def forward(self, H_intent_input, H_slot_input, mask):
|
57 |
+
H_slot, H_intent = self.I_S_Attention(H_intent_input, H_slot_input, mask)
|
58 |
+
H_slot = self.S_Out(H_slot, H_slot_input)
|
59 |
+
H_intent = self.I_Out(H_intent, H_intent_input)
|
60 |
+
H_intent, H_slot = self.I_S_Feed_forward(H_intent, H_slot)
|
61 |
+
|
62 |
+
return H_intent, H_slot
|
63 |
+
|
64 |
+
|
65 |
+
class I_S_SelfAttention(nn.Module):
|
66 |
+
def __init__(self, input_size, hidden_size, out_size, attention_dropout, num_attention_heads):
|
67 |
+
super(I_S_SelfAttention, self).__init__()
|
68 |
+
|
69 |
+
self.num_attention_heads = num_attention_heads
|
70 |
+
self.attention_head_size = int(hidden_size / self.num_attention_heads)
|
71 |
+
|
72 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
73 |
+
self.out_size = out_size
|
74 |
+
self.query = nn.Linear(input_size, self.all_head_size)
|
75 |
+
self.query_slot = nn.Linear(input_size, self.all_head_size)
|
76 |
+
self.key = nn.Linear(input_size, self.all_head_size)
|
77 |
+
self.key_slot = nn.Linear(input_size, self.all_head_size)
|
78 |
+
self.value = nn.Linear(input_size, self.out_size)
|
79 |
+
self.value_slot = nn.Linear(input_size, self.out_size)
|
80 |
+
self.dropout = nn.Dropout(attention_dropout)
|
81 |
+
|
82 |
+
def transpose_for_scores(self, x):
|
83 |
+
last_dim = int(x.size()[-1] / self.num_attention_heads)
|
84 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, last_dim)
|
85 |
+
x = x.view(*new_x_shape)
|
86 |
+
return x.permute(0, 2, 1, 3)
|
87 |
+
|
88 |
+
def forward(self, intent, slot, mask):
|
89 |
+
extended_attention_mask = mask.unsqueeze(1).unsqueeze(2)
|
90 |
+
|
91 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
92 |
+
attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
93 |
+
|
94 |
+
mixed_query_layer = self.query(intent)
|
95 |
+
mixed_key_layer = self.key(slot)
|
96 |
+
mixed_value_layer = self.value(slot)
|
97 |
+
|
98 |
+
mixed_query_layer_slot = self.query_slot(slot)
|
99 |
+
mixed_key_layer_slot = self.key_slot(intent)
|
100 |
+
mixed_value_layer_slot = self.value_slot(intent)
|
101 |
+
|
102 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
103 |
+
query_layer_slot = self.transpose_for_scores(mixed_query_layer_slot)
|
104 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
105 |
+
key_layer_slot = self.transpose_for_scores(mixed_key_layer_slot)
|
106 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
107 |
+
value_layer_slot = self.transpose_for_scores(mixed_value_layer_slot)
|
108 |
+
|
109 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
110 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
111 |
+
# attention_scores_slot = torch.matmul(query_slot, key_slot.transpose(1,0))
|
112 |
+
attention_scores_slot = torch.matmul(query_layer_slot, key_layer_slot.transpose(-1, -2))
|
113 |
+
attention_scores_slot = attention_scores_slot / math.sqrt(self.attention_head_size)
|
114 |
+
attention_scores_intent = attention_scores + attention_mask
|
115 |
+
|
116 |
+
attention_scores_slot = attention_scores_slot + attention_mask
|
117 |
+
|
118 |
+
# Normalize the attention scores to probabilities.
|
119 |
+
attention_probs_slot = nn.Softmax(dim=-1)(attention_scores_slot)
|
120 |
+
attention_probs_intent = nn.Softmax(dim=-1)(attention_scores_intent)
|
121 |
+
|
122 |
+
attention_probs_slot = self.dropout(attention_probs_slot)
|
123 |
+
attention_probs_intent = self.dropout(attention_probs_intent)
|
124 |
+
|
125 |
+
context_layer_slot = torch.matmul(attention_probs_slot, value_layer_slot)
|
126 |
+
context_layer_intent = torch.matmul(attention_probs_intent, value_layer)
|
127 |
+
|
128 |
+
context_layer = context_layer_slot.permute(0, 2, 1, 3).contiguous()
|
129 |
+
context_layer_intent = context_layer_intent.permute(0, 2, 1, 3).contiguous()
|
130 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.out_size,)
|
131 |
+
new_context_layer_shape_intent = context_layer_intent.size()[:-2] + (self.out_size,)
|
132 |
+
|
133 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
134 |
+
context_layer_intent = context_layer_intent.view(*new_context_layer_shape_intent)
|
135 |
+
return context_layer, context_layer_intent
|
136 |
+
|
137 |
+
|
138 |
+
class SelfOutput(nn.Module):
|
139 |
+
def __init__(self, hidden_size, hidden_dropout_prob):
|
140 |
+
super(SelfOutput, self).__init__()
|
141 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
142 |
+
self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
|
143 |
+
self.dropout = nn.Dropout(hidden_dropout_prob)
|
144 |
+
|
145 |
+
def forward(self, hidden_states, input_tensor):
|
146 |
+
hidden_states = self.dense(hidden_states)
|
147 |
+
hidden_states = self.dropout(hidden_states)
|
148 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
149 |
+
return hidden_states
|
150 |
+
|
151 |
+
|
152 |
+
class Intermediate_I_S(nn.Module):
|
153 |
+
def __init__(self, intermediate_size, hidden_size, attention_dropout):
|
154 |
+
super(Intermediate_I_S, self).__init__()
|
155 |
+
self.dense_in = nn.Linear(hidden_size * 6, intermediate_size)
|
156 |
+
self.intermediate_act_fn = nn.ReLU()
|
157 |
+
self.dense_out = nn.Linear(intermediate_size, hidden_size)
|
158 |
+
self.LayerNorm_I = LayerNorm(hidden_size, eps=1e-12)
|
159 |
+
self.LayerNorm_S = LayerNorm(hidden_size, eps=1e-12)
|
160 |
+
self.dropout = nn.Dropout(attention_dropout)
|
161 |
+
|
162 |
+
def forward(self, hidden_states_I, hidden_states_S):
|
163 |
+
hidden_states_in = torch.cat([hidden_states_I, hidden_states_S], dim=2)
|
164 |
+
batch_size, max_length, hidden_size = hidden_states_in.size()
|
165 |
+
h_pad = torch.zeros(batch_size, 1, hidden_size).to(hidden_states_I.device)
|
166 |
+
h_left = torch.cat([h_pad, hidden_states_in[:, :max_length - 1, :]], dim=1)
|
167 |
+
h_right = torch.cat([hidden_states_in[:, 1:, :], h_pad], dim=1)
|
168 |
+
hidden_states_in = torch.cat([hidden_states_in, h_left, h_right], dim=2)
|
169 |
+
|
170 |
+
hidden_states = self.dense_in(hidden_states_in)
|
171 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
172 |
+
hidden_states = self.dense_out(hidden_states)
|
173 |
+
hidden_states = self.dropout(hidden_states)
|
174 |
+
hidden_states_I_NEW = self.LayerNorm_I(hidden_states + hidden_states_I)
|
175 |
+
hidden_states_S_NEW = self.LayerNorm_S(hidden_states + hidden_states_S)
|
176 |
+
return hidden_states_I_NEW, hidden_states_S_NEW
|
model/decoder/interaction/gl_gin_interaction.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
5 |
+
|
6 |
+
from common.utils import HiddenData, ClassifierOutputData
|
7 |
+
from model.decoder.interaction import BaseInteraction
|
8 |
+
|
9 |
+
|
10 |
+
class LSTMEncoder(nn.Module):
|
11 |
+
"""
|
12 |
+
Encoder structure based on bidirectional LSTM.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, embedding_dim, hidden_dim, dropout_rate):
|
16 |
+
super(LSTMEncoder, self).__init__()
|
17 |
+
|
18 |
+
# Parameter recording.
|
19 |
+
self.__embedding_dim = embedding_dim
|
20 |
+
self.__hidden_dim = hidden_dim // 2
|
21 |
+
self.__dropout_rate = dropout_rate
|
22 |
+
|
23 |
+
# Network attributes.
|
24 |
+
self.__dropout_layer = nn.Dropout(self.__dropout_rate)
|
25 |
+
self.__lstm_layer = nn.LSTM(
|
26 |
+
input_size=self.__embedding_dim,
|
27 |
+
hidden_size=self.__hidden_dim,
|
28 |
+
batch_first=True,
|
29 |
+
bidirectional=True,
|
30 |
+
dropout=self.__dropout_rate,
|
31 |
+
num_layers=1
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, embedded_text, seq_lens):
|
35 |
+
""" Forward process for LSTM Encoder.
|
36 |
+
|
37 |
+
(batch_size, max_sent_len)
|
38 |
+
-> (batch_size, max_sent_len, word_dim)
|
39 |
+
-> (batch_size, max_sent_len, hidden_dim)
|
40 |
+
|
41 |
+
:param embedded_text: padded and embedded input text.
|
42 |
+
:param seq_lens: is the length of original input text.
|
43 |
+
:return: is encoded word hidden vectors.
|
44 |
+
"""
|
45 |
+
|
46 |
+
# Padded_text should be instance of LongTensor.
|
47 |
+
dropout_text = self.__dropout_layer(embedded_text)
|
48 |
+
|
49 |
+
# Pack and Pad process for input of variable length.
|
50 |
+
packed_text = pack_padded_sequence(dropout_text, seq_lens.cpu(), batch_first=True, enforce_sorted=False)
|
51 |
+
lstm_hiddens, (h_last, c_last) = self.__lstm_layer(packed_text)
|
52 |
+
padded_hiddens, _ = pad_packed_sequence(lstm_hiddens, batch_first=True)
|
53 |
+
|
54 |
+
return padded_hiddens
|
55 |
+
|
56 |
+
|
57 |
+
class GraphAttentionLayer(nn.Module):
|
58 |
+
"""
|
59 |
+
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
|
63 |
+
super(GraphAttentionLayer, self).__init__()
|
64 |
+
self.dropout = dropout
|
65 |
+
self.in_features = in_features
|
66 |
+
self.out_features = out_features
|
67 |
+
self.alpha = alpha
|
68 |
+
self.concat = concat
|
69 |
+
|
70 |
+
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
|
71 |
+
nn.init.xavier_uniform_(self.W.data, gain=1.414)
|
72 |
+
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
|
73 |
+
nn.init.xavier_uniform_(self.a.data, gain=1.414)
|
74 |
+
|
75 |
+
self.leakyrelu = nn.LeakyReLU(self.alpha)
|
76 |
+
|
77 |
+
def forward(self, input, adj):
|
78 |
+
h = torch.matmul(input, self.W)
|
79 |
+
B, N = h.size()[0], h.size()[1]
|
80 |
+
|
81 |
+
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=2).view(B, N, -1,
|
82 |
+
2 * self.out_features)
|
83 |
+
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
|
84 |
+
|
85 |
+
zero_vec = -9e15 * torch.ones_like(e)
|
86 |
+
attention = torch.where(adj > 0, e, zero_vec)
|
87 |
+
attention = F.softmax(attention, dim=2)
|
88 |
+
attention = F.dropout(attention, self.dropout, training=self.training)
|
89 |
+
h_prime = torch.matmul(attention, h)
|
90 |
+
|
91 |
+
if self.concat:
|
92 |
+
return F.elu(h_prime)
|
93 |
+
else:
|
94 |
+
return h_prime
|
95 |
+
|
96 |
+
|
97 |
+
class GAT(nn.Module):
|
98 |
+
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers=2):
|
99 |
+
"""Dense version of GAT."""
|
100 |
+
super(GAT, self).__init__()
|
101 |
+
self.dropout = dropout
|
102 |
+
self.nlayers = nlayers
|
103 |
+
self.nheads = nheads
|
104 |
+
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
|
105 |
+
range(nheads)]
|
106 |
+
for i, attention in enumerate(self.attentions):
|
107 |
+
self.add_module('attention_{}'.format(i), attention)
|
108 |
+
if self.nlayers > 2:
|
109 |
+
for i in range(self.nlayers - 2):
|
110 |
+
for j in range(self.nheads):
|
111 |
+
self.add_module('attention_{}_{}'.format(i + 1, j),
|
112 |
+
GraphAttentionLayer(nhid * nheads, nhid, dropout=dropout, alpha=alpha, concat=True))
|
113 |
+
|
114 |
+
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
|
115 |
+
|
116 |
+
def forward(self, x, adj):
|
117 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
118 |
+
input = x
|
119 |
+
x = torch.cat([att(x, adj) for att in self.attentions], dim=2)
|
120 |
+
if self.nlayers > 2:
|
121 |
+
for i in range(self.nlayers - 2):
|
122 |
+
temp = []
|
123 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
124 |
+
cur_input = x
|
125 |
+
for j in range(self.nheads):
|
126 |
+
temp.append(self.__getattr__('attention_{}_{}'.format(i + 1, j))(x, adj))
|
127 |
+
x = torch.cat(temp, dim=2) + cur_input
|
128 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
129 |
+
x = F.elu(self.out_att(x, adj))
|
130 |
+
return x + input
|
131 |
+
|
132 |
+
|
133 |
+
def normalize_adj(mx):
|
134 |
+
"""
|
135 |
+
Row-normalize matrix D^{-1}A
|
136 |
+
torch.diag_embed: https://github.com/pytorch/pytorch/pull/12447
|
137 |
+
"""
|
138 |
+
mx = mx.float()
|
139 |
+
rowsum = mx.sum(2)
|
140 |
+
r_inv = torch.pow(rowsum, -1)
|
141 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
142 |
+
r_mat_inv = torch.diag_embed(r_inv, 0)
|
143 |
+
mx = r_mat_inv.matmul(mx)
|
144 |
+
return mx
|
145 |
+
|
146 |
+
|
147 |
+
class GLGINInteraction(BaseInteraction):
|
148 |
+
def __init__(self, **config):
|
149 |
+
super().__init__(**config)
|
150 |
+
self.intent_embedding = nn.Parameter(
|
151 |
+
torch.FloatTensor(self.config["intent_label_num"], self.config["intent_embedding_dim"])) # 191, 32
|
152 |
+
nn.init.normal_(self.intent_embedding.data)
|
153 |
+
self.adj = None
|
154 |
+
self.__slot_lstm = LSTMEncoder(
|
155 |
+
self.config["input_dim"] + self.config["intent_label_num"],
|
156 |
+
config["output_dim"],
|
157 |
+
config["dropout_rate"]
|
158 |
+
)
|
159 |
+
self.__slot_graph = GAT(
|
160 |
+
config["output_dim"],
|
161 |
+
config["hidden_dim"],
|
162 |
+
config["output_dim"],
|
163 |
+
config["dropout_rate"],
|
164 |
+
config["alpha"],
|
165 |
+
config["num_heads"],
|
166 |
+
config["num_layers"])
|
167 |
+
|
168 |
+
self.__global_graph = GAT(
|
169 |
+
config["output_dim"],
|
170 |
+
config["hidden_dim"],
|
171 |
+
config["output_dim"],
|
172 |
+
config["dropout_rate"],
|
173 |
+
config["alpha"],
|
174 |
+
config["num_heads"],
|
175 |
+
config["num_layers"])
|
176 |
+
|
177 |
+
def generate_global_adj_gat(self, seq_len, index, batch, window):
|
178 |
+
global_intent_idx = [[] for i in range(batch)]
|
179 |
+
global_slot_idx = [[] for i in range(batch)]
|
180 |
+
for item in index:
|
181 |
+
global_intent_idx[item[0]].append(item[1])
|
182 |
+
|
183 |
+
for i, len in enumerate(seq_len):
|
184 |
+
global_slot_idx[i].extend(
|
185 |
+
list(range(self.config["intent_label_num"], self.config["intent_label_num"] + len)))
|
186 |
+
|
187 |
+
adj = torch.cat([torch.eye(self.config["intent_label_num"] + max(seq_len)).unsqueeze(0) for i in range(batch)])
|
188 |
+
for i in range(batch):
|
189 |
+
for j in global_intent_idx[i]:
|
190 |
+
adj[i, j, global_slot_idx[i]] = 1.
|
191 |
+
adj[i, j, global_intent_idx[i]] = 1.
|
192 |
+
for j in global_slot_idx[i]:
|
193 |
+
adj[i, j, global_intent_idx[i]] = 1.
|
194 |
+
|
195 |
+
for i in range(batch):
|
196 |
+
for j in range(self.config["intent_label_num"], self.config["intent_label_num"] + seq_len[i]):
|
197 |
+
adj[i, j, max(self.config["intent_label_num"], j - window):min(seq_len[i] + self.config["intent_label_num"], j + window + 1)] = 1.
|
198 |
+
|
199 |
+
if self.config["row_normalized"]:
|
200 |
+
adj = normalize_adj(adj)
|
201 |
+
adj = adj.to(self.intent_embedding.device)
|
202 |
+
return adj
|
203 |
+
|
204 |
+
def generate_slot_adj_gat(self, seq_len, batch, window):
|
205 |
+
slot_idx_ = [[] for i in range(batch)]
|
206 |
+
adj = torch.cat([torch.eye(max(seq_len)).unsqueeze(0) for i in range(batch)])
|
207 |
+
for i in range(batch):
|
208 |
+
for j in range(seq_len[i]):
|
209 |
+
adj[i, j, max(0, j - window):min(seq_len[i], j + window + 1)] = 1.
|
210 |
+
if self.config["row_normalized"]:
|
211 |
+
adj = normalize_adj(adj)
|
212 |
+
adj = adj.to(self.intent_embedding.device)
|
213 |
+
return adj
|
214 |
+
|
215 |
+
def forward(self, encode_hidden: HiddenData, pred_intent: ClassifierOutputData = None, intent_index=None):
|
216 |
+
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
|
217 |
+
slot_lstm_out = self.__slot_lstm(torch.cat([encode_hidden.slot_hidden, pred_intent.classifier_output], dim=-1),
|
218 |
+
seq_lens)
|
219 |
+
global_adj = self.generate_global_adj_gat(seq_lens, intent_index, len(seq_lens),
|
220 |
+
self.config["slot_graph_window"])
|
221 |
+
slot_adj = self.generate_slot_adj_gat(seq_lens, len(seq_lens), self.config["slot_graph_window"])
|
222 |
+
batch = len(seq_lens)
|
223 |
+
slot_graph_out = self.__slot_graph(slot_lstm_out, slot_adj)
|
224 |
+
intent_in = self.intent_embedding.unsqueeze(0).repeat(batch, 1, 1)
|
225 |
+
global_graph_in = torch.cat([intent_in, slot_graph_out], dim=1)
|
226 |
+
encode_hidden.update_slot_hidden_state(self.__global_graph(global_graph_in, global_adj))
|
227 |
+
return encode_hidden
|
model/decoder/interaction/slot_gated_interaction.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import einops
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import LayerNorm
|
8 |
+
|
9 |
+
from common.utils import HiddenData
|
10 |
+
from model.decoder.interaction import BaseInteraction
|
11 |
+
|
12 |
+
|
13 |
+
class SlotGatedInteraction(BaseInteraction):
|
14 |
+
def __init__(self, **config):
|
15 |
+
super().__init__(**config)
|
16 |
+
self.intent_linear = nn.Linear(self.config["input_dim"],1, bias=False)
|
17 |
+
self.slot_linear1 = nn.Linear(self.config["input_dim"],1, bias=False)
|
18 |
+
self.slot_linear2 = nn.Linear(self.config["input_dim"],1, bias=False)
|
19 |
+
self.remove_slot_attn = self.config["remove_slot_attn"]
|
20 |
+
self.slot_gate = SlotGate(**config)
|
21 |
+
|
22 |
+
def forward(self, encode_hidden: HiddenData, **kwargs):
|
23 |
+
input_hidden = encode_hidden.get_slot_hidden_state()
|
24 |
+
|
25 |
+
seq_lens = encode_hidden.inputs.attention_mask.sum(-1)
|
26 |
+
output_list = []
|
27 |
+
for index, slen in enumerate(seq_lens):
|
28 |
+
output_list.append(input_hidden[index, slen - 1, :].unsqueeze(0))
|
29 |
+
intent_input = torch.cat(output_list, dim=0)
|
30 |
+
e_I = torch.tanh(self.intent_linear(intent_input)).squeeze(1)
|
31 |
+
alpha_I = einops.repeat(e_I, 'b -> b h', h=intent_input.shape[-1])
|
32 |
+
c_I = alpha_I * intent_input
|
33 |
+
intent_hidden = intent_input+c_I
|
34 |
+
if not self.remove_slot_attn:
|
35 |
+
# slot attention
|
36 |
+
h_k = einops.repeat(self.slot_linear1(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1])
|
37 |
+
h_i = einops.repeat(self.slot_linear2(input_hidden), 'b l h -> b l c h', c=input_hidden.shape[1]).transpose(1,2)
|
38 |
+
e_S = torch.tanh(h_k + h_i)
|
39 |
+
alpha_S = torch.softmax(e_S, dim=2).squeeze(3)
|
40 |
+
alpha_S = einops.repeat(alpha_S, 'b l1 l2 -> b l1 l2 h', h=input_hidden.shape[-1])
|
41 |
+
map_input_hidden = einops.repeat(input_hidden, 'b l h -> b l c h', c=input_hidden.shape[1])
|
42 |
+
c_S = torch.sum(alpha_S * map_input_hidden, dim=2)
|
43 |
+
else:
|
44 |
+
c_S = input_hidden
|
45 |
+
slot_hidden = input_hidden + c_S * self.slot_gate(c_S,c_I)
|
46 |
+
encode_hidden.update_intent_hidden_state(intent_hidden)
|
47 |
+
encode_hidden.update_slot_hidden_state(slot_hidden)
|
48 |
+
return encode_hidden
|
49 |
+
|
50 |
+
class SlotGate(nn.Module):
|
51 |
+
def __init__(self, **config):
|
52 |
+
super().__init__()
|
53 |
+
self.linear = nn.Linear(config["input_dim"], config["output_dim"],bias=False)
|
54 |
+
self.v = nn.Parameter(torch.rand(size=[1]))
|
55 |
+
|
56 |
+
def forward(self, slot_context, intent_context):
|
57 |
+
intent_gate = self.linear(intent_context)
|
58 |
+
intent_gate = einops.repeat(intent_gate, 'b h -> b l h', l=slot_context.shape[1])
|
59 |
+
return self.v * torch.tanh(slot_context + intent_gate)
|
model/decoder/interaction/stack_interaction.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from common import utils
|
6 |
+
from common.utils import ClassifierOutputData, HiddenData
|
7 |
+
from model.decoder.interaction.base_interaction import BaseInteraction
|
8 |
+
|
9 |
+
|
10 |
+
class StackInteraction(BaseInteraction):
|
11 |
+
def __init__(self, **config):
|
12 |
+
super().__init__(**config)
|
13 |
+
self.intent_embedding = nn.Embedding(
|
14 |
+
self.config["intent_label_num"], self.config["intent_label_num"]
|
15 |
+
)
|
16 |
+
self.differentiable = config.get("differentiable")
|
17 |
+
self.intent_embedding.weight.data = torch.eye(
|
18 |
+
self.config["intent_label_num"])
|
19 |
+
self.intent_embedding.weight.requires_grad = False
|
20 |
+
|
21 |
+
def forward(self, intent_output: ClassifierOutputData, encode_hidden: HiddenData):
|
22 |
+
if not self.differentiable:
|
23 |
+
_, idx_intent = intent_output.classifier_output.topk(1, dim=-1)
|
24 |
+
feed_intent = self.intent_embedding(idx_intent.squeeze(2))
|
25 |
+
else:
|
26 |
+
feed_intent = intent_output.classifier_output
|
27 |
+
encode_hidden.update_slot_hidden_state(
|
28 |
+
torch.cat([encode_hidden.get_slot_hidden_state(), feed_intent], dim=-1))
|
29 |
+
return encode_hidden
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def from_configured(configure_name_or_file="stack-interaction", **input_config):
|
33 |
+
return utils.from_configured(configure_name_or_file,
|
34 |
+
model_class=StackInteraction,
|
35 |
+
config_prefix="./config/decoder/interaction",
|
36 |
+
**input_config)
|
model/encoder/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.encoder.pretrained_encoder import PretrainedEncoder
|
2 |
+
from model.encoder.non_pretrained_encoder import NonPretrainedEncoder
|
3 |
+
from model.encoder.base_encoder import BiEncoder
|
4 |
+
from model.encoder.auto_encoder import AutoEncoder
|
5 |
+
__all__ = ["PretrainedEncoder", "NonPretrainedEncoder", "AutoEncoder","BiEncoder"]
|
model/encoder/auto_encoder.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 17:46:10
|
6 |
+
Description:
|
7 |
+
|
8 |
+
'''
|
9 |
+
from common.utils import InputData
|
10 |
+
from model.encoder.base_encoder import BaseEncoder, BiEncoder
|
11 |
+
from model.encoder.pretrained_encoder import PretrainedEncoder
|
12 |
+
from model.encoder.non_pretrained_encoder import NonPretrainedEncoder
|
13 |
+
|
14 |
+
class AutoEncoder(BaseEncoder):
|
15 |
+
|
16 |
+
def __init__(self, **config):
|
17 |
+
"""automatedly load encoder by 'encoder_name'
|
18 |
+
Args:
|
19 |
+
config (dict):
|
20 |
+
encoder_name (str): support ["lstm", "self-attention-lstm", "bi-encoder"] and other pretrained model in hugging face
|
21 |
+
**args (Any): other configuration items corresponding to each module.
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
self.config = config
|
25 |
+
if config.get("encoder_name"):
|
26 |
+
encoder_name = config.get("encoder_name").lower()
|
27 |
+
if encoder_name in ["lstm", "self-attention-lstm"]:
|
28 |
+
self.__encoder = NonPretrainedEncoder(**config)
|
29 |
+
elif encoder_name == "bi-encoder":
|
30 |
+
self.__encoder= BiEncoder(self.__init__(**config["intent_encoder"]), self.__init__(**config["intent_encoder"]))
|
31 |
+
else:
|
32 |
+
self.__encoder = PretrainedEncoder(**config)
|
33 |
+
else:
|
34 |
+
raise ValueError("There is no Encoder Name in config.")
|
35 |
+
|
36 |
+
def forward(self, inputs: InputData):
|
37 |
+
return self.__encoder(inputs)
|
model/encoder/base_encoder.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 17:25:17
|
6 |
+
Description: Base encoder and bi encoder
|
7 |
+
|
8 |
+
'''
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from common.utils import InputData
|
12 |
+
|
13 |
+
|
14 |
+
class BaseEncoder(nn.Module):
|
15 |
+
"""Base class for all encoder module
|
16 |
+
"""
|
17 |
+
def __init__(self, **config):
|
18 |
+
super().__init__()
|
19 |
+
self.config = config
|
20 |
+
NotImplementedError("no implement")
|
21 |
+
|
22 |
+
def forward(self, inputs: InputData):
|
23 |
+
self.encoder(inputs.input_ids)
|
24 |
+
|
25 |
+
|
26 |
+
class BiEncoder(nn.Module):
|
27 |
+
"""Bi Encoder for encode intent and slot separately
|
28 |
+
"""
|
29 |
+
def __init__(self, intent_encoder: BaseEncoder, slot_encoder: BaseEncoder, **config):
|
30 |
+
super().__init__()
|
31 |
+
self.intent_encoder = intent_encoder
|
32 |
+
self.slot_encoder = slot_encoder
|
33 |
+
|
34 |
+
def forward(self, inputs: InputData):
|
35 |
+
hidden_slot = self.slot_encoder(inputs)
|
36 |
+
hidden_intent = self.intent_encoder(inputs)
|
37 |
+
if not self.intent_encoder.config["return_sentence_level_hidden"]:
|
38 |
+
hidden_slot.update_intent_hidden_state(hidden_intent.get_slot_hidden_state())
|
39 |
+
else:
|
40 |
+
hidden_slot.update_intent_hidden_state(hidden_intent.get_intent_hidden_state())
|
41 |
+
return hidden_slot
|
model/encoder/non_pretrained_encoder.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-30 15:00:29
|
6 |
+
Description: non-pretrained encoder model
|
7 |
+
|
8 |
+
'''
|
9 |
+
import math
|
10 |
+
|
11 |
+
import einops
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
17 |
+
|
18 |
+
from common.utils import HiddenData, InputData
|
19 |
+
from model.encoder.base_encoder import BaseEncoder
|
20 |
+
|
21 |
+
class NonPretrainedEncoder(BaseEncoder):
|
22 |
+
"""
|
23 |
+
Encoder structure based on bidirectional LSTM and self-attention.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, **config):
|
27 |
+
""" init non-pretrained encoder
|
28 |
+
|
29 |
+
Args:
|
30 |
+
config (dict):
|
31 |
+
embedding (dict):
|
32 |
+
dropout_rate (float): dropout rate.
|
33 |
+
load_embedding_name (str): None if not use pretrained embedding or embedding name like "glove.6B.300d.txt".
|
34 |
+
embedding_matrix (Tensor, Optional): embedding matrix tensor. Enabled if load_embedding_name is not None.
|
35 |
+
vocab_size (str, Optional): vocabulary size. Enabled if load_embedding_name is None.
|
36 |
+
lstm (dict):
|
37 |
+
output_dim (int): lstm output dim.
|
38 |
+
bidirectional (bool): if use BiLSTM or LSTM.
|
39 |
+
layer_num (int): number of layers.
|
40 |
+
dropout_rate (float): dropout rate.
|
41 |
+
attention (dict, Optional):
|
42 |
+
dropout_rate (float): dropout rate.
|
43 |
+
hidden_dim (int): attention hidden dim.
|
44 |
+
output_dim (int): attention output dim.
|
45 |
+
unflat_attention (dict, optional): Enabled if attention is not None.
|
46 |
+
dropout_rate (float): dropout rate.
|
47 |
+
"""
|
48 |
+
super(NonPretrainedEncoder, self).__init__()
|
49 |
+
self.config = config
|
50 |
+
# Embedding Initialization
|
51 |
+
embed_config = config["embedding"]
|
52 |
+
self.__embedding_dim = embed_config["embedding_dim"]
|
53 |
+
if embed_config.get("load_embedding_name"):
|
54 |
+
self.__embedding_layer = nn.Embedding.from_pretrained(embed_config["embedding_matrix"], padding_idx=0)
|
55 |
+
else:
|
56 |
+
self.__embedding_layer = nn.Embedding(
|
57 |
+
embed_config["vocab_size"], self.__embedding_dim
|
58 |
+
)
|
59 |
+
self.__embedding_dropout_layer = nn.Dropout(embed_config["dropout_rate"])
|
60 |
+
|
61 |
+
# LSTM Initialization
|
62 |
+
lstm_config = config["lstm"]
|
63 |
+
self.__hidden_size = lstm_config["output_dim"]
|
64 |
+
self.__lstm_layer = nn.LSTM(
|
65 |
+
input_size=self.__embedding_dim,
|
66 |
+
hidden_size=lstm_config["output_dim"] // 2,
|
67 |
+
batch_first=True,
|
68 |
+
bidirectional=lstm_config["bidirectional"],
|
69 |
+
dropout=lstm_config["dropout_rate"],
|
70 |
+
num_layers=lstm_config["layer_num"]
|
71 |
+
)
|
72 |
+
if self.config.get("attention"):
|
73 |
+
# Attention Initialization
|
74 |
+
att_config = config["attention"]
|
75 |
+
self.__attention_dropout_layer = nn.Dropout(att_config["dropout_rate"])
|
76 |
+
self.__attention_layer = QKVAttention(
|
77 |
+
self.__embedding_dim, self.__embedding_dim, self.__embedding_dim,
|
78 |
+
att_config["hidden_dim"], att_config["output_dim"], att_config["dropout_rate"]
|
79 |
+
)
|
80 |
+
if self.config.get("unflat_attention"):
|
81 |
+
unflat_att_config = config["unflat_attention"]
|
82 |
+
self.__sentattention = UnflatSelfAttention(
|
83 |
+
lstm_config["output_dim"] + att_config["output_dim"],
|
84 |
+
unflat_att_config["dropout_rate"]
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, inputs: InputData):
|
88 |
+
""" Forward process for Non-Pretrained Encoder.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
inputs: padded input ids, masks.
|
92 |
+
Returns:
|
93 |
+
encoded hidden vectors.
|
94 |
+
"""
|
95 |
+
|
96 |
+
# LSTM Encoder
|
97 |
+
# Padded_text should be instance of LongTensor.
|
98 |
+
embedded_text = self.__embedding_layer(inputs.input_ids)
|
99 |
+
dropout_text = self.__embedding_dropout_layer(embedded_text)
|
100 |
+
seq_lens = inputs.attention_mask.sum(-1).detach().cpu()
|
101 |
+
# Pack and Pad process for input of variable length.
|
102 |
+
packed_text = pack_padded_sequence(dropout_text, seq_lens, batch_first=True, enforce_sorted=False)
|
103 |
+
lstm_hiddens, (h_last, c_last) = self.__lstm_layer(packed_text)
|
104 |
+
padded_hiddens, _ = pad_packed_sequence(lstm_hiddens, batch_first=True)
|
105 |
+
|
106 |
+
if self.config.get("attention"):
|
107 |
+
# Attention Encoder
|
108 |
+
dropout_text = self.__attention_dropout_layer(embedded_text)
|
109 |
+
attention_hiddens = self.__attention_layer(
|
110 |
+
dropout_text, dropout_text, dropout_text, mask=inputs.attention_mask
|
111 |
+
)
|
112 |
+
|
113 |
+
# Attention + LSTM
|
114 |
+
hiddens = torch.cat([attention_hiddens, padded_hiddens], dim=-1)
|
115 |
+
hidden = HiddenData(None, hiddens)
|
116 |
+
if self.config.get("return_with_input"):
|
117 |
+
hidden.add_input(inputs)
|
118 |
+
if self.config.get("return_sentence_level_hidden"):
|
119 |
+
if self.config.get("unflat_attention"):
|
120 |
+
sentence = self.__sentattention(hiddens, seq_lens)
|
121 |
+
else:
|
122 |
+
sentence = hiddens[:, 0, :]
|
123 |
+
hidden.update_intent_hidden_state(sentence)
|
124 |
+
else:
|
125 |
+
sentence_hidden = None
|
126 |
+
if self.config.get("return_sentence_level_hidden"):
|
127 |
+
sentence_hidden = torch.cat((h_last[-1], h_last[-1], c_last[-1], c_last[-2]), dim=-1)
|
128 |
+
hidden = HiddenData(sentence_hidden, padded_hiddens)
|
129 |
+
if self.config.get("return_with_input"):
|
130 |
+
hidden.add_input(inputs)
|
131 |
+
|
132 |
+
return hidden
|
133 |
+
|
134 |
+
|
135 |
+
class QKVAttention(nn.Module):
|
136 |
+
"""
|
137 |
+
Attention mechanism based on Query-Key-Value architecture. And
|
138 |
+
especially, when query == key == value, it's self-attention.
|
139 |
+
"""
|
140 |
+
|
141 |
+
def __init__(self, query_dim, key_dim, value_dim, hidden_dim, output_dim, dropout_rate):
|
142 |
+
super(QKVAttention, self).__init__()
|
143 |
+
|
144 |
+
# Record hyper-parameters.
|
145 |
+
self.__query_dim = query_dim
|
146 |
+
self.__key_dim = key_dim
|
147 |
+
self.__value_dim = value_dim
|
148 |
+
self.__hidden_dim = hidden_dim
|
149 |
+
self.__output_dim = output_dim
|
150 |
+
self.__dropout_rate = dropout_rate
|
151 |
+
|
152 |
+
# Declare network structures.
|
153 |
+
self.__query_layer = nn.Linear(self.__query_dim, self.__hidden_dim)
|
154 |
+
self.__key_layer = nn.Linear(self.__key_dim, self.__hidden_dim)
|
155 |
+
self.__value_layer = nn.Linear(self.__value_dim, self.__output_dim)
|
156 |
+
self.__dropout_layer = nn.Dropout(p=self.__dropout_rate)
|
157 |
+
|
158 |
+
def forward(self, input_query, input_key, input_value, mask=None):
|
159 |
+
""" The forward propagation of attention.
|
160 |
+
|
161 |
+
Here we require the first dimension of input key
|
162 |
+
and value are equal.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
input_query: is query tensor, (n, d_q)
|
166 |
+
input_key: is key tensor, (m, d_k)
|
167 |
+
input_value: is value tensor, (m, d_v)
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
attention based tensor, (n, d_h)
|
171 |
+
"""
|
172 |
+
|
173 |
+
# Linear transform to fine-tune dimension.
|
174 |
+
linear_query = self.__query_layer(input_query)
|
175 |
+
linear_key = self.__key_layer(input_key)
|
176 |
+
linear_value = self.__value_layer(input_value)
|
177 |
+
|
178 |
+
score_tensor = torch.matmul(
|
179 |
+
linear_query,
|
180 |
+
linear_key.transpose(-2, -1)
|
181 |
+
) / math.sqrt(self.__hidden_dim)
|
182 |
+
if mask is not None:
|
183 |
+
attn_mask = einops.repeat((mask == 0), "b l -> b l h", h=score_tensor.shape[-1])
|
184 |
+
score_tensor = score_tensor.masked_fill_(attn_mask, -float(1e20))
|
185 |
+
score_tensor = F.softmax(score_tensor, dim=-1)
|
186 |
+
forced_tensor = torch.matmul(score_tensor, linear_value)
|
187 |
+
forced_tensor = self.__dropout_layer(forced_tensor)
|
188 |
+
|
189 |
+
return forced_tensor
|
190 |
+
|
191 |
+
|
192 |
+
class UnflatSelfAttention(nn.Module):
|
193 |
+
"""
|
194 |
+
scores each element of the sequence with a linear layer and uses the normalized scores to compute a context over the sequence.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, d_hid, dropout=0.):
|
198 |
+
super().__init__()
|
199 |
+
self.scorer = nn.Linear(d_hid, 1)
|
200 |
+
self.dropout = nn.Dropout(dropout)
|
201 |
+
|
202 |
+
def forward(self, inp, lens):
|
203 |
+
batch_size, seq_len, d_feat = inp.size()
|
204 |
+
inp = self.dropout(inp)
|
205 |
+
scores = self.scorer(inp.contiguous().view(-1, d_feat)).view(batch_size, seq_len)
|
206 |
+
max_len = max(lens)
|
207 |
+
for i, l in enumerate(lens):
|
208 |
+
if l < max_len:
|
209 |
+
scores.data[i, l:] = -np.inf
|
210 |
+
scores = F.softmax(scores, dim=1)
|
211 |
+
context = scores.unsqueeze(2).expand_as(inp).mul(inp).sum(1)
|
212 |
+
return context
|
model/encoder/pretrained_encoder.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 17:18:01
|
6 |
+
Description: pretrained encoder model
|
7 |
+
|
8 |
+
'''
|
9 |
+
from transformers import AutoModel
|
10 |
+
|
11 |
+
from common.utils import InputData, HiddenData
|
12 |
+
from model.encoder.base_encoder import BaseEncoder
|
13 |
+
|
14 |
+
|
15 |
+
class PretrainedEncoder(BaseEncoder):
|
16 |
+
def __init__(self, **config):
|
17 |
+
""" init pretrained encoder
|
18 |
+
|
19 |
+
Args:
|
20 |
+
config (dict):
|
21 |
+
encoder_name (str): pretrained model name in hugging face.
|
22 |
+
"""
|
23 |
+
super().__init__(**config)
|
24 |
+
self.encoder = AutoModel.from_pretrained(config["encoder_name"])
|
25 |
+
|
26 |
+
def forward(self, inputs: InputData):
|
27 |
+
output = self.encoder(**inputs.get_inputs())
|
28 |
+
hidden = HiddenData(None, output.last_hidden_state)
|
29 |
+
if self.config.get("return_with_input"):
|
30 |
+
hidden.add_input(inputs)
|
31 |
+
if self.config.get("return_sentence_level_hidden"):
|
32 |
+
padding_side = self.config.get("padding_side")
|
33 |
+
if hasattr(output, "pooler_output"):
|
34 |
+
hidden.update_intent_hidden_state(output.pooler_output)
|
35 |
+
elif padding_side is not None and padding_side == "left":
|
36 |
+
hidden.update_intent_hidden_state(output.last_hidden_state[:, -1])
|
37 |
+
else:
|
38 |
+
hidden.update_intent_hidden_state(output.last_hidden_state[:, 0])
|
39 |
+
return hidden
|
model/open_slu_model.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Author: Qiguang Chen
|
3 |
+
Date: 2023-01-11 10:39:26
|
4 |
+
LastEditors: Qiguang Chen
|
5 |
+
LastEditTime: 2023-01-26 17:18:22
|
6 |
+
Description: Root Model Module
|
7 |
+
|
8 |
+
'''
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from common.utils import OutputData, InputData
|
12 |
+
from model.decoder.base_decoder import BaseDecoder
|
13 |
+
from model.encoder.base_encoder import BaseEncoder
|
14 |
+
|
15 |
+
|
16 |
+
class OpenSLUModel(nn.Module):
|
17 |
+
def __init__(self, encoder: BaseEncoder, decoder:BaseDecoder, **config):
|
18 |
+
"""Create model automatedly
|
19 |
+
|
20 |
+
Args:
|
21 |
+
encoder (BaseEncoder): encoder created by config
|
22 |
+
decoder (BaseDecoder): decoder created by config
|
23 |
+
config (dict): any other args
|
24 |
+
"""
|
25 |
+
super().__init__()
|
26 |
+
self.encoder = encoder
|
27 |
+
self.decoder = decoder
|
28 |
+
self.config = config
|
29 |
+
|
30 |
+
def forward(self, inp: InputData) -> OutputData:
|
31 |
+
""" model forward
|
32 |
+
|
33 |
+
Args:
|
34 |
+
inp (InputData): input ids and other information
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
OutputData: pred logits
|
38 |
+
"""
|
39 |
+
return self.decoder(self.encoder(inp))
|
40 |
+
|
41 |
+
def decode(self, output: OutputData, target: InputData=None):
|
42 |
+
""" decode output
|
43 |
+
|
44 |
+
Args:
|
45 |
+
pred (OutputData): pred logits data
|
46 |
+
target (InputData): golden data
|
47 |
+
|
48 |
+
Returns: decoded ids
|
49 |
+
"""
|
50 |
+
return self.decoder.decode(output, target)
|
51 |
+
|
52 |
+
def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True):
|
53 |
+
""" compute loss
|
54 |
+
|
55 |
+
Args:
|
56 |
+
pred (OutputData): pred logits data
|
57 |
+
target (InputData): golden data
|
58 |
+
compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True.
|
59 |
+
compute_slot_loss (bool, optional): whether to compute slot loss. Defaults to True.
|
60 |
+
|
61 |
+
Returns: loss value
|
62 |
+
"""
|
63 |
+
return self.decoder.compute_loss(pred, target, compute_intent_loss=compute_intent_loss,
|
64 |
+
compute_slot_loss=compute_slot_loss)
|
save/stack/label.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"intent": ["atis_flight", "atis_airfare", "atis_airline", "atis_ground_service", "atis_quantity", "atis_city", "atis_flight#atis_airfare", "atis_abbreviation", "atis_aircraft", "atis_distance", "atis_ground_fare", "atis_capacity", "atis_flight_time", "atis_meal", "atis_aircraft#atis_flight#atis_flight_no", "atis_flight_no", "atis_restriction", "atis_airport", "atis_airline#atis_flight_no", "atis_cheapest", "atis_ground_service#atis_ground_fare"], "slot": ["O", "B-fromloc.city_name", "B-toloc.city_name", "B-round_trip", "I-round_trip", "B-cost_relative", "B-fare_amount", "I-fare_amount", "B-arrive_date.month_name", "B-arrive_date.day_number", "I-fromloc.city_name", "B-stoploc.city_name", "B-arrive_time.time_relative", "B-arrive_time.time", "I-arrive_time.time", "B-toloc.state_code", "I-toloc.city_name", "I-stoploc.city_name", "B-meal_description", "B-depart_date.month_name", "B-depart_date.day_number", "B-airline_name", "I-airline_name", "B-depart_time.period_of_day", "B-depart_date.day_name", "B-toloc.state_name", "B-depart_time.time_relative", "B-depart_time.time", "B-toloc.airport_name", "I-toloc.airport_name", "B-depart_date.date_relative", "B-or", "B-airline_code", "B-class_type", "I-class_type", "I-cost_relative", "I-depart_time.time", "B-fromloc.airport_name", "I-fromloc.airport_name", "B-city_name", "B-flight_mod", "B-meal", "B-economy", "B-fare_basis_code", "I-depart_date.day_number", "B-depart_date.today_relative", "B-flight_stop", "B-airport_code", "B-fromloc.state_name", "I-fromloc.state_name", "I-city_name", "B-connect", "B-arrive_date.day_name", "B-fromloc.state_code", "B-arrive_date.today_relative", "B-depart_date.year", "B-depart_time.start_time", "I-depart_time.start_time", "B-depart_time.end_time", "I-depart_time.end_time", "B-arrive_time.start_time", "B-arrive_time.end_time", "I-arrive_time.end_time", "I-flight_mod", "B-flight_days", "B-mod", "B-flight_number", "I-toloc.state_name", "B-meal_code", "I-meal_code", "B-airport_name", "I-airport_name", "I-flight_stop", "B-transport_type", "I-transport_type", "B-state_code", "B-aircraft_code", "B-toloc.country_name", "I-arrive_date.day_number", "B-toloc.airport_code", "B-return_date.date_relative", "I-return_date.date_relative", "B-flight_time", "I-economy", "B-fromloc.airport_code", "B-arrive_time.period_of_day", "B-depart_time.period_mod", "I-flight_time", "B-return_date.day_name", "B-arrive_date.date_relative", "B-restriction_code", "I-restriction_code", "B-arrive_time.period_mod", "I-arrive_time.period_of_day", "B-period_of_day", "B-stoploc.state_code", "I-depart_date.today_relative", "I-fare_basis_code", "I-arrive_time.start_time", "B-time", "B-today_relative", "I-today_relative", "B-state_name", "B-days_code", "I-depart_time.period_of_day", "I-arrive_time.time_relative", "B-time_relative", "I-time", "B-return_date.month_name", "B-return_date.day_number", "I-depart_time.time_relative", "B-stoploc.airport_name", "B-day_name", "B-month_name", "B-day_number", "B-return_time.period_mod", "B-return_time.period_of_day", "B-return_date.today_relative", "I-return_date.today_relative", "I-meal_description"]}
|
save/stack/model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9710de3d7d5c8a34fe55ef4dc36dc8a851863d1fb3bb14871d914a4e945c96ef
|
3 |
+
size 5793644
|
save/stack/outputs.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
save/stack/tokenizer.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"name": "word_tokenizer", "token_map": {"[PAD]": 0, "[UNK]": 1, "i": 2, "want": 3, "to": 4, "fly": 5, "from": 6, "baltimore": 7, "dallas": 8, "round": 9, "trip": 10, "fares": 11, "philadelphia": 12, "less": 13, "than": 14, "1000": 15, "dollars": 16, "denver": 17, "pittsburgh": 18, "show": 19, "me": 20, "the": 21, "flights": 22, "arriving": 23, "on": 24, "june": 25, "fourteenth": 26, "what": 27, "are": 28, "which": 29, "depart": 30, "san": 31, "francisco": 32, "washington": 33, "via": 34, "indianapolis": 35, "and": 36, "arrive": 37, "by": 38, "9": 39, "pm": 40, "airlines": 41, "boston": 42, "dc": 43, "other": 44, "cities": 45, "i'm": 46, "looking": 47, "for": 48, "a": 49, "flight": 50, "charlotte": 51, "las": 52, "vegas": 53, "that": 54, "stops": 55, "in": 56, "st.": 57, "louis": 58, "hopefully": 59, "dinner": 60, "how": 61, "can": 62, "find": 63, "out": 64, "okay": 65, "then": 66, "i'd": 67, "like": 68, "travel": 69, "atlanta": 70, "september": 71, "fourth": 72, "all": 73, "cincinnati": 74, "us": 75, "air": 76, "diego": 77, "afternoon": 78, "what's": 79, "available": 80, "tuesday": 81, "leave": 82, "phoenix": 83, "paul": 84, "minnesota": 85, "after": 86, "noon": 87, "american": 88, "chicago": 89, "los": 90, "angeles": 91, "morning": 92, "types": 93, "of": 94, "ground": 95, "transportation": 96, "there": 97, "airport": 98, "next": 99, "two": 100, "days": 101, "nashville": 102, "jose": 103, "or": 104, "tacoma": 105, "does": 106, "continental": 107, "milwaukee": 108, "many": 109, "twa": 110, "have": 111, "business": 112, "class": 113, "first": 114, "least": 115, "expensive": 116, "one": 117, "way": 118, "fare": 119, "booking": 120, "classes": 121, "wednesday": 122, "nineteenth": 123, "july": 124, "fifth": 125, "7": 126, "please": 127, "list": 128, "departing": 129, "general": 130, "mitchell": 131, "international": 132, "time": 133, "zone": 134, "is": 135, "serves": 136, "meal": 137, "seattle": 138, "salt": 139, "lake": 140, "city": 141, "you": 142, "with": 143, "economy": 144, "leaving": 145, "miami": 146, "cleveland": 147, "give": 148, "between": 149, "their": 150, "cost": 151, "code": 152, "y": 153, "mean": 154, "could": 155, "tell": 156, "leaves": 157, "united": 158, "airline": 159, "over": 160, "departures": 161, "arrivals": 162, "earliest": 163, "latest": 164, "return": 165, "within": 166, "same": 167, "day": 168, "orlando": 169, "either": 170, "evening": 171, "thursday": 172, "originating": 173, "going": 174, "order": 175, "snack": 176, "do": 177, "at": 178, "645": 179, "am": 180, "into": 181, "atlanta's": 182, "friday": 183, "qx": 184, "would": 185, "information": 186, "but": 187, "stopover": 188, "some": 189, "oakland": 190, "monday": 191, "know": 192, "type": 193, "aircraft": 194, "used": 195, "detroit": 196, "twenty": 197, "eighth": 198, "petersburg": 199, "take": 200, "begins": 201, "lands": 202, "fort": 203, "worth": 204, "stop": 205, "tomorrow": 206, "noontime": 207, "around": 208, "need": 209, "northwest": 210, "toronto": 211, "memphis": 212, "thirtieth": 213, "nonstop": 214, "houston": 215, "august": 216, "twentieth": 217, "ewr": 218, "seventh": 219, "newark": 220, "11": 221, "lowest": 222, "delta": 223, "has": 224, "go": 225, "any": 226, "jet": 227, "mco": 228, "new": 229, "jersey": 230, "ontario": 231, "saturday": 232, "york": 233, "long": 234, "it": 235, "get": 236, "prices": 237, "an": 238, "inexpensive": 239, "breakfast": 240, "direct": 241, "what're": 242, "sunday": 243, "colorado": 244, "see": 245, "serving": 246, "before": 247, "o'clock": 248, "january": 249, "1992": 250, "10": 251, "2": 252, "445": 253, "515": 254, "6": 255, "8": 256, "coach": 257, "only": 258, "weekdays": 259, "la": 260, "3": 261, "my": 262, "choices": 263, "early": 264, "minneapolis": 265, "cheapest": 266, "flying": 267, "possible": 268, "daily": 269, "beach": 270, "stopping": 271, "kansas": 272, "night": 273, "serve": 274, "meals": 275, "heading": 276, "kind": 277, "once": 278, "mornings": 279, "ff": 280, "arrangements": 281, "served": 282, "canadian": 283, "california": 284, "about": 285, "530": 286, "kinds": 287, "requesting": 288, "landing": 289, "distance": 290, "downtown": 291, "love": 292, "field": 293, "this": 294, "tampa": 295, "florida": 296, "5": 297, "must": 298, "be": 299, "advertises": 300, "having": 301, "more": 302, "november": 303, "eleventh": 304, "services": 305, "qo": 306, "american's": 307, "last": 308, "using": 309, "dl": 310, "217": 311, "montreal": 312, "service": 313, "ticket": 314, "should": 315, "near": 316, "also": 317, "missouri": 318, "utah": 319, "interested": 320, "those": 321, "when": 322, "north": 323, "carolina": 324, "415": 325, "200": 326, "explain": 327, "codes": 328, "sd": 329, "d": 330, "trying": 331, "plane": 332, "flies": 333, "weekday": 334, "columbus": 335, "1991": 336, "carries": 337, "smallest": 338, "number": 339, "passengers": 340, "takeoffs": 341, "landings": 342, "book": 343, "shortest": 344, "both": 345, "nationair": 346, "823": 347, "guardia": 348, "as": 349, "well": 350, "sixteenth": 351, "rental": 352, "car": 353, "rates": 354, "through": 355, "boeing": 356, "757": 357, "limousine": 358, "listing": 359, "canada": 360, "much": 361, "71": 362, "airfare": 363, "12": 364, "third": 365, "seating": 366, "capacity": 367, "arrives": 368, "bwi": 369, "ninth": 370, "late": 371, "nevada": 372, "4": 373, "price": 374, "fifteenth": 375, "eighteenth": 376, "returning": 377, "following": 378, "capacities": 379, "planes": 380, "1145": 381, "use": 382, "burbank": 383, "may": 384, "america": 385, "west": 386, "now": 387, "eastern": 388, "825": 389, "555": 390, "area": 391, "schedule": 392, "dfw": 393, "these": 394, "connecting": 395, "make": 396, "connection": 397, "lunch": 398, "f": 399, "belong": 400, "most": 401, "tickets": 402, "logan": 403, "vicinity": 404, "210": 405, "wednesdays": 406, "thursdays": 407, "yes": 408, "will": 409, "continuing": 410, "1039": 411, "southwest": 412, "times": 413, "400": 414, "week": 415, "if": 416, "813": 417, "enroute": 418, "another": 419, "twelfth": 420, "turboprop": 421, "420": 422, "today": 423, "1": 424, "we're": 425, "westchester": 426, "county": 427, "various": 428, "airplanes": 429, "uses": 430, "yn": 431, "852": 432, "transport": 433, "display": 434, "under": 435, "500": 436, "airfares": 437, "back": 438, "hours": 439, "fn": 440, "options": 441, "december": 442, "second": 443, "april": 444, "ohio": 445, "departs": 446, "2153": 447, "schedules": 448, "who": 449, "restriction": 450, "ap": 451, "57": 452, "layover": 453, "abbreviation": 454, "stands": 455, "1291": 456, "324": 457, "again": 458, "offer": 459, "dc10": 460, "currently": 461, "represented": 462, "database": 463, "arizona": 464, "1505": 465, "sixth": 466, "3724": 467, "three": 468, "including": 469, "connections": 470, "numbers": 471, "six": 472, "1100": 473, "destination": 474, "838": 475, "no": 476, "h": 477, "traveling": 478, "ap57": 479, "far": 480, "lufthansa": 481, "abbreviations": 482, "such": 483, "aa": 484, "459": 485, "where": 486, "ua": 487, "281": 488, "your": 489, "texas": 490, "1500": 491, "bound": 492, "includes": 493, "right": 494, "airports": 495, "eight": 496, "sixteen": 497, "trips": 498, "seventeenth": 499, "thrift": 500, "delta's": 501, "departure": 502, "listed": 503, "1055": 504, "405": 505, "midnight": 506, "hi": 507, "630": 508, "question": 509, "live": 510, "stand": 511, "ten": 512, "people": 513, "during": 514, "2100": 515, "gets": 516, "just": 517, "philly": 518, "21": 519, "airplane": 520, "1765": 521, "iah": 522, "737": 523, "midwest": 524, "express": 525, "s": 526, "designate": 527, "747": 528, "650": 529, "goes": 530, "reaches": 531, "seventeen": 532, "sorry": 533, "anywhere": 534, "provided": 535, "d10": 536, "toward": 537, "preferably": 538, "rate": 539, "difference": 540, "q": 541, "b": 542, "ac": 543, "tower": 544, "tenth": 545, "hp": 546, "4400": 547, "georgia": 548, "offers": 549, "fine": 550, "201": 551, "343": 552, "october": 553, "ea": 554, "jfk": 555, "name": 556, "arrange": 557, "largest": 558, "connect": 559, "operating": 560, "sundays": 561, "720": 562, "land": 563, "final": 564, "don't": 565, "stopovers": 566, "total": 567, "friday's": 568, "755": 569, "cheap": 570, "sfo": 571, "thirty": 572, "across": 573, "continent": 574, "makes": 575, "1220": 576, "co": 577, "1209": 578, "wanted": 579, "1850": 580, "without": 581, "listings": 582, "local": 583, "wish": 584, "bring": 585, "up": 586, "home": 587, "417": 588, "approximately": 589, "actually": 590, "1200": 591, "230": 592, "819": 593, "serviced": 594, "928": 595, "reservation": 596, "limousines": 597, "taxi": 598, "fit": 599, "72s": 600, "352": 601, "1133": 602, "43": 603, "define": 604, "directly": 605, "m80": 606, "close": 607, "restrictions": 608, "430": 609, "718": 610, "hou": 611, "costs": 612, "466": 613, "march": 614, "1026": 615, "1024": 616, "different": 617, "rentals": 618, "each": 619, "arrival": 620, "say": 621, "mealtime": 622, "932": 623, "1115": 624, "1245": 625, "include": 626, "whether": 627, "offered": 628, "130": 629, "alaska": 630, "296": 631, "they": 632, "106": 633, "york's": 634, "497766": 635, "itinerary": 636, "coming": 637, "month": 638, "bur": 639, "travels": 640, "pennsylvania": 641, "usa": 642, "1288": 643, "c": 644, "names": 645, "sure": 646, "meaning": 647, "ap80": 648, "269": 649, "reservations": 650, "d9s": 651, "sunday's": 652, "f28": 653, "934": 654, "earlier": 655, "1017": 656, "date": 657, "thank": 658, "oak": 659, "atl": 660, "cp": 661, "3357": 662, "1045": 663, "limo": 664, "845": 665, "sometime": 666, "1222": 667, "i'll": 668, "tennessee": 669, "0900": 670, "hello": 671, "let": 672, "repeat": 673, "provide": 674, "still": 675, "along": 676, "operation": 677, "year": 678, "one's": 679, "great": 680, "too": 681, "nighttime": 682, "1300": 683, "saturdays": 684, "416": 685, "four": 686, "257": 687, "minimum": 688, "intercontinental": 689, "february": 690, "spend": 691, "lastest": 692, "thing": 693, "originate": 694, "describe": 695, "concerning": 696, "sa": 697, "help": 698, "1700": 699, "225": 700, "1158": 701, "equipment": 702, "let's": 703, "wednesday's": 704, "quebec": 705, "highest": 706, "starting": 707, "taking": 708, "311": 709, "1230": 710, "able": 711, "put": 712, "later": 713, "takes": 714, "amount": 715, "qw": 716, "seven": 717, "maximum": 718, "yyz": 719, "it's": 720, "80": 721, "place": 722, "equal": 723, "while": 724, "train": 725, "look": 726, "815": 727, "takeoff": 728, "plan": 729, "2134": 730, "297": 731, "323": 732, "229": 733, "329": 734, "runs": 735, "730": 736, "closest": 737, "dulles": 738, "73s": 739, "so": 740, "economic": 741, "single": 742, "supper": 743, "110": 744, "calling": 745, "1205": 746, "55": 747, "michigan": 748, "proper": 749, "regarding": 750, "seats": 751, "19": 752, "m": 753, "midway": 754, "besides": 755, "reverse": 756, "1993": 757, "402": 758, "level": 759, "reaching": 760, "771": 761, "straight": 762, "located": 763, "305": 764, "repeating": 765, "indiana": 766, "connects": 767, "beginning": 768, "staying": 769, "town": 770, "cars": 771, "nonstops": 772, "300": 773, "345": 774, "dinnertime": 775, "sort": 776, "route": 777, "j31": 778, "tuesdays": 779, "212": 780, "705": 781, "red": 782, "eye": 783, "laying": 784, "friends": 785, "visit": 786, "here": 787, "them": 788, "lives": 789, "rent": 790, "279": 791, "137338": 792, "transcontinental": 793, "trans": 794, "world": 795, "1030": 796, "1130": 797, "come": 798, "727": 799, "1020": 800, "505": 801, "that's": 802, "163": 803, "ls": 804, "greatest": 805, "i've": 806, "got": 807, "somebody": 808, "else": 809, "wants": 810, "charges": 811, "734": 812, "carried": 813, "thirteenth": 814, "making": 815, "733": 816, "everywhere": 817, "prefer": 818, "run": 819, "non": 820, "315": 821, "746": 822, "companies": 823, "buy": 824, "very": 825, "270": 826, "locate": 827, "hartfield": 828, "start": 829, "98": 830, "inform": 831, "oh": 832, "82": 833, "139": 834, "1600": 835, "eleven": 836, "ord": 837, "mia": 838, "qualify": 839, "doesn't": 840, "mondays": 841, "catch": 842, "priced": 843, "bna": 844, "being": 845, "working": 846, "scenario": 847, "767": 848, "1940": 849, "150": 850, "100": 851, "afternoons": 852, "provides": 853, "723": 854, "1110": 855, "symbols": 856, "grounds": 857, "nw": 858, "539": 859, "soon": 860, "thereafter": 861, "scheduled": 862, "instead": 863, "810": 864, "lester": 865, "pearson": 866, "stapleton": 867, "615": 868, "twelve": 869, "bay": 870, "sounds": 871, "o'hare": 872, "ap68": 873, "fridays": 874, "try": 875, "fifteen": 876, "nights": 877, "determine": 878, "hold": 879, "lax": 880, "seat": 881, "k": 882, "planning": 883, "discount": 884, "summer": 885, "cover": 886, "271": 887, "tonight": 888, "off": 889, "124": 890, "thanks": 891, "longest": 892, "kindly": 893, "afterwards": 894, "overnight": 895, "1083": 896, "428": 897, "anything": 898, "1059": 899}}
|