import json import os import sys from tqdm import tqdm from crazyneuraluser.user_model_code.analysis_multiwoz import DATA_SPLIT, collect_data from crazyneuraluser.user_model_code.utils_multiwoz import ( get_act_natural_language, get_original_act_set, load_schema, normalise_intent, normalise_slot, normalise_value, ) from crazyneuraluser.user_model_code.utils_sgd import ( add_str, compare_slot_values_in_state, conv_special_token, dict2list, get_special_tokens, wrap_element, ) """ pre-process script for MultiWOZ v2.2 """ class DialMetaData: def __init__(self, dial_id, dial_meta, dial_act, unify_act): self.dial_id = dial_id self.unify_act = unify_act self.turn_meta_list, self.scenario = self.parse( dial_meta, dial_act ) # None for system turn self.linearise_turns() def parse(self, dial_meta, dial_act): global n, act_intent, non_intent assert len(dial_meta["turns"]) == len(dial_act) turn_meta_list = [] scenario = [] sys_turn = None # dummy sys turn for first usr turn prev_intent = "" prev_usr_turn, prev_usr_turn_meta = ( None, None, ) # dummpy for tracing goal change at first turn for turn_id, turn in enumerate(dial_meta["turns"]): assert turn_id == int(turn["turn_id"]) if turn["speaker"] == "SYSTEM": sys_turn = turn turn_meta_list.append(None) continue # init turn meta turn_meta = TurnMetaData( prev_intent, sys_turn, turn, self.dial_id, self.unify_act ) # get goal change label turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta) # update previous goal for prev_turn_meta in reversed(turn_meta_list): if prev_turn_meta is None: continue prev_turn_meta.accumulate_constraints(turn_meta) # TODO: check goal # record task (intent) in scenario if turn_meta.usr_intent not in scenario: scenario.append(turn_meta.usr_intent) turn_meta_list.append(turn_meta) prev_intent = turn_meta.usr_intent prev_usr_turn, prev_usr_turn_meta = turn, turn_meta assert len(turn_meta_list) == len(dial_meta["turns"]) return turn_meta_list, scenario def linearise_turns(self): # linearise necessary meterials for turn_meta in self.turn_meta_list: if turn_meta is None: continue turn_meta._linearise(self.scenario, SERVICE2META) class TurnMetaData: def __init__(self, prev_intent, sys_turn, usr_turn, dial_id, unify_act): self.dial_id = dial_id self.unify_act = unify_act self.original_act_set = get_original_act_set() # act set w/o domain information self.sys_turn, self.usr_turn = sys_turn, usr_turn # turn id self.sys_turn_id, self.usr_turn_id = self._get_turn_id(sys_turn, usr_turn) # intent self.usr_intent = normalise_intent(self._get_intent(usr_turn, prev_intent)) if remove_book_intent: self.usr_intent = self.usr_intent.replace("book", "find") assert self.usr_intent in INTENTS # or self.usr_intent == "temp temp" self.service = self.usr_intent.split()[1] # utterances self.utt = {} self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt( usr_turn ) # act self.act2sv = {} self.act2sv["sys"], _ = self._parse_action(self.sys_turn_id, self.sys_turn) self.act2sv["usr"], self.usr_constraints = self._parse_action( self.usr_turn_id, self.usr_turn ) # task boundary self._get_new_task_label(prev_intent) # req_alts self._get_req_alts_label() def _get_turn_id(self, sys_turn, usr_turn): usr_turn_id = int(usr_turn["turn_id"]) # 0, 2, 4 ... sys_turn_id = int(sys_turn["turn_id"]) if sys_turn is not None else -1 assert sys_turn_id == (usr_turn_id - 1) return sys_turn_id, usr_turn_id def _get_utt(self, turn): if turn is None: return "" return turn["utterance"] def accumulate_constraints(self, new_turn_meta): """ Add slot, slot-value pairs from a given following turn This function forms the user goal by accumulating constraints backward """ # only accumulate constraints with the same task/intent if new_turn_meta.usr_intent != self.usr_intent: return if ( new_turn_meta.goal_change ): # if goal changes at a new turn, these constraints should not be put in previous turns return # only accumulate constraints without goal change # if the value of a slot is changed (goal change) in a new turn, # this slot-value pair is not part of initial goal and should not be added into the goal of previous turns new_constraints = new_turn_meta.usr_constraints self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union( new_constraints["requestable"] ) for slot, value_list in new_constraints["informable"].items(): if slot not in self.usr_constraints["informable"]: self.usr_constraints["informable"][slot] = value_list def get_goal_change_label(self, prev_usr_turn, prev_turn_meta): """check if goal changed (value of slot changes) between two turn states""" # first usr turn if prev_usr_turn is None: assert self.usr_turn_id == 0 self.goal_change = False return # last usr turn if "GOODBYE" in self.act2sv["usr"] or "THANK_YOU" in self.act2sv["usr"]: self.goal_change = False return assert self.usr_turn_id != 0 assert prev_usr_turn["speaker"] == "USER" # new task if self.usr_intent != prev_turn_meta.usr_intent: self.goal_change = False return # compare two states to obtain goal change flag curr_state, prev_state = None, None for frame in self.usr_turn["frames"]: if frame["service"] == self.service: curr_state = frame["state"]["slot_values"] for frame in prev_usr_turn["frames"]: if frame["service"] == prev_turn_meta.service: prev_state = frame["state"]["slot_values"] # check if slot value has changed at current turn (new slot is not counted) assert curr_state is not None and prev_state is not None self.goal_change = compare_slot_values_in_state(curr_state, prev_state) def _get_domain_from_act(self, dialogue_act): """ parse the raw dialouge act annotation to get domain info number of doamin can be more than 1 for multi-domain turns """ domains = set() book_flag = False for dact, sv_pairs in dialogue_act.items(): assert "-" in dact domain, _ = dact.split("-") if domain not in ["Booking", "general"]: domains.add(domain) for slot, value in sv_pairs: if "book" in slot: # e.g., bookday book_flag = True return domains, book_flag def _get_intent(self, usr_turn, prev_intent): intents = [] for frame in usr_turn["frames"]: # service = frame["service"] intent = frame["state"]["active_intent"] if intent != "NONE": intents.append(intent) if len(intents) == 1: intent = intents[0] if intent == "find_taxi": intent = "book_taxi" return intent # tackle 51.5k out of 71.5k user turns # if above fails (e.g., due to wrong label), leverage usr act to help determine main intent/service # possible domains in da: {'Hospital', 'Taxi', 'Train', 'Police', 'Restaurant', 'Booking', 'general', # 'Attraction', 'Hotel'} usr_act = data_act[self.dial_id][str(self.usr_turn_id)]["dialog_act"] domains, book_flag = self._get_domain_from_act(usr_act) if len(domains) == 1: domain = list(domains)[0].lower() if book_flag and domain in ["restaurant", "hotel", "train"]: intent = "book_{}".format(domain) elif domain == "taxi": intent = "book_{}".format(domain) else: intent = "find_{}".format(domain) return intent # tackle 58.1k out of 71.5k user turns if "Taxi" in domains: return "book_taxi" # tackle 58.8k out of 71.5k user turns if ( self.usr_turn_id == 0 ): # wrong label at first turn, no previous intent to use, only 136 user turns here utt = usr_turn["utterance"] if ( "restaurant" in utt or "Restaurant" in utt or "eat" in utt or "din" in utt ): return "find_restaurant" elif ( "hotel" in utt or "room" in utt or "house" in utt or "stay" in utt or "live" in utt ): return "find_hotel" else: return "find_attraction" # tackle 58.9k out of 71.5k user turns else: # not first turn, leverage sys act to help decide intent sys_act = data_act[self.dial_id][str(self.sys_turn_id)]["dialog_act"] sys_domains, _ = self._get_domain_from_act(sys_act) if len(sys_domains) == 1: domain = list(sys_domains)[0].lower() if book_flag and domain in ["restaurant", "hotel", "train"]: intent = "book_{}".format(domain) elif domain == "taxi": intent = "book_{}".format(domain) else: intent = "find_{}".format(domain) return intent # tackle 67.3k out of 71.5k user turns # two cases left enter here # 1. turns with only general act, e.g., bye # 2. turns have multiple intents (very few) # both will be handled using previous intent assert prev_intent != "" intent = "_".join( prev_intent.split() ) # as prev_intent has been normalised already return intent def _parse_action(self, turn_id, turn): """parse the `dialog_act` field in `dialog_acts.json` Returns: act2sv: act to slot value pairs, {act=sv}; sv: slot to value list, {slot=[v1, v2]} """ act2sv = dict() constraints = {"informable": dict(), "requestable": set()} if turn is None: return None, constraints # get da from data_act dialogue_act = data_act[self.dial_id][str(turn_id)]["dialog_act"] # domains = set() for dact, svs in dialogue_act.items(): assert "-" in dact if self.unify_act: # will use only act part without domain info domain, act = dact.split( "-" ) # split `domain-act`, e.g., `hotel-inform` -> hotel, inform else: # keep original mwoz act act = dact # use act with domain info if self.unify_act: # unify act: `Booking-Inform` with no args is equivalent to `OfferBook` in train domain if dact == "Booking-Inform" and svs == [["none", "none"]]: act = "OfferBook" # deal with act if self.unify_act: assert act in self.original_act_set if turn["speaker"] == "USER": assert act in ["Inform", "Request", "bye", "thank", "greet"] act = get_act_natural_language(act) if act not in act2sv: act2sv[act] = dict() # iterate slot value pairs for slot, value in svs: slot = normalise_slot(slot) value = normalise_value(value) # act to slot value pairs # NOTE: same slot might appear more than once per turn, e.g., when the system informs two hotels with # their addresses so a value list is stored for each slot if slot not in act2sv[act]: act2sv[act][slot] = [] act2sv[act][slot].append(value) # collect constraints if act in ["REQUEST", "Request", "request"]: constraints["requestable"].add(slot) else: if slot != "Empty": if ( slot not in constraints["informable"] ): # NOTE: same reason as act, value list per slot constraints["informable"][slot] = [] constraints["informable"][slot].append(value) return act2sv, constraints def _linearise(self, scenario, service2meta): self.linear_act = {} self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"]) self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"]) self.linear_goal = self._linearise_goal( self.usr_constraints, scenario, service2meta ) def _linearise_goal(self, constraints, scenario, service2meta): """ linearise goal representation which consists of several parts: scenario, task (intent), task description, constraints with informable and requestable e.g., task1 task2 .. current task task description slot1 value1 .. slot1 slot2 .. """ res = "" # scenario assert isinstance(scenario, list) and len(scenario) > 0 scenario = " ".join( [wrap_element("INTENT", intent) for intent in scenario] ) # treat intent as nl scenario_wrap = wrap_element("SCENARIO", scenario) res = add_str(res, scenario_wrap) # task name intent = self.usr_intent assert intent in scenario intent_wrap = wrap_element("TASK", intent) res = add_str(res, intent_wrap) # task description description = service2meta[self.service]["intents"][intent]["description"] description_warp = wrap_element("DESC", description) res = add_str(res, description_warp) # informable informable = dict2list( constraints["informable"] ) # sorted sv pair list [slot=value] res = add_str(res, "") for sv_pair in informable: slot, value = sv_pair.split("=") if value in ["True", "False", "Empty"]: value = conv_special_token(value, SPECIAL_TOKENS) if slot in ["Empty"]: slot = conv_special_token(slot, SPECIAL_TOKENS) # slot slot_wrap = wrap_element("SLOT", slot) res = add_str(res, slot_wrap) # value value_wrap = wrap_element("VALUE", value) res = add_str(res, value_wrap) res = add_str(res, "") # requestable requestable = sorted( list(constraints["requestable"]) ) # sorted slot list [slot] res = add_str(res, "") for slot in requestable: slot_wrap = wrap_element("SLOT", slot) res = add_str(res, slot_wrap) res = add_str(res, "") return res[1:] # remove first space def _linearise_act(self, act2sv): """ NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present Return: ordered (slots sorted within act, acts sorted) linearised act sequence, e.g., area Cambridge ... e.g., _Empty_ _Empty_ """ res = "" if act2sv is None: return res for act in sorted(act2sv.keys()): # sort act sv = act2sv[act] # dict{slot: value_list} act_wrap = wrap_element("ACT", act) res = add_str(res, act_wrap) sorted_sv = dict2list( sv ) # sorted sv list, [s1=v1, s2=v2], note slot can repeat for sv_pair in sorted_sv: slot, value = sv_pair.split("=") if value in ["True", "False", "Empty"]: value = conv_special_token(value, SPECIAL_TOKENS) if slot in ["Empty"]: slot = conv_special_token(slot, SPECIAL_TOKENS) # slot slot_wrap = wrap_element("SLOT", slot) res = add_str(res, slot_wrap) # value value_wrap = wrap_element("VALUE", value) res = add_str(res, value_wrap) return res[1:] # remove first space def _get_new_task_label(self, prev_intent): """ get a binary label indicating if a turn starts a new task (intent) in dialogue """ assert prev_intent != "NONE" and self.usr_intent != "NONE" if self.usr_intent != prev_intent: self.start_new_task = True else: self.start_new_task = False def _get_req_alts_label(self): self.req_alts = False # no request alternative in mwoz def collect_examples(dial_id, dial_meta, examples): num = 0 examples[dial_id] = {} for turn_meta in dial_meta.turn_meta_list: if turn_meta is None: # sys turn continue example_id = "{}-{}".format(dial_id, num) example = { "utterances": turn_meta.utt, "actions": turn_meta.linear_act, "goal": turn_meta.linear_goal, "service": turn_meta.service, "intent": turn_meta.usr_intent, "goal_change": turn_meta.goal_change, "start_new_task": turn_meta.start_new_task, "req_alts": turn_meta.req_alts, } examples[dial_id][example_id] = example num += 1 def prepare_data_seq(unify_act, out_data_path): for split in DATA_SPLIT: examples = {} for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))): dial = data[split][dial_id] dial_act = data_act[dial_id] dial_meta = DialMetaData(dial_id, dial, dial_act, unify_act) collect_examples(dial_id, dial_meta, examples) with open("{}/{}.json".format(out_data_path, split), "w") as f: json.dump(examples, f, sort_keys=True, indent=4) print("Done process {} {} dialogues".format(split, len(examples))) if __name__ == "__main__": if len(sys.argv) == 1: print("Wrong argument!") print("usage: python utils/preprocess_multiwoz.py multiwoz2.2-data-path") sys.exit(1) # Set data path data_path = sys.argv[1] out_data_path = "./data/preprocessed/user_model" os.makedirs(out_data_path, exist_ok=True) # Control flags unify_act = True remove_book_intent = True # Load data and material as global var SERVICE2META, INTENTS, SLOTS = load_schema(os.path.join(data_path, "schema.json")) SPECIAL_TOKENS = get_special_tokens() data, data_act = collect_data(data_path, remove_dial_switch=False) prepare_data_seq(unify_act, out_data_path)