amatiger commited on
Commit
59d97af
1 Parent(s): e6d934e

Upload 6 files

Browse files
scripts/create_downstream_dataset.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ------------------------------------------------------------------
3
+ # [Author] Title
4
+ # Description
5
+ # ------------------------------------------------------------------
6
+
7
+ VERSION=0.1.0
8
+ SUBJECT=DialoGLMRedditDataset
9
+ USAGE="Usage: "
10
+
11
+
12
+ # Please follow parlai to download WoW and WoI dataset
13
+ # WOW_PATH=/home/bapeng/anaconda3/envs/parlai/lib/python3.8/site-packages/data/wizard_of_wikipedia
14
+ # python downstream_tasks_converter.py WoWConverter ${WOW_PATH}
15
+
16
+ # WOI_PATH=/home/bapeng/anaconda3/envs/parlai/lib/python3.8/site-packages/data/wizard_of_interent
17
+ # python downstream_tasks_converter.py WoIConverter ${WOI_PATH}
18
+
19
+ # # Please follow https://github.com/stanfordnlp/coqa-baselines to prepare seq2seq-train-h2 and seq2seq-dev-h2
20
+ # COQA_PATH=/home/bapeng/experiment/cqa/coqa-baselines/data
21
+ # python downstream_tasks_converter.py CoQAConverter ${COQA_PATH}
22
+
23
+ # Please clone https://github.com/wenhuchen/HDSA-Dialog to download the data.
24
+ MULTIWOZ_PATH=/home/bapeng/experiment/HDSA-Dialog/data
25
+ python downstream_tasks_converter.py MultiWOZConverter ${MULTIWOZ_PATH}
scripts/create_grounded_dataset.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ------------------------------------------------------------------
3
+ # [Author] Title
4
+ # Description
5
+ # ------------------------------------------------------------------
6
+
7
+ VERSION=0.1.0
8
+ SUBJECT=DialoGLMGroundedDataset
9
+ USAGE="Usage: "
10
+
11
+
12
+ # Please follow https://microsoft.github.io/msmarco/ to download msmarco dataset
13
+ MSMARCO_PATH=/home/bapeng/experiment/DialoGLM/data/dummy_data/msmarco
14
+
15
+ # Please follow https://github.com/google-research-datasets/dstc8-schema-guided-dialogue
16
+ SGD_PATH=/home/bapeng/experiment/dstc8-schema-guided-dialogue
17
+
18
+ # Please follow https://github.com/mgalley/DSTC7-End-to-End-Conversation-Modeling
19
+ DSTC7_PATH=/home/bapeng/experiment/DialoGLM/data/dummy_data/dstc7/dstc7_h100.tsv
20
+
21
+ #Please follow instructions on https://github.com/allenai/unifiedqa to download the dataset
22
+ UNIFIED_QA_PATH=/home/bapeng/experiment/DialoGLM/data/dummy_data/unifedqa
23
+
24
+ python grounded_converter.py ${MSMARCO_PATH} ${SGD_PATH} ${DSTC7_PATH} ${UNIFIED_QA_PATH}
scripts/create_reddit.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+
6
+ import jsonlines
7
+ import fire
8
+
9
+
10
+ def _norm_text(text):
11
+ w, *toks = text.strip().split()
12
+ try:
13
+ w = float(w)
14
+ except Exception:
15
+ toks = [w] + toks
16
+ w = 1.0
17
+ return w, ' '.join(toks)
18
+
19
+
20
+ def _get_inputs_from_text(text):
21
+ srcs, tgt = text.strip().split('\t')
22
+ weights = []
23
+ inputs = []
24
+ for src in srcs.split(' EOS '):
25
+ src_weight, src = _norm_text(src)
26
+ weights.append(src_weight)
27
+ inputs.append(src)
28
+ tgt_weight, tgt = _norm_text(tgt)
29
+ if tgt_weight != 0:
30
+ weights.append(tgt_weight)
31
+ inputs.append(tgt)
32
+ return weights, inputs
33
+
34
+
35
+ def process(reddit_path):
36
+
37
+ idx = 0
38
+ writer = jsonlines.open('../data/reddit_session_level.jsonl', 'w')
39
+ with open(reddit_path, "r", encoding="utf-8") as reader:
40
+ for line in reader:
41
+ idx += 1
42
+ if idx % 10000 == 0:
43
+ print(idx)
44
+ weights, inputs = _get_inputs_from_text(line)
45
+ if 0.0 in weights:
46
+ continue
47
+ else:
48
+ writer.write({'text': ' EOS '.join(inputs)})
49
+
50
+ idx = 0
51
+ with open('../data/reddit_session_level.jsonl', "r", encoding="utf-8") as reader:
52
+ writer = jsonlines.open('../data/reddit.jsonl', mode='w')
53
+ for item in jsonlines.Reader(reader):
54
+ idx += 1
55
+ if idx % 10000 == 0:
56
+ print(idx)
57
+ context = item['text'].split('EOS')
58
+
59
+ for idx in range(0, len(context)-1):
60
+
61
+ history = 'EOS'.join(context[:idx+1])
62
+ response = context[idx+1]
63
+
64
+ if len(history) == 0:
65
+ continue
66
+
67
+ example = {}
68
+ example['Context'] = history
69
+ example['Knowledge'] = ''
70
+ example['Response'] = response.strip()
71
+
72
+ writer.write(example)
73
+
74
+
75
+ def main():
76
+ fire.Fire(process)
77
+
78
+
79
+ if __name__ == '__main__':
80
+ main()
scripts/create_reddit_dataset.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ------------------------------------------------------------------
3
+ # [Author] Title
4
+ # Description
5
+ # ------------------------------------------------------------------
6
+
7
+ VERSION=0.1.0
8
+ SUBJECT=DialoGLMRedditDataset
9
+ USAGE="Usage: "
10
+
11
+
12
+ # Please follow https://microsoft.github.io/msmarco/ to download msmarco dataset
13
+ REDDIT_PATH=../data/dummy_data/reddit/dialogpt.t1000.txt
14
+
15
+
16
+ python create_reddit.py ${REDDIT_PATH}
scripts/downstream_tasks_converter.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import jsonlines
9
+ import json
10
+ import copy
11
+ import random
12
+ import fire
13
+
14
+
15
+ class Converter(ABC):
16
+
17
+ def __init__(self, filepath) -> None:
18
+ super().__init__()
19
+
20
+ self.filepath = filepath
21
+
22
+ def convert(self):
23
+ """
24
+ Implement your convert logics in this function
25
+ """
26
+ self.start()
27
+ self.process()
28
+ self.end()
29
+ pass
30
+
31
+ def start(self):
32
+ print(f'Start processing {self.__class__.__name__} at {self.filepath}')
33
+
34
+ def end(self):
35
+ print(
36
+ f'Finish processing {self.__class__.__name__} at {self.filepath}')
37
+
38
+ @abstractmethod
39
+ def process(self):
40
+ """
41
+ Implement your convert logics in this function
42
+ """
43
+
44
+
45
+ class WoWConverter(Converter):
46
+
47
+ def process(self):
48
+
49
+ train_data = json.load(open(f'{self.filepath}/train.json'))
50
+ topic_data = {}
51
+ for i in train_data:
52
+ chosen_topic = i['chosen_topic']
53
+ if not chosen_topic in topic_data.keys():
54
+ topic_data[chosen_topic] = []
55
+ else:
56
+ topic_data[chosen_topic].append((i['persona'], i['dialog']))
57
+
58
+ topic_data_sorted = sorted(
59
+ topic_data.items(), key=lambda k: -len(k[1]))
60
+
61
+ examples = []
62
+ for topic, dialogs in topic_data_sorted[1:100:2]:
63
+ for persona, dialog in dialogs[:1]:
64
+ history = [persona]
65
+ history = []
66
+ example = {}
67
+ checked_sentence = ''
68
+ for i in dialog:
69
+ speaker = i['speaker']
70
+ text = i['text']
71
+ if 'Wizard' in speaker:
72
+
73
+ try:
74
+ checked_sentence = next(
75
+ iter(i['checked_sentence'].values()))
76
+ except Exception:
77
+ checked_sentence = ''
78
+ response = text
79
+ example['Context'] = ' EOS '.join(history)
80
+ example['Knowledge'] = checked_sentence
81
+ example['Response'] = response.strip()
82
+ examples.append(copy.deepcopy(example))
83
+ example = {}
84
+ else:
85
+ text = text
86
+ history.append(text.strip())
87
+
88
+ with jsonlines.open('../data/wow/wow_train.jsonl', mode='w') as writer:
89
+ for i in examples:
90
+ writer.write(i)
91
+
92
+ for split in ['valid', 'test']:
93
+ data = json.load(
94
+ open(f'{self.filepath}/{split}_random_split.json'))
95
+ examples = []
96
+ for dialog in data:
97
+ history = []
98
+ example = {}
99
+ checked_sentence = ''
100
+ persona = dialog['persona']
101
+ history = [persona]
102
+ for i in dialog['dialog']:
103
+ speaker = i['speaker']
104
+ text = i['text']
105
+ if 'Wizard' in speaker:
106
+ try:
107
+ checked_sentence = next(
108
+ iter(i['checked_sentence'].values()))
109
+ except Exception:
110
+ checked_sentence = ''
111
+
112
+ text = text
113
+ response = text
114
+ example['Context'] = ' EOS '.join(history)
115
+ example['Knowledge'] = checked_sentence
116
+ example['Response'] = response.strip()
117
+ examples.append(copy.deepcopy(example))
118
+ example = {}
119
+ else:
120
+ text = text
121
+ history.append(text)
122
+
123
+ with jsonlines.open(f'../data/wow/wow_{split}.jsonl', mode='w') as writer:
124
+ for i in examples:
125
+ writer.write(i)
126
+
127
+ return super().process()
128
+
129
+
130
+ class WoIConverter(Converter):
131
+
132
+ def process(self):
133
+ for split in ['train', 'valid', 'test']:
134
+ reader = jsonlines.open(f'{self.filepath}/{split}.jsonl')
135
+ examples = []
136
+ num_of_dialogs = 0
137
+ for dialog in reader:
138
+ num_of_dialogs += 1
139
+ example = {}
140
+ history = []
141
+ turn = ''
142
+ data = list(dialog.values())[0]
143
+ persona = data['apprentice_persona']
144
+ history = [persona.replace('\n', ' ')]
145
+
146
+ for i in data['dialog_history']:
147
+ if 'SearchAgent' in i['action']:
148
+ continue
149
+
150
+ else:
151
+ if i['action'] == 'Wizard => Apprentice':
152
+
153
+ contents = []
154
+ selected = []
155
+
156
+ for content_ in i['context']['contents']:
157
+ contents.extend(content_['content'])
158
+
159
+ for selected_ in i['context']['selected_contents']:
160
+ selected.extend(selected_)
161
+
162
+ knowledge = []
163
+ for c, s in zip(contents, selected[1:]):
164
+ if s:
165
+ knowledge.append(c)
166
+
167
+ turn = i['text'].strip()
168
+ example['Context'] = ' EOS '.join(history)
169
+ example['Knowledge'] = ' '.join(knowledge)
170
+ example['Response'] = turn.strip()
171
+ examples.append(copy.deepcopy(example))
172
+ else:
173
+ turn = i['text'].strip()
174
+ history.append(turn)
175
+
176
+ with jsonlines.open(f'../data/woi/woi_{split}.jsonl', mode='w') as writer:
177
+ for i in examples:
178
+ if split == 'train':
179
+ if random.random() < 0.006:
180
+ writer.write(i)
181
+ else:
182
+ writer.write(i)
183
+
184
+ return super().process()
185
+
186
+
187
+ class CoQAConverter(Converter):
188
+
189
+ def process(self):
190
+
191
+ for split in ['train', 'dev']:
192
+ source = open(f'{self.filepath}/seq2seq-{split}-h2-src.txt')
193
+ target = open(f'{self.filepath}/seq2seq-{split}-h2-tgt.txt')
194
+
195
+ source_ = []
196
+ for line in source:
197
+ if line.strip() != '':
198
+ sotry, question = line.strip().split('||')
199
+ source_.append((sotry, question))
200
+
201
+ target_ = []
202
+ for line in target:
203
+ if line.strip() != '':
204
+ target_.append(line.strip())
205
+ examples = []
206
+ for context, response in zip(source_, target_):
207
+ story, question = context
208
+ examples.append(
209
+ {'Context': question, 'Response': response, 'Knowledge': story})
210
+
211
+ if split == 'dev':
212
+ split = 'valid'
213
+ with jsonlines.open(f'../data/coqa/coqa_{split}.jsonl', mode='w') as writer:
214
+ for i in examples:
215
+ if split == 'train':
216
+ if random.random() < 0.006:
217
+ writer.write(i)
218
+ else:
219
+ writer.write(i)
220
+
221
+ return super().process()
222
+
223
+
224
+ class MultiWOZConverter(Converter):
225
+
226
+ def process(self):
227
+
228
+ for split in ['train', 'val', 'test']:
229
+ data = json.load(open(f'{self.filepath}/{split}.json'))
230
+ examples = []
231
+ for i in data:
232
+ name = i['file'].lower()
233
+ history = []
234
+ for turn in i['info']:
235
+ history.append(turn['user_orig'])
236
+ bs = turn['BS']
237
+ bs_str = []
238
+ for domain, states in bs.items():
239
+ domain_str = []
240
+ for state in states:
241
+ domain_str.append(state[0] + ' = ' + state[1])
242
+ domain_str = ' ; '.join(domain_str)
243
+ bs_str.append(domain + ' ' + domain_str)
244
+ bs_str = ' | '.join(bs_str)
245
+
246
+ db_str = 'kb '
247
+ db = turn['KB']
248
+ if db == 0:
249
+ db_str += 'zero'
250
+ elif db_str == 1:
251
+ db_str += 'one'
252
+ elif db_str == 2:
253
+ db_str += 'two'
254
+ else:
255
+ db_str += 'more than two'
256
+
257
+ act_seq = ' '.join(turn['act'].keys())
258
+ example = {}
259
+ example['Context'] = ' EOS '.join(history[:])
260
+ example['Knowledge'] = bs_str + ' | ' + db_str
261
+ example['Response'] = act_seq + ' | ' + turn['sys'].strip()
262
+
263
+ history.append(turn['sys'].strip())
264
+ examples.append(copy.copy(example))
265
+
266
+ if split == 'val':
267
+ split = 'valid'
268
+ with jsonlines.open(f'../data/multiwoz/multiwoz_{split}.jsonl', mode='w') as writer:
269
+ for i in examples:
270
+ if split == 'train':
271
+ if random.random() < 0.006:
272
+ writer.write(i)
273
+ else:
274
+ writer.write(i)
275
+
276
+ return super().process()
277
+
278
+
279
+ def convert(class_name, file_path):
280
+ eval(class_name)(file_path).convert()
281
+
282
+
283
+ def main():
284
+ fire.Fire(convert)
285
+
286
+
287
+ if __name__ == '__main__':
288
+ main()
scripts/grounded_converter.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import jsonlines
9
+ import json
10
+ import copy
11
+ import glob
12
+ import random
13
+ import fire
14
+
15
+
16
+ class Converter(ABC):
17
+
18
+ def __init__(self, filepath) -> None:
19
+ super().__init__()
20
+
21
+ self.filepath = filepath
22
+
23
+ def convert(self):
24
+ """
25
+ Implement your convert logics in this function
26
+ """
27
+ self.start()
28
+ self.process()
29
+ self.end()
30
+ pass
31
+
32
+ def start(self):
33
+ print(f'Start processing {self.__class__.__name__} at {self.filepath}')
34
+
35
+ def end(self):
36
+ print(
37
+ f'Finish processing {self.__class__.__name__} at {self.filepath}')
38
+
39
+ @abstractmethod
40
+ def process(self):
41
+ """
42
+ Implement your convert logics in this function
43
+ """
44
+
45
+
46
+ class DSTC7Converter(Converter):
47
+
48
+ '''
49
+ Converter class for DSTC7 Grounded response generation
50
+ '''
51
+
52
+ def process(self):
53
+
54
+ convs = open(self.filepath)
55
+ examples = []
56
+ for conv in convs:
57
+ _, c_id, score, facts, context, response = conv.split('\t')
58
+ example = {}
59
+ if context.strip() == 'START':
60
+ continue
61
+ context = context.replace('START EOS TIL ', '')
62
+ example['Context'] = context.strip()
63
+ example['Knowledge'] = facts.replace(
64
+ ' < p > ', '').replace(' < /p > ', '').strip()
65
+ example['Response'] = response.strip()
66
+ examples.append(copy.deepcopy(example))
67
+
68
+ with jsonlines.open('../data/dstc7.jsonl', mode='w') as writer:
69
+ for i in examples:
70
+ writer.write(i)
71
+
72
+ return
73
+
74
+
75
+ class MSMARCOConverter(Converter):
76
+
77
+ '''
78
+ Converter class for MS MARCO
79
+ '''
80
+
81
+ def process(self):
82
+
83
+ train_data = json.load(open(self.filepath))
84
+ examples = []
85
+ for ids in train_data['query'].keys():
86
+ query, answer, passage = train_data['query'][ids], train_data['answers'][ids], train_data['passages'][ids]
87
+ knowledge = [i['passage_text']
88
+ for i in passage if i['is_selected']]
89
+ example = {}
90
+ example['Context'] = query.strip()
91
+ example['Knowledge'] = ' '.join(knowledge)
92
+ example['Response'] = ' '.join(answer).strip()
93
+ examples.append(copy.deepcopy(example))
94
+
95
+ with jsonlines.open('../data/msmarco.jsonl', mode='w') as writer:
96
+ for i in examples:
97
+ writer.write(i)
98
+
99
+ return
100
+
101
+
102
+ class UnifiedQAConverter(Converter):
103
+
104
+ def process(self):
105
+
106
+ examples = []
107
+ for fname in glob.glob(f'{self.filepath}/*/*'):
108
+ if 'train.tsv' in fname or 'test.tsv' in fname:
109
+ data = open(fname)
110
+ for line in data:
111
+ line = line.strip()
112
+ try:
113
+ question, answer = line.split('\t')
114
+ question, story = question.split('\\n')
115
+ example = {}
116
+ example['Context'] = question
117
+ example['Response'] = answer
118
+ example['Knowledge'] = story
119
+ examples.append(copy.deepcopy(example))
120
+ k += 1
121
+ except:
122
+ pass
123
+
124
+ train_writer = jsonlines.open('../data/unifiedqa.jsonl', mode='w')
125
+ for i in examples:
126
+ train_writer.write(i)
127
+
128
+ return
129
+
130
+
131
+ class SGDConverter(Converter):
132
+
133
+ '''
134
+ Converter class for SGD dataset
135
+ '''
136
+
137
+ def process(self):
138
+
139
+ examples = []
140
+ for split in ['train', 'dev', 'test']:
141
+ schema_info = json.load(
142
+ open(f'{self.filepath}/{split}/schema.json'))
143
+ schema_info = dict([(i['service_name'], i) for i in schema_info])
144
+ for file in glob.glob(f'{self.filepath}/{split}/dialogues_*.json'):
145
+ data = json.load(open(file))
146
+ for dialogue in data:
147
+ dialogue_id = dialogue['dialogue_id']
148
+ services = dialogue['services'][0]
149
+ schema = schema_info[services]
150
+ description = schema['description']
151
+ task_slots = [s['name'] for s in schema['slots']]
152
+ task_intents = [s['name'] for s in schema['intents']]
153
+ task_intents_description = [
154
+ s['description'] for s in schema['intents']]
155
+ turns = dialogue['turns']
156
+ history = []
157
+ example = {}
158
+ for idx, turn in enumerate(turns):
159
+ if idx == 0:
160
+ assert turn['speaker'] == 'USER'
161
+ frame = turn['frames'][0]
162
+ service = turn['frames'][0]['service'].split('_')[
163
+ 0].lower()
164
+ if turn['speaker'] == 'USER':
165
+ user_utter = turn['utterance']
166
+ history.append(f'{user_utter}')
167
+ belief_slot_values = frame['state']['slot_values']
168
+ slot_values_list = []
169
+ for slot_value in belief_slot_values.items():
170
+ slot, values = slot_value
171
+ value = values[0]
172
+ slot_values_list.append(f'{slot} = {value}')
173
+ slot_values_str = ' ; '.join(slot_values_list)
174
+
175
+ else:
176
+ sys_utter = copy.copy(turn['utterance'])
177
+ slot_values_str = f'belief : {service} {slot_values_str}'
178
+
179
+ slots = frame['slots']
180
+ offset = 0
181
+ len_ = len(sys_utter)
182
+ candidates = []
183
+ for idx, slot_info in enumerate(slots):
184
+ start, end, slot_name = slot_info['start'], slot_info['exclusive_end'], slot_info['slot']
185
+ sys_utter = sys_utter[:start+offset] + str(
186
+ idx) * (end - start) + sys_utter[end+offset:]
187
+ candidates.append(
188
+ (slot_name, str(idx) * (end - start)))
189
+ for idx, info in enumerate(candidates):
190
+ slotname, target = info
191
+ sys_utter = sys_utter.replace(
192
+ target, f'[{slotname}]')
193
+
194
+ reply = f'{sys_utter}'
195
+ example['Context'] = ' EOS '.join(history)
196
+ example['Knowledge'] = slot_values_str
197
+ example['Response'] = reply
198
+ examples.append(copy.deepcopy(example))
199
+ history.append(reply)
200
+
201
+ train_writer = jsonlines.open('../data/sgd.jsonl', mode='w')
202
+ for i in examples:
203
+ train_writer.write(i)
204
+
205
+ return
206
+
207
+
208
+ def merge_and_split():
209
+
210
+ examples = []
211
+ filepath = '../data/dstc7.jsonl'
212
+ with open(filepath, "r", encoding="utf-8") as reader:
213
+ for item in jsonlines.Reader(reader):
214
+ examples.append(item)
215
+
216
+ filepath = '../data/msmarco.jsonl'
217
+ with open(filepath, "r", encoding="utf-8") as reader:
218
+ for item in jsonlines.Reader(reader):
219
+ examples.append(item)
220
+
221
+ filepath = '../data/sgd.jsonl'
222
+ with open(filepath, "r", encoding="utf-8") as reader:
223
+ for item in jsonlines.Reader(reader):
224
+ examples.append(item)
225
+
226
+ filepath = '../data/unifiedqa.jsonl'
227
+ with open(filepath, "r", encoding="utf-8") as reader:
228
+ for item in jsonlines.Reader(reader):
229
+ examples.append(item)
230
+
231
+ random.seed(2021)
232
+ train_writer = jsonlines.open(
233
+ '../data/grounded_data_train.jsonl', mode='w')
234
+ valid_writer = jsonlines.open(
235
+ '../data/grounded_data_valid.jsonl', mode='w')
236
+ for i in examples:
237
+ if random.random() < 0.01:
238
+ valid_writer.write(i)
239
+ else:
240
+ train_writer.write(i)
241
+
242
+ print('Done!')
243
+
244
+
245
+ def process(
246
+ msmarco_path,
247
+ sgd_path,
248
+ dstc7_path,
249
+ unified_qa_path
250
+ ):
251
+ MSMARCOConverter(f'{msmarco_path}/train_v2.1.json').convert()
252
+ SGDConverter(f'{sgd_path}').convert()
253
+ DSTC7Converter(f'{dstc7_path}').convert()
254
+ UnifiedQAConverter(unified_qa_path).convert()
255
+
256
+
257
+ def main():
258
+ fire.Fire(process)
259
+ # merge generated data and split it into train and valid
260
+ merge_and_split()
261
+
262
+
263
+ if __name__ == '__main__':
264
+ main()