Spaces:
Runtime error
Runtime error
File size: 7,059 Bytes
b16a132 2936a70 b16a132 2936a70 b16a132 2936a70 b16a132 2936a70 b16a132 2936a70 b16a132 2936a70 b16a132 2936a70 b16a132 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import sys
import traceback
import pandas as pd
# from tqdm import tqdm
from UBAR_code.interaction import UBAR_interact
from user_model_code.interaction import multiwoz_interact
from UBAR_code.interaction.UBAR_interact import bcolors
# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors
def instantiate_agents():
UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2"
user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k"
sys_model = UBAR_interact.UbarSystemModel(
"UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml"
)
user_model = multiwoz_interact.NeuralAgent(
"user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml"
)
return sys_model, user_model
def read_multiwoz_data():
"""
Read the multiwoz 2.0 raw data from the .json file
"""
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
return df_raw_mwoz
def load_test_val_lists():
val_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json"
test_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json"
def main(
write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
):
sys_model, user_model = instantiate_agents()
# TODO: move hardcoded vars into config file
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
user_utterances_out_path = "cambridge-masters-project/data/preprocessed/UBAR/user_utterances_from_simulator.txt"
logging_successes_path = "cambridge-masters-project/data/preprocessed/UBAR/logging_successes"
sys_model.print_intermediary_info = False
user_model.print_intermediary_info = False
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
if n_dialogues == "all":
n_dialogues = len(df_raw_mwoz.columns)
curr_dialogue_user_utterances_formatted = []
print("Loading goals...")
goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
# Write column headers
if write_to_file:
with open(user_utterances_out_path, "w") as f:
f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
print("Loading data...")
df_mwoz_data = read_multiwoz_data()
val_list, test_list = load_test_val_lists()
successful_dialogues = 0
total_dialogues_generated = 0 # train dialogues only
for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
if log_successes:
# log successful_dialogues to logging_successes_path every 100 dialogues
if dialogue_idx % 100 == 0:
with open(logging_successes_path, "w") as f:
f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
curr_dialogue_user_utterances_formatted = []
if train_only:
if dialogue_filename in val_list or dialogue_filename in test_list:
continue
total_dialogues_generated += 1
print("Dialogue: {}".format(dialogue_filename))
# There are occasionally exceptions thrown from one of the agents, usually the user
# In this case we simply continue to the next dialogue
try:
# Reset state after each dialogue
sys_model.init_session()
user_model.init_session(ini_goal=goal)
sys_response = ""
for turn_idx in range(50):
# Turn idx in this case represents the turn as one user utterance AND one system response
usr_response_raw_data_idx = turn_idx * 2
sys_response_raw_data_idx = turn_idx * 2 + 1
user_utterance = user_model.response(sys_response)
print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
if write_to_file:
user_utterance = user_utterance.replace("\n", " ")
curr_dialogue_user_utterances_formatted.append(
str(dialogue_idx)
+ "\t"
+ dialogue_filename
+ "\t"
+ str(usr_response_raw_data_idx)
+ "\t"
+ user_utterance
+ "\n"
)
if user_model.is_terminated():
successful_dialogues += 1
print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
if write_to_file:
# Write whole dialogue to file
with open(user_utterances_out_path, "a") as f:
for line in curr_dialogue_user_utterances_formatted:
f.write(line)
break
# Next turn materials
if ground_truth_system_responses:
# If we are at the end of the ground truth dialogues
if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
break
sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
else:
sys_response = sys_model.response(user_utterance, turn_idx)
capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
except Exception:
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
traceback.print_exc()
continue
print("Successful dialogues: {}".format(successful_dialogues))
print("Total dialogues: {}".format(n_dialogues))
print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
if __name__ == "__main__":
# TODO: move parameters to config file
# Fix the hacky mess below
ground_truth_system_responses = sys.argv[1]
if ground_truth_system_responses == "False":
ground_truth_system_responses = False
else:
ground_truth_system_responses = True
main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses)
|