Spaces:
Runtime error
Runtime error
File size: 7,549 Bytes
b16a132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import copy
import json
import os
import re
import zipfile
from collections import OrderedDict
from crazyneuraluser.UBAR_code.ontology import all_domains
# 2.0
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data.json"
save_path = "data/interim/gen_usr_utts/multi-woz-analysis/"
save_path_exp = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
# 2.1
# data_path = 'data/raw/UBAR/MultiWOZ_2.1/'
# save_path = 'data/interim/multi-woz-2.1-analysis/'
# save_path_exp = 'data/preprocessed/multi-woz-2.1-processed/'
data_file = "data.json"
domains = all_domains
# all_domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital']
def analysis():
compressed_raw_data = {}
goal_of_dials = {}
req_slots = {}
info_slots = {}
dom_count = {}
dom_fnlist = {}
all_domain_specific_slots = set()
for domain in domains:
req_slots[domain] = []
info_slots[domain] = []
# archive = zipfile.ZipFile(data_path + data_file + ".zip", "r")
# data = archive.open(data_file, "r").read().decode("utf-8").lower()
data = open(data_path, "r").read().lower()
ref_nos = list(set(re.findall(r"\"reference\"\: \"(\w+)\"", data)))
data = json.loads(data)
for fn, dial in data.items():
goals = dial["goal"]
logs = dial["log"]
# get compressed_raw_data and goal_of_dials
compressed_raw_data[fn] = {"goal": {}, "log": []}
goal_of_dials[fn] = {}
for dom, goal in goals.items(): # get goal of domains that are in demmand
if dom != "topic" and dom != "message" and goal:
compressed_raw_data[fn]["goal"][dom] = goal
goal_of_dials[fn][dom] = goal
for turn in logs:
if not turn["metadata"]: # user's turn
compressed_raw_data[fn]["log"].append({"text": turn["text"]})
else: # system's turn
meta = turn["metadata"]
turn_dict = {"text": turn["text"], "metadata": {}}
for (
dom,
book_semi,
) in meta.items(): # for every domain, sys updates "book" and "semi"
book, semi = book_semi["book"], book_semi["semi"]
record = False
for (
slot,
value,
) in book.items(): # record indicates non-empty-book domain
if value not in ["", []]:
record = True
if record:
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["book"] = book # add that domain's book
record = False
for (
slot,
value,
) in semi.items(): # here record indicates non-empty-semi domain
if value not in ["", []]:
record = True
break
if record:
for s, v in copy.deepcopy(semi).items():
if v == "not mentioned":
del semi[s]
if not turn_dict["metadata"].get(dom):
turn_dict["metadata"][dom] = {}
turn_dict["metadata"][dom]["semi"] = semi # add that domain's semi
compressed_raw_data[fn]["log"].append(turn_dict) # add to log the compressed turn_dict
# get domain statistics
dial_type = (
"multi" if "mul" in fn or "MUL" in fn else "single"
) # determine the dialog's type: sinle or multi
if fn in ["pmul2756.json", "pmul4958.json", "pmul3599.json"]:
dial_type = "single"
dial_domains = [dom for dom in domains if goals[dom]] # domains that are in demmand
dom_str = ""
for dom in dial_domains:
if not dom_count.get(dom + "_" + dial_type): # count each domain type, with single or multi considered
dom_count[dom + "_" + dial_type] = 1
else:
dom_count[dom + "_" + dial_type] += 1
if not dom_fnlist.get(dom + "_" + dial_type): # keep track the file number of each domain type
dom_fnlist[dom + "_" + dial_type] = [fn]
else:
dom_fnlist[dom + "_" + dial_type].append(fn)
dom_str += "%s_" % dom
dom_str = dom_str[:-1] # substract the last char in dom_str
if dial_type == "multi": # count multi-domains
if not dom_count.get(dom_str):
dom_count[dom_str] = 1
else:
dom_count[dom_str] += 1
if not dom_fnlist.get(dom_str):
dom_fnlist[dom_str] = [fn]
else:
dom_fnlist[dom_str].append(fn)
######
# get informable and requestable slots statistics
for domain in domains:
info_ss = goals[domain].get("info", {})
book_ss = goals[domain].get("book", {})
req_ss = goals[domain].get("reqt", {})
for info_s in info_ss:
all_domain_specific_slots.add(domain + "-" + info_s)
if info_s not in info_slots[domain]:
info_slots[domain] += [info_s]
for book_s in book_ss:
if "book_" + book_s not in info_slots[domain] and book_s not in [
"invalid",
"pre_invalid",
]:
all_domain_specific_slots.add(domain + "-" + book_s)
info_slots[domain] += ["book_" + book_s]
for req_s in req_ss:
if req_s not in req_slots[domain]:
req_slots[domain] += [req_s]
# result statistics
if not os.path.exists(save_path):
os.mkdir(save_path)
if not os.path.exists(save_path_exp):
os.mkdir(save_path_exp)
with open(save_path + "req_slots.json", "w") as sf:
json.dump(req_slots, sf, indent=2)
with open(save_path + "info_slots.json", "w") as sf:
json.dump(info_slots, sf, indent=2)
with open(save_path + "all_domain_specific_info_slots.json", "w") as sf:
json.dump(list(all_domain_specific_slots), sf, indent=2)
print("slot num:", len(list(all_domain_specific_slots)))
with open(save_path + "goal_of_each_dials.json", "w") as sf:
json.dump(goal_of_dials, sf, indent=2)
with open(save_path + "compressed_data.json", "w") as sf:
json.dump(compressed_raw_data, sf, indent=2)
with open(save_path + "domain_count.json", "w") as sf:
single_count = [d for d in dom_count.items() if "single" in d[0]]
multi_count = [d for d in dom_count.items() if "multi" in d[0]]
other_count = [d for d in dom_count.items() if "multi" not in d[0] and "single" not in d[0]]
dom_count_od = OrderedDict(single_count + multi_count + other_count)
json.dump(dom_count_od, sf, indent=2)
with open(save_path_exp + "reference_no.json", "w") as sf:
json.dump(ref_nos, sf, indent=2)
with open(save_path_exp + "domain_files.json", "w") as sf:
json.dump(dom_fnlist, sf, indent=2)
if __name__ == "__main__":
analysis()
|