|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import jsonlines |
|
import json |
|
import copy |
|
import random |
|
import fire |
|
|
|
|
|
class Converter(ABC): |
|
|
|
def __init__(self, filepath) -> None: |
|
super().__init__() |
|
|
|
self.filepath = filepath |
|
|
|
def convert(self): |
|
""" |
|
Implement your convert logics in this function |
|
""" |
|
self.start() |
|
self.process() |
|
self.end() |
|
pass |
|
|
|
def start(self): |
|
print(f'Start processing {self.__class__.__name__} at {self.filepath}') |
|
|
|
def end(self): |
|
print( |
|
f'Finish processing {self.__class__.__name__} at {self.filepath}') |
|
|
|
@abstractmethod |
|
def process(self): |
|
""" |
|
Implement your convert logics in this function |
|
""" |
|
|
|
|
|
class WoWConverter(Converter): |
|
|
|
def process(self): |
|
|
|
train_data = json.load(open(f'{self.filepath}/train.json')) |
|
topic_data = {} |
|
for i in train_data: |
|
chosen_topic = i['chosen_topic'] |
|
if not chosen_topic in topic_data.keys(): |
|
topic_data[chosen_topic] = [] |
|
else: |
|
topic_data[chosen_topic].append((i['persona'], i['dialog'])) |
|
|
|
topic_data_sorted = sorted( |
|
topic_data.items(), key=lambda k: -len(k[1])) |
|
|
|
examples = [] |
|
for topic, dialogs in topic_data_sorted[1:100:2]: |
|
for persona, dialog in dialogs[:1]: |
|
history = [persona] |
|
history = [] |
|
example = {} |
|
checked_sentence = '' |
|
for i in dialog: |
|
speaker = i['speaker'] |
|
text = i['text'] |
|
if 'Wizard' in speaker: |
|
|
|
try: |
|
checked_sentence = next( |
|
iter(i['checked_sentence'].values())) |
|
except Exception: |
|
checked_sentence = '' |
|
response = text |
|
example['Context'] = ' EOS '.join(history) |
|
example['Knowledge'] = checked_sentence |
|
example['Response'] = response.strip() |
|
examples.append(copy.deepcopy(example)) |
|
example = {} |
|
else: |
|
text = text |
|
history.append(text.strip()) |
|
|
|
with jsonlines.open('../data/wow/wow_train.jsonl', mode='w') as writer: |
|
for i in examples: |
|
writer.write(i) |
|
|
|
for split in ['valid', 'test']: |
|
data = json.load( |
|
open(f'{self.filepath}/{split}_random_split.json')) |
|
examples = [] |
|
for dialog in data: |
|
history = [] |
|
example = {} |
|
checked_sentence = '' |
|
persona = dialog['persona'] |
|
history = [persona] |
|
for i in dialog['dialog']: |
|
speaker = i['speaker'] |
|
text = i['text'] |
|
if 'Wizard' in speaker: |
|
try: |
|
checked_sentence = next( |
|
iter(i['checked_sentence'].values())) |
|
except Exception: |
|
checked_sentence = '' |
|
|
|
text = text |
|
response = text |
|
example['Context'] = ' EOS '.join(history) |
|
example['Knowledge'] = checked_sentence |
|
example['Response'] = response.strip() |
|
examples.append(copy.deepcopy(example)) |
|
example = {} |
|
else: |
|
text = text |
|
history.append(text) |
|
|
|
with jsonlines.open(f'../data/wow/wow_{split}.jsonl', mode='w') as writer: |
|
for i in examples: |
|
writer.write(i) |
|
|
|
return super().process() |
|
|
|
|
|
class WoIConverter(Converter): |
|
|
|
def process(self): |
|
for split in ['train', 'valid', 'test']: |
|
reader = jsonlines.open(f'{self.filepath}/{split}.jsonl') |
|
examples = [] |
|
num_of_dialogs = 0 |
|
for dialog in reader: |
|
num_of_dialogs += 1 |
|
example = {} |
|
history = [] |
|
turn = '' |
|
data = list(dialog.values())[0] |
|
persona = data['apprentice_persona'] |
|
history = [persona.replace('\n', ' ')] |
|
|
|
for i in data['dialog_history']: |
|
if 'SearchAgent' in i['action']: |
|
continue |
|
|
|
else: |
|
if i['action'] == 'Wizard => Apprentice': |
|
|
|
contents = [] |
|
selected = [] |
|
|
|
for content_ in i['context']['contents']: |
|
contents.extend(content_['content']) |
|
|
|
for selected_ in i['context']['selected_contents']: |
|
selected.extend(selected_) |
|
|
|
knowledge = [] |
|
for c, s in zip(contents, selected[1:]): |
|
if s: |
|
knowledge.append(c) |
|
|
|
turn = i['text'].strip() |
|
example['Context'] = ' EOS '.join(history) |
|
example['Knowledge'] = ' '.join(knowledge) |
|
example['Response'] = turn.strip() |
|
examples.append(copy.deepcopy(example)) |
|
else: |
|
turn = i['text'].strip() |
|
history.append(turn) |
|
|
|
with jsonlines.open(f'../data/woi/woi_{split}.jsonl', mode='w') as writer: |
|
for i in examples: |
|
if split == 'train': |
|
if random.random() < 0.006: |
|
writer.write(i) |
|
else: |
|
writer.write(i) |
|
|
|
return super().process() |
|
|
|
|
|
class CoQAConverter(Converter): |
|
|
|
def process(self): |
|
|
|
for split in ['train', 'dev']: |
|
source = open(f'{self.filepath}/seq2seq-{split}-h2-src.txt') |
|
target = open(f'{self.filepath}/seq2seq-{split}-h2-tgt.txt') |
|
|
|
source_ = [] |
|
for line in source: |
|
if line.strip() != '': |
|
sotry, question = line.strip().split('||') |
|
source_.append((sotry, question)) |
|
|
|
target_ = [] |
|
for line in target: |
|
if line.strip() != '': |
|
target_.append(line.strip()) |
|
examples = [] |
|
for context, response in zip(source_, target_): |
|
story, question = context |
|
examples.append( |
|
{'Context': question, 'Response': response, 'Knowledge': story}) |
|
|
|
if split == 'dev': |
|
split = 'valid' |
|
with jsonlines.open(f'../data/coqa/coqa_{split}.jsonl', mode='w') as writer: |
|
for i in examples: |
|
if split == 'train': |
|
if random.random() < 0.006: |
|
writer.write(i) |
|
else: |
|
writer.write(i) |
|
|
|
return super().process() |
|
|
|
|
|
class MultiWOZConverter(Converter): |
|
|
|
def process(self): |
|
|
|
for split in ['train', 'val', 'test']: |
|
data = json.load(open(f'{self.filepath}/{split}.json')) |
|
examples = [] |
|
for i in data: |
|
name = i['file'].lower() |
|
history = [] |
|
for turn in i['info']: |
|
history.append(turn['user_orig']) |
|
bs = turn['BS'] |
|
bs_str = [] |
|
for domain, states in bs.items(): |
|
domain_str = [] |
|
for state in states: |
|
domain_str.append(state[0] + ' = ' + state[1]) |
|
domain_str = ' ; '.join(domain_str) |
|
bs_str.append(domain + ' ' + domain_str) |
|
bs_str = ' | '.join(bs_str) |
|
|
|
db_str = 'kb ' |
|
db = turn['KB'] |
|
if db == 0: |
|
db_str += 'zero' |
|
elif db_str == 1: |
|
db_str += 'one' |
|
elif db_str == 2: |
|
db_str += 'two' |
|
else: |
|
db_str += 'more than two' |
|
|
|
act_seq = ' '.join(turn['act'].keys()) |
|
example = {} |
|
example['Context'] = ' EOS '.join(history[:]) |
|
example['Knowledge'] = bs_str + ' | ' + db_str |
|
example['Response'] = act_seq + ' | ' + turn['sys'].strip() |
|
|
|
history.append(turn['sys'].strip()) |
|
examples.append(copy.copy(example)) |
|
|
|
if split == 'val': |
|
split = 'valid' |
|
with jsonlines.open(f'../data/multiwoz/multiwoz_{split}.jsonl', mode='w') as writer: |
|
for i in examples: |
|
if split == 'train': |
|
if random.random() < 0.006: |
|
writer.write(i) |
|
else: |
|
writer.write(i) |
|
|
|
return super().process() |
|
|
|
|
|
def convert(class_name, file_path): |
|
eval(class_name)(file_path).convert() |
|
|
|
|
|
def main(): |
|
fire.Fire(convert) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|