alistairmcleay's picture
Added dialogue system code
b16a132
import json
import os
import sys
from tqdm import tqdm
from crazyneuraluser.user_model_code.analysis_multiwoz import DATA_SPLIT, collect_data
from crazyneuraluser.user_model_code.utils_multiwoz import (
get_act_natural_language,
get_original_act_set,
load_schema,
normalise_intent,
normalise_slot,
normalise_value,
)
from crazyneuraluser.user_model_code.utils_sgd import (
add_str,
compare_slot_values_in_state,
conv_special_token,
dict2list,
get_special_tokens,
wrap_element,
)
""" pre-process script for MultiWOZ v2.2 """
class DialMetaData:
def __init__(self, dial_id, dial_meta, dial_act, unify_act):
self.dial_id = dial_id
self.unify_act = unify_act
self.turn_meta_list, self.scenario = self.parse(
dial_meta, dial_act
) # None for system turn
self.linearise_turns()
def parse(self, dial_meta, dial_act):
global n, act_intent, non_intent
assert len(dial_meta["turns"]) == len(dial_act)
turn_meta_list = []
scenario = []
sys_turn = None # dummy sys turn for first usr turn
prev_intent = ""
prev_usr_turn, prev_usr_turn_meta = (
None,
None,
) # dummpy for tracing goal change at first turn
for turn_id, turn in enumerate(dial_meta["turns"]):
assert turn_id == int(turn["turn_id"])
if turn["speaker"] == "SYSTEM":
sys_turn = turn
turn_meta_list.append(None)
continue
# init turn meta
turn_meta = TurnMetaData(
prev_intent, sys_turn, turn, self.dial_id, self.unify_act
)
# get goal change label
turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta)
# update previous goal
for prev_turn_meta in reversed(turn_meta_list):
if prev_turn_meta is None:
continue
prev_turn_meta.accumulate_constraints(turn_meta) # TODO: check goal
# record task (intent) in scenario
if turn_meta.usr_intent not in scenario:
scenario.append(turn_meta.usr_intent)
turn_meta_list.append(turn_meta)
prev_intent = turn_meta.usr_intent
prev_usr_turn, prev_usr_turn_meta = turn, turn_meta
assert len(turn_meta_list) == len(dial_meta["turns"])
return turn_meta_list, scenario
def linearise_turns(self):
# linearise necessary meterials
for turn_meta in self.turn_meta_list:
if turn_meta is None:
continue
turn_meta._linearise(self.scenario, SERVICE2META)
class TurnMetaData:
def __init__(self, prev_intent, sys_turn, usr_turn, dial_id, unify_act):
self.dial_id = dial_id
self.unify_act = unify_act
self.original_act_set = get_original_act_set() # act set w/o domain information
self.sys_turn, self.usr_turn = sys_turn, usr_turn
# turn id
self.sys_turn_id, self.usr_turn_id = self._get_turn_id(sys_turn, usr_turn)
# intent
self.usr_intent = normalise_intent(self._get_intent(usr_turn, prev_intent))
if remove_book_intent:
self.usr_intent = self.usr_intent.replace("book", "find")
assert self.usr_intent in INTENTS # or self.usr_intent == "temp temp"
self.service = self.usr_intent.split()[1]
# utterances
self.utt = {}
self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt(
usr_turn
)
# act
self.act2sv = {}
self.act2sv["sys"], _ = self._parse_action(self.sys_turn_id, self.sys_turn)
self.act2sv["usr"], self.usr_constraints = self._parse_action(
self.usr_turn_id, self.usr_turn
)
# task boundary
self._get_new_task_label(prev_intent)
# req_alts
self._get_req_alts_label()
def _get_turn_id(self, sys_turn, usr_turn):
usr_turn_id = int(usr_turn["turn_id"]) # 0, 2, 4 ...
sys_turn_id = int(sys_turn["turn_id"]) if sys_turn is not None else -1
assert sys_turn_id == (usr_turn_id - 1)
return sys_turn_id, usr_turn_id
def _get_utt(self, turn):
if turn is None:
return ""
return turn["utterance"]
def accumulate_constraints(self, new_turn_meta):
"""
Add slot, slot-value pairs from a given following turn
This function forms the user goal by accumulating constraints backward
"""
# only accumulate constraints with the same task/intent
if new_turn_meta.usr_intent != self.usr_intent:
return
if (
new_turn_meta.goal_change
): # if goal changes at a new turn, these constraints should not be put in previous turns
return
# only accumulate constraints without goal change
# if the value of a slot is changed (goal change) in a new turn,
# this slot-value pair is not part of initial goal and should not be added into the goal of previous turns
new_constraints = new_turn_meta.usr_constraints
self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union(
new_constraints["requestable"]
)
for slot, value_list in new_constraints["informable"].items():
if slot not in self.usr_constraints["informable"]:
self.usr_constraints["informable"][slot] = value_list
def get_goal_change_label(self, prev_usr_turn, prev_turn_meta):
"""check if goal changed (value of slot changes) between two turn states"""
# first usr turn
if prev_usr_turn is None:
assert self.usr_turn_id == 0
self.goal_change = False
return
# last usr turn
if "GOODBYE" in self.act2sv["usr"] or "THANK_YOU" in self.act2sv["usr"]:
self.goal_change = False
return
assert self.usr_turn_id != 0
assert prev_usr_turn["speaker"] == "USER"
# new task
if self.usr_intent != prev_turn_meta.usr_intent:
self.goal_change = False
return
# compare two states to obtain goal change flag
curr_state, prev_state = None, None
for frame in self.usr_turn["frames"]:
if frame["service"] == self.service:
curr_state = frame["state"]["slot_values"]
for frame in prev_usr_turn["frames"]:
if frame["service"] == prev_turn_meta.service:
prev_state = frame["state"]["slot_values"]
# check if slot value has changed at current turn (new slot is not counted)
assert curr_state is not None and prev_state is not None
self.goal_change = compare_slot_values_in_state(curr_state, prev_state)
def _get_domain_from_act(self, dialogue_act):
"""
parse the raw dialouge act annotation to get domain info
number of doamin can be more than 1 for multi-domain turns
"""
domains = set()
book_flag = False
for dact, sv_pairs in dialogue_act.items():
assert "-" in dact
domain, _ = dact.split("-")
if domain not in ["Booking", "general"]:
domains.add(domain)
for slot, value in sv_pairs:
if "book" in slot: # e.g., bookday
book_flag = True
return domains, book_flag
def _get_intent(self, usr_turn, prev_intent):
intents = []
for frame in usr_turn["frames"]:
# service = frame["service"]
intent = frame["state"]["active_intent"]
if intent != "NONE":
intents.append(intent)
if len(intents) == 1:
intent = intents[0]
if intent == "find_taxi":
intent = "book_taxi"
return intent # tackle 51.5k out of 71.5k user turns
# if above fails (e.g., due to wrong label), leverage usr act to help determine main intent/service
# possible domains in da: {'Hospital', 'Taxi', 'Train', 'Police', 'Restaurant', 'Booking', 'general',
# 'Attraction', 'Hotel'}
usr_act = data_act[self.dial_id][str(self.usr_turn_id)]["dialog_act"]
domains, book_flag = self._get_domain_from_act(usr_act)
if len(domains) == 1:
domain = list(domains)[0].lower()
if book_flag and domain in ["restaurant", "hotel", "train"]:
intent = "book_{}".format(domain)
elif domain == "taxi":
intent = "book_{}".format(domain)
else:
intent = "find_{}".format(domain)
return intent # tackle 58.1k out of 71.5k user turns
if "Taxi" in domains:
return "book_taxi" # tackle 58.8k out of 71.5k user turns
if (
self.usr_turn_id == 0
): # wrong label at first turn, no previous intent to use, only 136 user turns here
utt = usr_turn["utterance"]
if (
"restaurant" in utt
or "Restaurant" in utt
or "eat" in utt
or "din" in utt
):
return "find_restaurant"
elif (
"hotel" in utt
or "room" in utt
or "house" in utt
or "stay" in utt
or "live" in utt
):
return "find_hotel"
else:
return "find_attraction" # tackle 58.9k out of 71.5k user turns
else: # not first turn, leverage sys act to help decide intent
sys_act = data_act[self.dial_id][str(self.sys_turn_id)]["dialog_act"]
sys_domains, _ = self._get_domain_from_act(sys_act)
if len(sys_domains) == 1:
domain = list(sys_domains)[0].lower()
if book_flag and domain in ["restaurant", "hotel", "train"]:
intent = "book_{}".format(domain)
elif domain == "taxi":
intent = "book_{}".format(domain)
else:
intent = "find_{}".format(domain)
return intent # tackle 67.3k out of 71.5k user turns
# two cases left enter here
# 1. turns with only general act, e.g., bye
# 2. turns have multiple intents (very few)
# both will be handled using previous intent
assert prev_intent != ""
intent = "_".join(
prev_intent.split()
) # as prev_intent has been normalised already
return intent
def _parse_action(self, turn_id, turn):
"""parse the `dialog_act` field in `dialog_acts.json`
Returns:
act2sv: act to slot value pairs, {act=sv}; sv: slot to value list, {slot=[v1, v2]}
"""
act2sv = dict()
constraints = {"informable": dict(), "requestable": set()}
if turn is None:
return None, constraints
# get da from data_act
dialogue_act = data_act[self.dial_id][str(turn_id)]["dialog_act"]
# domains = set()
for dact, svs in dialogue_act.items():
assert "-" in dact
if self.unify_act: # will use only act part without domain info
domain, act = dact.split(
"-"
) # split `domain-act`, e.g., `hotel-inform` -> hotel, inform
else: # keep original mwoz act
act = dact # use act with domain info
if self.unify_act:
# unify act: `Booking-Inform` with no args is equivalent to `OfferBook` in train domain
if dact == "Booking-Inform" and svs == [["none", "none"]]:
act = "OfferBook"
# deal with act
if self.unify_act:
assert act in self.original_act_set
if turn["speaker"] == "USER":
assert act in ["Inform", "Request", "bye", "thank", "greet"]
act = get_act_natural_language(act)
if act not in act2sv:
act2sv[act] = dict()
# iterate slot value pairs
for slot, value in svs:
slot = normalise_slot(slot)
value = normalise_value(value)
# act to slot value pairs
# NOTE: same slot might appear more than once per turn, e.g., when the system informs two hotels with
# their addresses so a value list is stored for each slot
if slot not in act2sv[act]:
act2sv[act][slot] = []
act2sv[act][slot].append(value)
# collect constraints
if act in ["REQUEST", "Request", "request"]:
constraints["requestable"].add(slot)
else:
if slot != "Empty":
if (
slot not in constraints["informable"]
): # NOTE: same reason as act, value list per slot
constraints["informable"][slot] = []
constraints["informable"][slot].append(value)
return act2sv, constraints
def _linearise(self, scenario, service2meta):
self.linear_act = {}
self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"])
self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"])
self.linear_goal = self._linearise_goal(
self.usr_constraints, scenario, service2meta
)
def _linearise_goal(self, constraints, scenario, service2meta):
"""
linearise goal representation which consists of several parts:
scenario, task (intent), task description, constraints with informable and requestable
e.g., <SCENARIO/> task1 task2 .. </SCENARIO>
<TASK/> current task </TASK> <DESC/> task description </DESC>
<INFORM/> <SLOT/> slot1 </SLOT> <VALUE> value1 </VALUE> .. </INFORM>
<REQUEST/> <SLOT> slot1 </SLOT> <SLOT> slot2 </SLOT> .. </REQUEST>
"""
res = ""
# scenario
assert isinstance(scenario, list) and len(scenario) > 0
scenario = " ".join(
[wrap_element("INTENT", intent) for intent in scenario]
) # treat intent as nl
scenario_wrap = wrap_element("SCENARIO", scenario)
res = add_str(res, scenario_wrap)
# task name
intent = self.usr_intent
assert intent in scenario
intent_wrap = wrap_element("TASK", intent)
res = add_str(res, intent_wrap)
# task description
description = service2meta[self.service]["intents"][intent]["description"]
description_warp = wrap_element("DESC", description)
res = add_str(res, description_warp)
# informable
informable = dict2list(
constraints["informable"]
) # sorted sv pair list [slot=value]
res = add_str(res, "<INFORM/>")
for sv_pair in informable:
slot, value = sv_pair.split("=")
if value in ["True", "False", "Empty"]:
value = conv_special_token(value, SPECIAL_TOKENS)
if slot in ["Empty"]:
slot = conv_special_token(slot, SPECIAL_TOKENS)
# slot
slot_wrap = wrap_element("SLOT", slot)
res = add_str(res, slot_wrap)
# value
value_wrap = wrap_element("VALUE", value)
res = add_str(res, value_wrap)
res = add_str(res, "</INFORM>")
# requestable
requestable = sorted(
list(constraints["requestable"])
) # sorted slot list [slot]
res = add_str(res, "<REQUEST/>")
for slot in requestable:
slot_wrap = wrap_element("SLOT", slot)
res = add_str(res, slot_wrap)
res = add_str(res, "</REQUEST>")
return res[1:] # remove first space
def _linearise_act(self, act2sv):
"""
NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value
NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present
Return: ordered (slots sorted within act, acts sorted) linearised act sequence,
e.g., <ACT/> <INFORM> </ACT> <SLOT/> area </SLOT> <VALUE/> Cambridge </VALUE> ...
e.g., <ACT/> <REQUEST> </ACT> <SLOT/> _Empty_ </SLOT> <VALUE/> _Empty_ </VALUE>
"""
res = ""
if act2sv is None:
return res
for act in sorted(act2sv.keys()): # sort act
sv = act2sv[act] # dict{slot: value_list}
act_wrap = wrap_element("ACT", act)
res = add_str(res, act_wrap)
sorted_sv = dict2list(
sv
) # sorted sv list, [s1=v1, s2=v2], note slot can repeat
for sv_pair in sorted_sv:
slot, value = sv_pair.split("=")
if value in ["True", "False", "Empty"]:
value = conv_special_token(value, SPECIAL_TOKENS)
if slot in ["Empty"]:
slot = conv_special_token(slot, SPECIAL_TOKENS)
# slot
slot_wrap = wrap_element("SLOT", slot)
res = add_str(res, slot_wrap)
# value
value_wrap = wrap_element("VALUE", value)
res = add_str(res, value_wrap)
return res[1:] # remove first space
def _get_new_task_label(self, prev_intent):
"""
get a binary label indicating if a turn starts a new task (intent) in dialogue
"""
assert prev_intent != "NONE" and self.usr_intent != "NONE"
if self.usr_intent != prev_intent:
self.start_new_task = True
else:
self.start_new_task = False
def _get_req_alts_label(self):
self.req_alts = False # no request alternative in mwoz
def collect_examples(dial_id, dial_meta, examples):
num = 0
examples[dial_id] = {}
for turn_meta in dial_meta.turn_meta_list:
if turn_meta is None: # sys turn
continue
example_id = "{}-{}".format(dial_id, num)
example = {
"utterances": turn_meta.utt,
"actions": turn_meta.linear_act,
"goal": turn_meta.linear_goal,
"service": turn_meta.service,
"intent": turn_meta.usr_intent,
"goal_change": turn_meta.goal_change,
"start_new_task": turn_meta.start_new_task,
"req_alts": turn_meta.req_alts,
}
examples[dial_id][example_id] = example
num += 1
def prepare_data_seq(unify_act, out_data_path):
for split in DATA_SPLIT:
examples = {}
for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))):
dial = data[split][dial_id]
dial_act = data_act[dial_id]
dial_meta = DialMetaData(dial_id, dial, dial_act, unify_act)
collect_examples(dial_id, dial_meta, examples)
with open("{}/{}.json".format(out_data_path, split), "w") as f:
json.dump(examples, f, sort_keys=True, indent=4)
print("Done process {} {} dialogues".format(split, len(examples)))
if __name__ == "__main__":
if len(sys.argv) == 1:
print("Wrong argument!")
print("usage: python utils/preprocess_multiwoz.py multiwoz2.2-data-path")
sys.exit(1)
# Set data path
data_path = sys.argv[1]
out_data_path = "./data/preprocessed/user_model"
os.makedirs(out_data_path, exist_ok=True)
# Control flags
unify_act = True
remove_book_intent = True
# Load data and material as global var
SERVICE2META, INTENTS, SLOTS = load_schema(os.path.join(data_path, "schema.json"))
SPECIAL_TOKENS = get_special_tokens()
data, data_act = collect_data(data_path, remove_dial_switch=False)
prepare_data_seq(unify_act, out_data_path)