import copy import json import os import re import zipfile from collections import OrderedDict import spacy from tqdm import tqdm from crazyneuraluser.UBAR_code import ontology, utils from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text from crazyneuraluser.UBAR_code.config import global_config as cfg from crazyneuraluser.UBAR_code.db_ops import MultiWozDB # value_set.json, all the domain[slot] values in datasets def get_db_values(value_set_path): processed = {} bspn_word = [] nlp = spacy.load("en_core_web_sm") with open(value_set_path, "r") as f: # read value set file in lower value_set = json.loads(f.read().lower()) with open("data/raw/UBAR/db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values otlg = json.loads(f.read().lower()) for ( domain, slots, ) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values processed[domain] = {} bspn_word.append("[" + domain + "]") for slot, values in slots.items(): s_p = ontology.normlize_slot_names.get(slot, slot) if s_p in ontology.informable_slots[domain]: bspn_word.append(s_p) processed[domain][s_p] = [] for ( domain, slots, ) in value_set.items(): # add all words of values of informable slots to bspn_word for slot, values in slots.items(): s_p = ontology.normlize_slot_names.get(slot, slot) if s_p in ontology.informable_slots[domain]: for v in values: _, v_p = clean_slot_values(domain, slot, v) v_p = " ".join([token.text for token in nlp(v_p)]).strip() processed[domain][s_p].append(v_p) for x in v_p.split(): if x not in bspn_word: bspn_word.append(x) for domain_slot, values in otlg.items(): # split domain-slots to domains and slots domain, slot = domain_slot.split("-") if domain == "bus": domain = "taxi" if slot == "price range": slot = "pricerange" if slot == "book stay": slot = "stay" if slot == "book day": slot = "day" if slot == "book people": slot = "people" if slot == "book time": slot = "time" if slot == "arrive by": slot = "arrive" if slot == "leave at": slot = "leave" if slot == "leaveat": slot = "leave" # add all slots and words of values if not already in processed and bspn_word if slot not in processed[domain]: processed[domain][slot] = [] bspn_word.append(slot) for v in values: _, v_p = clean_slot_values(domain, slot, v) v_p = " ".join([token.text for token in nlp(v_p)]).strip() if v_p not in processed[domain][slot]: processed[domain][slot].append(v_p) for x in v_p.split(): if x not in bspn_word: bspn_word.append(x) with open(value_set_path.replace(".json", "_processed.json"), "w") as f: json.dump(processed, f, indent=2) # save processed.json with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f: json.dump(bspn_word, f, indent=2) # save bspn_word print("DB value set processed! ") def preprocess_db(db_paths): # apply clean_slot_values to all dbs dbs = {} nlp = spacy.load("en_core_web_sm") for domain in ontology.all_domains: with open(db_paths[domain], "r") as f: # for every db_domain, read json file dbs[domain] = json.loads(f.read().lower()) # entry has information about slots of said domain for idx, entry in enumerate(dbs[domain]): new_entry = copy.deepcopy(entry) for key, value in entry.items(): # key = slot if type(value) is not str: continue del new_entry[key] key, value = clean_slot_values(domain, key, value) tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip() new_entry[key] = tokenize_and_back dbs[domain][idx] = new_entry with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f: json.dump(dbs[domain], f, indent=2) print("[%s] DB processed! " % domain) class DataPreprocessor(object): def __init__(self): self.nlp = spacy.load("en_core_web_sm") self.db = MultiWozDB(cfg.dbs) # load all processed dbs data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data_with_span_full.json" # archive = zipfile.ZipFile(data_path + ".zip", "r") # self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower()) self.convlab_data = json.loads(open(data_path, "r").read().lower()) self.delex_sg_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_single_valdict.json" self.delex_mt_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_multi_valdict.json" self.ambiguous_val_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/ambiguous_values.json" self.delex_refs_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/reference_no.json" self.delex_refs = json.loads(open(self.delex_refs_path, "r").read()) if not os.path.exists(self.delex_sg_valdict_path): ( self.delex_sg_valdict, self.delex_mt_valdict, self.ambiguous_vals, ) = self.get_delex_valdict() else: self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read()) self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read()) self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read()) self.vocab = utils.Vocab(cfg.vocab_size) def delex_by_annotation(self, dial_turn): u = dial_turn["text"].split() span = dial_turn["span_info"] for s in span: slot = s[1] if slot == "open": continue if ontology.da_abbr_to_slot_name.get(slot): slot = ontology.da_abbr_to_slot_name[slot] for idx in range(s[3], s[4] + 1): u[idx] = "" try: u[s[3]] = "[value_" + slot + "]" except Exception: u[5] = "[value_" + slot + "]" u_delex = " ".join([t for t in u if t != ""]) u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]") u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]") u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]") u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )") return u_delex def delex_by_valdict(self, text): text = clean_text(text) text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text) text = re.sub(r"\d[\s-]stars?", "[value_stars]", text) text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text) text = re.sub(r"tr[\d]{4}", "[value_id]", text) text = re.sub( r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})", "[value_postcode]", text, ) for value, slot in self.delex_mt_valdict.items(): text = text.replace(value, "[value_%s]" % slot) for value, slot in self.delex_sg_valdict.items(): tokens = text.split() for idx, tk in enumerate(tokens): if tk == value: tokens[idx] = "[value_%s]" % slot text = " ".join(tokens) for ambg_ent in self.ambiguous_vals: # ely is a place, but appears in words like moderately start_idx = text.find(" " + ambg_ent) if start_idx == -1: continue front_words = text[:start_idx].split() ent_type = "time" if ":" in ambg_ent else "place" for fw in front_words[::-1]: if fw in [ "arrive", "arrives", "arrived", "arriving", "arrival", "destination", "there", "reach", "to", "by", "before", ]: slot = "[value_arrive]" if ent_type == "time" else "[value_destination]" text = re.sub(" " + ambg_ent, " " + slot, text) elif fw in [ "leave", "leaves", "leaving", "depart", "departs", "departing", "departure", "from", "after", "pulls", ]: slot = "[value_leave]" if ent_type == "time" else "[value_departure]" text = re.sub(" " + ambg_ent, " " + slot, text) text = text.replace("[value_car] [value_car]", "[value_car]") return text def get_delex_valdict( self, ): skip_entry_type = { "taxi": ["taxi_phone"], "police": ["id"], "hospital": ["id"], "hotel": [ "id", "location", "internet", "parking", "takesbookings", "stars", "price", "n", "postcode", "phone", ], "attraction": [ "id", "location", "pricerange", "price", "openhours", "postcode", "phone", ], "train": ["price", "id"], "restaurant": [ "id", "location", "introduction", "signature", "type", "postcode", "phone", ], } entity_value_to_slot = {} ambiguous_entities = [] for domain, db_data in self.db.dbs.items(): print("Processing entity values in [%s]" % domain) if domain != "taxi": for db_entry in db_data: for slot, value in db_entry.items(): if slot not in skip_entry_type[domain]: if type(value) is not str: raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain)) else: slot, value = clean_slot_values(domain, slot, value) value = " ".join([token.text for token in self.nlp(value)]).strip() if value in entity_value_to_slot and entity_value_to_slot[value] != slot: # print(value, ": ",entity_value_to_slot[value], slot) ambiguous_entities.append(value) entity_value_to_slot[value] = slot else: # taxi db specific db_entry = db_data[0] for slot, ent_list in db_entry.items(): if slot not in skip_entry_type[domain]: for ent in ent_list: entity_value_to_slot[ent] = "car" ambiguous_entities = set(ambiguous_entities) ambiguous_entities.remove("cambridge") ambiguous_entities = list(ambiguous_entities) for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time? entity_value_to_slot.pop(amb_ent) entity_value_to_slot["parkside"] = "address" entity_value_to_slot["parkside, cambridge"] = "address" entity_value_to_slot["cambridge belfry"] = "name" entity_value_to_slot["hills road"] = "address" entity_value_to_slot["hills rd"] = "address" entity_value_to_slot["Parkside Police Station"] = "name" single_token_values = {} multi_token_values = {} for val, slt in entity_value_to_slot.items(): if val in ["cambridge"]: continue if len(val.split()) > 1: multi_token_values[val] = slt else: single_token_values[val] = slt with open(self.delex_sg_valdict_path, "w") as f: single_token_values = OrderedDict( sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True) ) json.dump(single_token_values, f, indent=2) print("single delex value dict saved!") with open(self.delex_mt_valdict_path, "w") as f: multi_token_values = OrderedDict( sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True) ) json.dump(multi_token_values, f, indent=2) print("multi delex value dict saved!") with open(self.ambiguous_val_path, "w") as f: json.dump(ambiguous_entities, f, indent=2) print("ambiguous value dict saved!") return single_token_values, multi_token_values, ambiguous_entities def preprocess_main(self, save_path=None, is_test=False): """ """ data = {} count = 0 self.unique_da = {} ordered_sysact_dict = {} for fn, raw_dial in tqdm(list(self.convlab_data.items())): count += 1 # if count == 100: # break compressed_goal = {} # for every dialog, keep track the goal, domains, requests dial_domains, dial_reqs = [], [] for dom, g in raw_dial["goal"].items(): if dom != "topic" and dom != "message" and g: if g.get("reqt"): # request info. eg. postcode/address/phone # normalize request slots for i, req_slot in enumerate(g["reqt"]): if ontology.normlize_slot_names.get(req_slot): g["reqt"][i] = ontology.normlize_slot_names[req_slot] dial_reqs.append(g["reqt"][i]) compressed_goal[dom] = g if dom in ontology.all_domains: dial_domains.append(dom) dial_reqs = list(set(dial_reqs)) dial = {"goal": compressed_goal, "log": []} single_turn = {} constraint_dict = OrderedDict() prev_constraint_dict = {} prev_turn_domain = ["general"] ordered_sysact_dict[fn] = {} for turn_num, dial_turn in enumerate(raw_dial["log"]): # for user turn, have text # sys turn: text, belief states(metadata), dialog_act, span_info dial_state = dial_turn["metadata"] if not dial_state: # user # delexicalize user utterance, either by annotation or by val_dict u = " ".join(clean_text(dial_turn["text"]).split()) # NOTE: Commenting out delexicalisation because it is not used and # breaks when I use generated user dialogues for some reason # if dial_turn["span_info"]: # u_delex = clean_text(self.delex_by_annotation(dial_turn)) # else: # u_delex = self.delex_by_valdict(dial_turn["text"]) single_turn["user"] = u # single_turn["user_delex"] = u_delex else: # system # delexicalize system response, either by annotation or by val_dict if dial_turn["span_info"]: s_delex = clean_text(self.delex_by_annotation(dial_turn)) else: if not dial_turn["text"]: print(fn) s_delex = self.delex_by_valdict(dial_turn["text"]) single_turn["resp"] = s_delex # get belief state, semi=informable/book=requestable, put into constraint_dict for domain in dial_domains: if not constraint_dict.get(domain): constraint_dict[domain] = OrderedDict() info_sv = dial_state[domain]["semi"] for s, v in info_sv.items(): s, v = clean_slot_values(domain, s, v) if len(v.split()) > 1: v = " ".join([token.text for token in self.nlp(v)]).strip() if v != "": constraint_dict[domain][s] = v book_sv = dial_state[domain]["book"] for s, v in book_sv.items(): if s == "booked": continue s, v = clean_slot_values(domain, s, v) if len(v.split()) > 1: v = " ".join([token.text for token in self.nlp(v)]).strip() if v != "": constraint_dict[domain][s] = v constraints = [] # list in format of [domain] slot value cons_delex = [] turn_dom_bs = [] for domain, info_slots in constraint_dict.items(): if info_slots: constraints.append("[" + domain + "]") cons_delex.append("[" + domain + "]") for slot, value in info_slots.items(): constraints.append(slot) constraints.extend(value.split()) cons_delex.append(slot) if domain not in prev_constraint_dict: turn_dom_bs.append(domain) elif prev_constraint_dict[domain] != constraint_dict[domain]: turn_dom_bs.append(domain) sys_act_dict = {} turn_dom_da = set() for act in dial_turn["dialog_act"]: d, a = act.split("-") # split domain-act turn_dom_da.add(d) turn_dom_da = list(turn_dom_da) if len(turn_dom_da) != 1 and "general" in turn_dom_da: turn_dom_da.remove("general") if len(turn_dom_da) != 1 and "booking" in turn_dom_da: turn_dom_da.remove("booking") # get turn domain turn_domain = turn_dom_bs for dom in turn_dom_da: if dom != "booking" and dom not in turn_domain: turn_domain.append(dom) if not turn_domain: turn_domain = prev_turn_domain if len(turn_domain) == 2 and "general" in turn_domain: turn_domain.remove("general") if len(turn_domain) == 2: if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]: turn_domain = turn_domain[::-1] # get system action for dom in turn_domain: sys_act_dict[dom] = {} add_to_last_collect = [] booking_act_map = {"inform": "offerbook", "book": "offerbooked"} for act, params in dial_turn["dialog_act"].items(): if act == "general-greet": continue d, a = act.split("-") if d == "general" and d not in sys_act_dict: sys_act_dict[d] = {} if d == "booking": d = turn_domain[0] a = booking_act_map.get(a, a) add_p = [] for param in params: p = param[0] if p == "none": continue elif ontology.da_abbr_to_slot_name.get(p): p = ontology.da_abbr_to_slot_name[p] if p not in add_p: add_p.append(p) add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False if add_to_last: add_to_last_collect.append((d, a, add_p)) else: sys_act_dict[d][a] = add_p for d, a, add_p in add_to_last_collect: sys_act_dict[d][a] = add_p for d in copy.copy(sys_act_dict): acts = sys_act_dict[d] if not acts: del sys_act_dict[d] if "inform" in acts and "offerbooked" in acts: for s in sys_act_dict[d]["inform"]: sys_act_dict[d]["offerbooked"].append(s) del sys_act_dict[d]["inform"] ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict sys_act = [] if "general-greet" in dial_turn["dialog_act"]: sys_act.extend(["[general]", "[greet]"]) for d, acts in sys_act_dict.items(): sys_act += ["[" + d + "]"] for a, slots in acts.items(): self.unique_da[d + "-" + a] = 1 sys_act += ["[" + a + "]"] sys_act += slots # get db pointers matnums = self.db.get_match_num(constraint_dict) match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] match = matnums[match_dom] dbvec = self.db.addDBPointer(match_dom, match) bkvec = self.db.addBookingPointer(dial_turn["dialog_act"]) # 4 database pointer for domains, 2 for booking single_turn["pointer"] = ",".join([str(d) for d in dbvec + bkvec]) single_turn["match"] = str(match) single_turn["constraint"] = " ".join(constraints) single_turn["cons_delex"] = " ".join(cons_delex) single_turn["sys_act"] = " ".join(sys_act) single_turn["turn_num"] = len(dial["log"]) single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain]) prev_turn_domain = copy.deepcopy(turn_domain) prev_constraint_dict = copy.deepcopy(constraint_dict) if "user" in single_turn: dial["log"].append(single_turn) for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act: self.vocab.add_word(t) # NOTE: Commenting out delexicalisation because it is not used and # breaks when I use generated user dialogues for some reason # for t in single_turn["user_delex"].split(): # if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"): # single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1]) # elif not self.vocab.has_word(t): # self.vocab.add_word(t) single_turn = {} data[fn] = dial # pprint(dial) # if count == 20: # break self.vocab.construct() self.vocab.save_vocab("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab") with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_acts.json", "w") as f: json.dump(ordered_sysact_dict, f, indent=2) with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_act_type.json", "w") as f: json.dump(self.unique_da, f, indent=2) return data if __name__ == "__main__": db_paths = { "attraction": "data/raw/UBAR/db/attraction_db.json", "hospital": "data/raw/UBAR/db/hospital_db.json", "hotel": "data/raw/UBAR/db/hotel_db.json", "police": "data/raw/UBAR/db/police_db.json", "restaurant": "data/raw/UBAR/db/restaurant_db.json", "taxi": "data/raw/UBAR/db/taxi_db.json", "train": "data/raw/UBAR/db/train_db.json", } get_db_values("data/raw/UBAR/db/value_set.json") preprocess_db(db_paths) dh = DataPreprocessor() data = dh.preprocess_main() if not os.path.exists("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed"): os.mkdir("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed") with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/data_for_ubar.json", "w") as f: json.dump(data, f, indent=2)