alistairmcleay's picture
Added dialogue system code
b16a132
raw
history blame contribute delete
No virus
7.55 kB
import copy
import json
import os
import re
import zipfile
from collections import OrderedDict
from crazyneuraluser.UBAR_code.ontology import all_domains
# 2.0
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data.json"
save_path = "data/interim/gen_usr_utts/multi-woz-analysis/"
save_path_exp = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
# 2.1
# data_path = 'data/raw/UBAR/MultiWOZ_2.1/'
# save_path = 'data/interim/multi-woz-2.1-analysis/'
# save_path_exp = 'data/preprocessed/multi-woz-2.1-processed/'
data_file = "data.json"
domains = all_domains
# all_domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital']
def analysis():
compressed_raw_data = {}
goal_of_dials = {}
req_slots = {}
info_slots = {}
dom_count = {}
dom_fnlist = {}
all_domain_specific_slots = set()
for domain in domains:
req_slots[domain] = []
info_slots[domain] = []
# archive = zipfile.ZipFile(data_path + data_file + ".zip", "r")
# data = archive.open(data_file, "r").read().decode("utf-8").lower()
data = open(data_path, "r").read().lower()
ref_nos = list(set(re.findall(r"\"reference\"\: \"(\w+)\"", data)))
data = json.loads(data)
for fn, dial in data.items():
goals = dial["goal"]
logs = dial["log"]
# get compressed_raw_data and goal_of_dials
compressed_raw_data[fn] = {"goal": {}, "log": []}
goal_of_dials[fn] = {}
for dom, goal in goals.items(): # get goal of domains that are in demmand
if dom != "topic" and dom != "message" and goal:
compressed_raw_data[fn]["goal"][dom] = goal
goal_of_dials[fn][dom] = goal
for turn in logs:
if not turn["metadata"]: # user's turn
compressed_raw_data[fn]["log"].append({"text": turn["text"]})
else: # system's turn
meta = turn["metadata"]
turn_dict = {"text": turn["text"], "metadata": {}}
for (
dom,
book_semi,
) in meta.items(): # for every domain, sys updates "book" and "semi"
book, semi = book_semi["book"], book_semi["semi"]
record = False
for (
slot,
value,
) in book.items(): # record indicates non-empty-book domain
if value not in ["", []]:
record = True
if record:
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["book"] = book # add that domain's book
record = False
for (
slot,
value,
) in semi.items(): # here record indicates non-empty-semi domain
if value not in ["", []]:
record = True
break
if record:
for s, v in copy.deepcopy(semi).items():
if v == "not mentioned":
del semi[s]
if not turn_dict["metadata"].get(dom):
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["semi"] = semi # add that domain's semi
compressed_raw_data[fn]["log"].append(turn_dict) # add to log the compressed turn_dict
# get domain statistics
dial_type = (
"multi" if "mul" in fn or "MUL" in fn else "single"
) # determine the dialog's type: sinle or multi
if fn in ["pmul2756.json", "pmul4958.json", "pmul3599.json"]:
dial_type = "single"
dial_domains = [dom for dom in domains if goals[dom]] # domains that are in demmand
dom_str = ""
for dom in dial_domains:
if not dom_count.get(dom + "_" + dial_type): # count each domain type, with single or multi considered
dom_count[dom + "_" + dial_type] = 1
else:
dom_count[dom + "_" + dial_type] += 1
if not dom_fnlist.get(dom + "_" + dial_type): # keep track the file number of each domain type
dom_fnlist[dom + "_" + dial_type] = [fn]
else:
dom_fnlist[dom + "_" + dial_type].append(fn)
dom_str += "%s_" % dom
dom_str = dom_str[:-1] # substract the last char in dom_str
if dial_type == "multi": # count multi-domains
if not dom_count.get(dom_str):
dom_count[dom_str] = 1
else:
dom_count[dom_str] += 1
if not dom_fnlist.get(dom_str):
dom_fnlist[dom_str] = [fn]
else:
dom_fnlist[dom_str].append(fn)
######
# get informable and requestable slots statistics
for domain in domains:
info_ss = goals[domain].get("info", {})
book_ss = goals[domain].get("book", {})
req_ss = goals[domain].get("reqt", {})
for info_s in info_ss:
all_domain_specific_slots.add(domain + "-" + info_s)
if info_s not in info_slots[domain]:
info_slots[domain] += [info_s]
for book_s in book_ss:
if "book_" + book_s not in info_slots[domain] and book_s not in [
"invalid",
"pre_invalid",
]:
all_domain_specific_slots.add(domain + "-" + book_s)
info_slots[domain] += ["book_" + book_s]
for req_s in req_ss:
if req_s not in req_slots[domain]:
req_slots[domain] += [req_s]
# result statistics
if not os.path.exists(save_path):
os.mkdir(save_path)
if not os.path.exists(save_path_exp):
os.mkdir(save_path_exp)
with open(save_path + "req_slots.json", "w") as sf:
json.dump(req_slots, sf, indent=2)
with open(save_path + "info_slots.json", "w") as sf:
json.dump(info_slots, sf, indent=2)
with open(save_path + "all_domain_specific_info_slots.json", "w") as sf:
json.dump(list(all_domain_specific_slots), sf, indent=2)
print("slot num:", len(list(all_domain_specific_slots)))
with open(save_path + "goal_of_each_dials.json", "w") as sf:
json.dump(goal_of_dials, sf, indent=2)
with open(save_path + "compressed_data.json", "w") as sf:
json.dump(compressed_raw_data, sf, indent=2)
with open(save_path + "domain_count.json", "w") as sf:
single_count = [d for d in dom_count.items() if "single" in d[0]]
multi_count = [d for d in dom_count.items() if "multi" in d[0]]
other_count = [d for d in dom_count.items() if "multi" not in d[0] and "single" not in d[0]]
dom_count_od = OrderedDict(single_count + multi_count + other_count)
json.dump(dom_count_od, sf, indent=2)
with open(save_path_exp + "reference_no.json", "w") as sf:
json.dump(ref_nos, sf, indent=2)
with open(save_path_exp + "domain_files.json", "w") as sf:
json.dump(dom_fnlist, sf, indent=2)
if __name__ == "__main__":
analysis()