alistairmcleay's picture
Added dialogue system code
b16a132
import copy
import json
import os
import re
import zipfile
from collections import OrderedDict
import spacy
from tqdm import tqdm
from crazyneuraluser.UBAR_code import ontology, utils
from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text
from crazyneuraluser.UBAR_code.config import global_config as cfg
from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
# value_set.json, all the domain[slot] values in datasets
def get_db_values(value_set_path):
processed = {}
bspn_word = []
nlp = spacy.load("en_core_web_sm")
with open(value_set_path, "r") as f: # read value set file in lower
value_set = json.loads(f.read().lower())
with open("data/raw/UBAR/db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values
otlg = json.loads(f.read().lower())
for (
domain,
slots,
) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values
processed[domain] = {}
bspn_word.append("[" + domain + "]")
for slot, values in slots.items():
s_p = ontology.normlize_slot_names.get(slot, slot)
if s_p in ontology.informable_slots[domain]:
bspn_word.append(s_p)
processed[domain][s_p] = []
for (
domain,
slots,
) in value_set.items(): # add all words of values of informable slots to bspn_word
for slot, values in slots.items():
s_p = ontology.normlize_slot_names.get(slot, slot)
if s_p in ontology.informable_slots[domain]:
for v in values:
_, v_p = clean_slot_values(domain, slot, v)
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
processed[domain][s_p].append(v_p)
for x in v_p.split():
if x not in bspn_word:
bspn_word.append(x)
for domain_slot, values in otlg.items(): # split domain-slots to domains and slots
domain, slot = domain_slot.split("-")
if domain == "bus":
domain = "taxi"
if slot == "price range":
slot = "pricerange"
if slot == "book stay":
slot = "stay"
if slot == "book day":
slot = "day"
if slot == "book people":
slot = "people"
if slot == "book time":
slot = "time"
if slot == "arrive by":
slot = "arrive"
if slot == "leave at":
slot = "leave"
if slot == "leaveat":
slot = "leave"
# add all slots and words of values if not already in processed and bspn_word
if slot not in processed[domain]:
processed[domain][slot] = []
bspn_word.append(slot)
for v in values:
_, v_p = clean_slot_values(domain, slot, v)
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
if v_p not in processed[domain][slot]:
processed[domain][slot].append(v_p)
for x in v_p.split():
if x not in bspn_word:
bspn_word.append(x)
with open(value_set_path.replace(".json", "_processed.json"), "w") as f:
json.dump(processed, f, indent=2) # save processed.json
with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f:
json.dump(bspn_word, f, indent=2) # save bspn_word
print("DB value set processed! ")
def preprocess_db(db_paths): # apply clean_slot_values to all dbs
dbs = {}
nlp = spacy.load("en_core_web_sm")
for domain in ontology.all_domains:
with open(db_paths[domain], "r") as f: # for every db_domain, read json file
dbs[domain] = json.loads(f.read().lower())
# entry has information about slots of said domain
for idx, entry in enumerate(dbs[domain]):
new_entry = copy.deepcopy(entry)
for key, value in entry.items(): # key = slot
if type(value) is not str:
continue
del new_entry[key]
key, value = clean_slot_values(domain, key, value)
tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip()
new_entry[key] = tokenize_and_back
dbs[domain][idx] = new_entry
with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f:
json.dump(dbs[domain], f, indent=2)
print("[%s] DB processed! " % domain)
class DataPreprocessor(object):
def __init__(self):
self.nlp = spacy.load("en_core_web_sm")
self.db = MultiWozDB(cfg.dbs) # load all processed dbs
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data_with_span_full.json"
# archive = zipfile.ZipFile(data_path + ".zip", "r")
# self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower())
self.convlab_data = json.loads(open(data_path, "r").read().lower())
self.delex_sg_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_single_valdict.json"
self.delex_mt_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_multi_valdict.json"
self.ambiguous_val_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/ambiguous_values.json"
self.delex_refs_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/reference_no.json"
self.delex_refs = json.loads(open(self.delex_refs_path, "r").read())
if not os.path.exists(self.delex_sg_valdict_path):
(
self.delex_sg_valdict,
self.delex_mt_valdict,
self.ambiguous_vals,
) = self.get_delex_valdict()
else:
self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read())
self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read())
self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read())
self.vocab = utils.Vocab(cfg.vocab_size)
def delex_by_annotation(self, dial_turn):
u = dial_turn["text"].split()
span = dial_turn["span_info"]
for s in span:
slot = s[1]
if slot == "open":
continue
if ontology.da_abbr_to_slot_name.get(slot):
slot = ontology.da_abbr_to_slot_name[slot]
for idx in range(s[3], s[4] + 1):
u[idx] = ""
try:
u[s[3]] = "[value_" + slot + "]"
except Exception:
u[5] = "[value_" + slot + "]"
u_delex = " ".join([t for t in u if t != ""])
u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]")
u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]")
u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]")
u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )")
return u_delex
def delex_by_valdict(self, text):
text = clean_text(text)
text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text)
text = re.sub(r"\d[\s-]stars?", "[value_stars]", text)
text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text)
text = re.sub(r"tr[\d]{4}", "[value_id]", text)
text = re.sub(
r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})",
"[value_postcode]",
text,
)
for value, slot in self.delex_mt_valdict.items():
text = text.replace(value, "[value_%s]" % slot)
for value, slot in self.delex_sg_valdict.items():
tokens = text.split()
for idx, tk in enumerate(tokens):
if tk == value:
tokens[idx] = "[value_%s]" % slot
text = " ".join(tokens)
for ambg_ent in self.ambiguous_vals:
# ely is a place, but appears in words like moderately
start_idx = text.find(" " + ambg_ent)
if start_idx == -1:
continue
front_words = text[:start_idx].split()
ent_type = "time" if ":" in ambg_ent else "place"
for fw in front_words[::-1]:
if fw in [
"arrive",
"arrives",
"arrived",
"arriving",
"arrival",
"destination",
"there",
"reach",
"to",
"by",
"before",
]:
slot = "[value_arrive]" if ent_type == "time" else "[value_destination]"
text = re.sub(" " + ambg_ent, " " + slot, text)
elif fw in [
"leave",
"leaves",
"leaving",
"depart",
"departs",
"departing",
"departure",
"from",
"after",
"pulls",
]:
slot = "[value_leave]" if ent_type == "time" else "[value_departure]"
text = re.sub(" " + ambg_ent, " " + slot, text)
text = text.replace("[value_car] [value_car]", "[value_car]")
return text
def get_delex_valdict(
self,
):
skip_entry_type = {
"taxi": ["taxi_phone"],
"police": ["id"],
"hospital": ["id"],
"hotel": [
"id",
"location",
"internet",
"parking",
"takesbookings",
"stars",
"price",
"n",
"postcode",
"phone",
],
"attraction": [
"id",
"location",
"pricerange",
"price",
"openhours",
"postcode",
"phone",
],
"train": ["price", "id"],
"restaurant": [
"id",
"location",
"introduction",
"signature",
"type",
"postcode",
"phone",
],
}
entity_value_to_slot = {}
ambiguous_entities = []
for domain, db_data in self.db.dbs.items():
print("Processing entity values in [%s]" % domain)
if domain != "taxi":
for db_entry in db_data:
for slot, value in db_entry.items():
if slot not in skip_entry_type[domain]:
if type(value) is not str:
raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain))
else:
slot, value = clean_slot_values(domain, slot, value)
value = " ".join([token.text for token in self.nlp(value)]).strip()
if value in entity_value_to_slot and entity_value_to_slot[value] != slot:
# print(value, ": ",entity_value_to_slot[value], slot)
ambiguous_entities.append(value)
entity_value_to_slot[value] = slot
else: # taxi db specific
db_entry = db_data[0]
for slot, ent_list in db_entry.items():
if slot not in skip_entry_type[domain]:
for ent in ent_list:
entity_value_to_slot[ent] = "car"
ambiguous_entities = set(ambiguous_entities)
ambiguous_entities.remove("cambridge")
ambiguous_entities = list(ambiguous_entities)
for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time?
entity_value_to_slot.pop(amb_ent)
entity_value_to_slot["parkside"] = "address"
entity_value_to_slot["parkside, cambridge"] = "address"
entity_value_to_slot["cambridge belfry"] = "name"
entity_value_to_slot["hills road"] = "address"
entity_value_to_slot["hills rd"] = "address"
entity_value_to_slot["Parkside Police Station"] = "name"
single_token_values = {}
multi_token_values = {}
for val, slt in entity_value_to_slot.items():
if val in ["cambridge"]:
continue
if len(val.split()) > 1:
multi_token_values[val] = slt
else:
single_token_values[val] = slt
with open(self.delex_sg_valdict_path, "w") as f:
single_token_values = OrderedDict(
sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
)
json.dump(single_token_values, f, indent=2)
print("single delex value dict saved!")
with open(self.delex_mt_valdict_path, "w") as f:
multi_token_values = OrderedDict(
sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
)
json.dump(multi_token_values, f, indent=2)
print("multi delex value dict saved!")
with open(self.ambiguous_val_path, "w") as f:
json.dump(ambiguous_entities, f, indent=2)
print("ambiguous value dict saved!")
return single_token_values, multi_token_values, ambiguous_entities
def preprocess_main(self, save_path=None, is_test=False):
""" """
data = {}
count = 0
self.unique_da = {}
ordered_sysact_dict = {}
for fn, raw_dial in tqdm(list(self.convlab_data.items())):
count += 1
# if count == 100:
# break
compressed_goal = {} # for every dialog, keep track the goal, domains, requests
dial_domains, dial_reqs = [], []
for dom, g in raw_dial["goal"].items():
if dom != "topic" and dom != "message" and g:
if g.get("reqt"): # request info. eg. postcode/address/phone
# normalize request slots
for i, req_slot in enumerate(g["reqt"]):
if ontology.normlize_slot_names.get(req_slot):
g["reqt"][i] = ontology.normlize_slot_names[req_slot]
dial_reqs.append(g["reqt"][i])
compressed_goal[dom] = g
if dom in ontology.all_domains:
dial_domains.append(dom)
dial_reqs = list(set(dial_reqs))
dial = {"goal": compressed_goal, "log": []}
single_turn = {}
constraint_dict = OrderedDict()
prev_constraint_dict = {}
prev_turn_domain = ["general"]
ordered_sysact_dict[fn] = {}
for turn_num, dial_turn in enumerate(raw_dial["log"]):
# for user turn, have text
# sys turn: text, belief states(metadata), dialog_act, span_info
dial_state = dial_turn["metadata"]
if not dial_state: # user
# delexicalize user utterance, either by annotation or by val_dict
u = " ".join(clean_text(dial_turn["text"]).split())
# NOTE: Commenting out delexicalisation because it is not used and
# breaks when I use generated user dialogues for some reason
# if dial_turn["span_info"]:
# u_delex = clean_text(self.delex_by_annotation(dial_turn))
# else:
# u_delex = self.delex_by_valdict(dial_turn["text"])
single_turn["user"] = u
# single_turn["user_delex"] = u_delex
else: # system
# delexicalize system response, either by annotation or by val_dict
if dial_turn["span_info"]:
s_delex = clean_text(self.delex_by_annotation(dial_turn))
else:
if not dial_turn["text"]:
print(fn)
s_delex = self.delex_by_valdict(dial_turn["text"])
single_turn["resp"] = s_delex
# get belief state, semi=informable/book=requestable, put into constraint_dict
for domain in dial_domains:
if not constraint_dict.get(domain):
constraint_dict[domain] = OrderedDict()
info_sv = dial_state[domain]["semi"]
for s, v in info_sv.items():
s, v = clean_slot_values(domain, s, v)
if len(v.split()) > 1:
v = " ".join([token.text for token in self.nlp(v)]).strip()
if v != "":
constraint_dict[domain][s] = v
book_sv = dial_state[domain]["book"]
for s, v in book_sv.items():
if s == "booked":
continue
s, v = clean_slot_values(domain, s, v)
if len(v.split()) > 1:
v = " ".join([token.text for token in self.nlp(v)]).strip()
if v != "":
constraint_dict[domain][s] = v
constraints = [] # list in format of [domain] slot value
cons_delex = []
turn_dom_bs = []
for domain, info_slots in constraint_dict.items():
if info_slots:
constraints.append("[" + domain + "]")
cons_delex.append("[" + domain + "]")
for slot, value in info_slots.items():
constraints.append(slot)
constraints.extend(value.split())
cons_delex.append(slot)
if domain not in prev_constraint_dict:
turn_dom_bs.append(domain)
elif prev_constraint_dict[domain] != constraint_dict[domain]:
turn_dom_bs.append(domain)
sys_act_dict = {}
turn_dom_da = set()
for act in dial_turn["dialog_act"]:
d, a = act.split("-") # split domain-act
turn_dom_da.add(d)
turn_dom_da = list(turn_dom_da)
if len(turn_dom_da) != 1 and "general" in turn_dom_da:
turn_dom_da.remove("general")
if len(turn_dom_da) != 1 and "booking" in turn_dom_da:
turn_dom_da.remove("booking")
# get turn domain
turn_domain = turn_dom_bs
for dom in turn_dom_da:
if dom != "booking" and dom not in turn_domain:
turn_domain.append(dom)
if not turn_domain:
turn_domain = prev_turn_domain
if len(turn_domain) == 2 and "general" in turn_domain:
turn_domain.remove("general")
if len(turn_domain) == 2:
if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
turn_domain = turn_domain[::-1]
# get system action
for dom in turn_domain:
sys_act_dict[dom] = {}
add_to_last_collect = []
booking_act_map = {"inform": "offerbook", "book": "offerbooked"}
for act, params in dial_turn["dialog_act"].items():
if act == "general-greet":
continue
d, a = act.split("-")
if d == "general" and d not in sys_act_dict:
sys_act_dict[d] = {}
if d == "booking":
d = turn_domain[0]
a = booking_act_map.get(a, a)
add_p = []
for param in params:
p = param[0]
if p == "none":
continue
elif ontology.da_abbr_to_slot_name.get(p):
p = ontology.da_abbr_to_slot_name[p]
if p not in add_p:
add_p.append(p)
add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False
if add_to_last:
add_to_last_collect.append((d, a, add_p))
else:
sys_act_dict[d][a] = add_p
for d, a, add_p in add_to_last_collect:
sys_act_dict[d][a] = add_p
for d in copy.copy(sys_act_dict):
acts = sys_act_dict[d]
if not acts:
del sys_act_dict[d]
if "inform" in acts and "offerbooked" in acts:
for s in sys_act_dict[d]["inform"]:
sys_act_dict[d]["offerbooked"].append(s)
del sys_act_dict[d]["inform"]
ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict
sys_act = []
if "general-greet" in dial_turn["dialog_act"]:
sys_act.extend(["[general]", "[greet]"])
for d, acts in sys_act_dict.items():
sys_act += ["[" + d + "]"]
for a, slots in acts.items():
self.unique_da[d + "-" + a] = 1
sys_act += ["[" + a + "]"]
sys_act += slots
# get db pointers
matnums = self.db.get_match_num(constraint_dict)
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
match = matnums[match_dom]
dbvec = self.db.addDBPointer(match_dom, match)
bkvec = self.db.addBookingPointer(dial_turn["dialog_act"])
# 4 database pointer for domains, 2 for booking
single_turn["pointer"] = ",".join([str(d) for d in dbvec + bkvec])
single_turn["match"] = str(match)
single_turn["constraint"] = " ".join(constraints)
single_turn["cons_delex"] = " ".join(cons_delex)
single_turn["sys_act"] = " ".join(sys_act)
single_turn["turn_num"] = len(dial["log"])
single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain])
prev_turn_domain = copy.deepcopy(turn_domain)
prev_constraint_dict = copy.deepcopy(constraint_dict)
if "user" in single_turn:
dial["log"].append(single_turn)
for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act:
self.vocab.add_word(t)
# NOTE: Commenting out delexicalisation because it is not used and
# breaks when I use generated user dialogues for some reason
# for t in single_turn["user_delex"].split():
# if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"):
# single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1])
# elif not self.vocab.has_word(t):
# self.vocab.add_word(t)
single_turn = {}
data[fn] = dial
# pprint(dial)
# if count == 20:
# break
self.vocab.construct()
self.vocab.save_vocab("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab")
with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_acts.json", "w") as f:
json.dump(ordered_sysact_dict, f, indent=2)
with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_act_type.json", "w") as f:
json.dump(self.unique_da, f, indent=2)
return data
if __name__ == "__main__":
db_paths = {
"attraction": "data/raw/UBAR/db/attraction_db.json",
"hospital": "data/raw/UBAR/db/hospital_db.json",
"hotel": "data/raw/UBAR/db/hotel_db.json",
"police": "data/raw/UBAR/db/police_db.json",
"restaurant": "data/raw/UBAR/db/restaurant_db.json",
"taxi": "data/raw/UBAR/db/taxi_db.json",
"train": "data/raw/UBAR/db/train_db.json",
}
get_db_values("data/raw/UBAR/db/value_set.json")
preprocess_db(db_paths)
dh = DataPreprocessor()
data = dh.preprocess_main()
if not os.path.exists("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed"):
os.mkdir("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed")
with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/data_for_ubar.json", "w") as f:
json.dump(data, f, indent=2)