alistairmcleay's picture
Added dialogue system code
b16a132
raw history blame
No virus
7.55 kB
import copy
import json
import os
import re
import zipfile
from collections import OrderedDict
from crazyneuraluser.UBAR_code.ontology import all_domains
# 2.0
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data.json"
save_path = "data/interim/gen_usr_utts/multi-woz-analysis/"
save_path_exp = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
# 2.1
# data_path = 'data/raw/UBAR/MultiWOZ_2.1/'
# save_path = 'data/interim/multi-woz-2.1-analysis/'
# save_path_exp = 'data/preprocessed/multi-woz-2.1-processed/'
data_file = "data.json"
domains = all_domains
# all_domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital']
def analysis():
compressed_raw_data = {}
goal_of_dials = {}
req_slots = {}
info_slots = {}
dom_count = {}
dom_fnlist = {}
all_domain_specific_slots = set()
for domain in domains:
req_slots[domain] = []
info_slots[domain] = []
# archive = zipfile.ZipFile(data_path + data_file + ".zip", "r")
# data = archive.open(data_file, "r").read().decode("utf-8").lower()
data = open(data_path, "r").read().lower()
ref_nos = list(set(re.findall(r"\"reference\"\: \"(\w+)\"", data)))
data = json.loads(data)
for fn, dial in data.items():
goals = dial["goal"]
logs = dial["log"]
# get compressed_raw_data and goal_of_dials
compressed_raw_data[fn] = {"goal": {}, "log": []}
goal_of_dials[fn] = {}
for dom, goal in goals.items(): # get goal of domains that are in demmand
if dom != "topic" and dom != "message" and goal:
compressed_raw_data[fn]["goal"][dom] = goal
goal_of_dials[fn][dom] = goal
for turn in logs:
if not turn["metadata"]: # user's turn
compressed_raw_data[fn]["log"].append({"text": turn["text"]})
else: # system's turn
meta = turn["metadata"]
turn_dict = {"text": turn["text"], "metadata": {}}
for (
dom,
book_semi,
) in meta.items(): # for every domain, sys updates "book" and "semi"
book, semi = book_semi["book"], book_semi["semi"]
record = False
for (
slot,
value,
) in book.items(): # record indicates non-empty-book domain
if value not in ["", []]:
record = True
if record:
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["book"] = book # add that domain's book
record = False
for (
slot,
value,
) in semi.items(): # here record indicates non-empty-semi domain
if value not in ["", []]:
record = True
break
if record:
for s, v in copy.deepcopy(semi).items():
if v == "not mentioned":
del semi[s]
if not turn_dict["metadata"].get(dom):
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["semi"] = semi # add that domain's semi
compressed_raw_data[fn]["log"].append(turn_dict) # add to log the compressed turn_dict
# get domain statistics
dial_type = (
"multi" if "mul" in fn or "MUL" in fn else "single"
) # determine the dialog's type: sinle or multi
if fn in ["pmul2756.json", "pmul4958.json", "pmul3599.json"]:
dial_type = "single"
dial_domains = [dom for dom in domains if goals[dom]] # domains that are in demmand
dom_str = ""
for dom in dial_domains:
if not dom_count.get(dom + "_" + dial_type): # count each domain type, with single or multi considered
dom_count[dom + "_" + dial_type] = 1
else:
dom_count[dom + "_" + dial_type] += 1
if not dom_fnlist.get(dom + "_" + dial_type): # keep track the file number of each domain type
dom_fnlist[dom + "_" + dial_type] = [fn]
else:
dom_fnlist[dom + "_" + dial_type].append(fn)
dom_str += "%s_" % dom
dom_str = dom_str[:-1] # substract the last char in dom_str
if dial_type == "multi": # count multi-domains
if not dom_count.get(dom_str):
dom_count[dom_str] = 1
else:
dom_count[dom_str] += 1
if not dom_fnlist.get(dom_str):
dom_fnlist[dom_str] = [fn]
else:
dom_fnlist[dom_str].append(fn)
######
# get informable and requestable slots statistics
for domain in domains:
info_ss = goals[domain].get("info", {})
book_ss = goals[domain].get("book", {})
req_ss = goals[domain].get("reqt", {})
for info_s in info_ss:
all_domain_specific_slots.add(domain + "-" + info_s)
if info_s not in info_slots[domain]:
info_slots[domain] += [info_s]
for book_s in book_ss:
if "book_" + book_s not in info_slots[domain] and book_s not in [
"invalid",
"pre_invalid",
]:
all_domain_specific_slots.add(domain + "-" + book_s)
info_slots[domain] += ["book_" + book_s]
for req_s in req_ss:
if req_s not in req_slots[domain]:
req_slots[domain] += [req_s]
# result statistics
if not os.path.exists(save_path):
os.mkdir(save_path)
if not os.path.exists(save_path_exp):
os.mkdir(save_path_exp)
with open(save_path + "req_slots.json", "w") as sf:
json.dump(req_slots, sf, indent=2)
with open(save_path + "info_slots.json", "w") as sf:
json.dump(info_slots, sf, indent=2)
with open(save_path + "all_domain_specific_info_slots.json", "w") as sf:
json.dump(list(all_domain_specific_slots), sf, indent=2)
print("slot num:", len(list(all_domain_specific_slots)))
with open(save_path + "goal_of_each_dials.json", "w") as sf:
json.dump(goal_of_dials, sf, indent=2)
with open(save_path + "compressed_data.json", "w") as sf:
json.dump(compressed_raw_data, sf, indent=2)
with open(save_path + "domain_count.json", "w") as sf:
single_count = [d for d in dom_count.items() if "single" in d[0]]
multi_count = [d for d in dom_count.items() if "multi" in d[0]]
other_count = [d for d in dom_count.items() if "multi" not in d[0] and "single" not in d[0]]
dom_count_od = OrderedDict(single_count + multi_count + other_count)
json.dump(dom_count_od, sf, indent=2)
with open(save_path_exp + "reference_no.json", "w") as sf:
json.dump(ref_nos, sf, indent=2)
with open(save_path_exp + "domain_files.json", "w") as sf:
json.dump(dom_fnlist, sf, indent=2)
if __name__ == "__main__":
analysis()