Spaces:
Runtime error
Runtime error
import sys | |
import traceback | |
import pandas as pd | |
# from tqdm import tqdm | |
from UBAR_code.interaction import UBAR_interact | |
from user_model_code.interaction import multiwoz_interact | |
from UBAR_code.interaction.UBAR_interact import bcolors | |
# from tqdm import tqdm | |
from scripts.UBAR_code.interaction import UBAR_interact | |
from scripts.user_model_code.interaction import multiwoz_interact | |
from scripts.UBAR_code.interaction.UBAR_interact import bcolors | |
def instantiate_agents(): | |
UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2" | |
user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k" | |
sys_model = UBAR_interact.UbarSystemModel( | |
"UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml" | |
) | |
user_model = multiwoz_interact.NeuralAgent( | |
"user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml" | |
) | |
return sys_model, user_model | |
def read_multiwoz_data(): | |
""" | |
Read the multiwoz 2.0 raw data from the .json file | |
""" | |
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json" | |
df_raw_mwoz = pd.read_json(raw_mwoz_20_path) | |
return df_raw_mwoz | |
def load_test_val_lists(): | |
val_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json" | |
test_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json" | |
def main( | |
write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False | |
): | |
sys_model, user_model = instantiate_agents() | |
# TODO: move hardcoded vars into config file | |
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json" | |
user_utterances_out_path = "cambridge-masters-project/data/preprocessed/UBAR/user_utterances_from_simulator.txt" | |
logging_successes_path = "cambridge-masters-project/data/preprocessed/UBAR/logging_successes" | |
sys_model.print_intermediary_info = False | |
user_model.print_intermediary_info = False | |
df_raw_mwoz = pd.read_json(raw_mwoz_20_path) | |
if n_dialogues == "all": | |
n_dialogues = len(df_raw_mwoz.columns) | |
curr_dialogue_user_utterances_formatted = [] | |
print("Loading goals...") | |
goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues) | |
# Write column headers | |
if write_to_file: | |
with open(user_utterances_out_path, "w") as f: | |
f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n") | |
print("Loading data...") | |
df_mwoz_data = read_multiwoz_data() | |
val_list, test_list = load_test_val_lists() | |
successful_dialogues = 0 | |
total_dialogues_generated = 0 # train dialogues only | |
for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)): | |
if log_successes: | |
# log successful_dialogues to logging_successes_path every 100 dialogues | |
if dialogue_idx % 100 == 0: | |
with open(logging_successes_path, "w") as f: | |
f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated)) | |
curr_dialogue_user_utterances_formatted = [] | |
if train_only: | |
if dialogue_filename in val_list or dialogue_filename in test_list: | |
continue | |
total_dialogues_generated += 1 | |
print("Dialogue: {}".format(dialogue_filename)) | |
# There are occasionally exceptions thrown from one of the agents, usually the user | |
# In this case we simply continue to the next dialogue | |
try: | |
# Reset state after each dialogue | |
sys_model.init_session() | |
user_model.init_session(ini_goal=goal) | |
sys_response = "" | |
for turn_idx in range(50): | |
# Turn idx in this case represents the turn as one user utterance AND one system response | |
usr_response_raw_data_idx = turn_idx * 2 | |
sys_response_raw_data_idx = turn_idx * 2 + 1 | |
user_utterance = user_model.response(sys_response) | |
print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance) | |
if write_to_file: | |
user_utterance = user_utterance.replace("\n", " ") | |
curr_dialogue_user_utterances_formatted.append( | |
str(dialogue_idx) | |
+ "\t" | |
+ dialogue_filename | |
+ "\t" | |
+ str(usr_response_raw_data_idx) | |
+ "\t" | |
+ user_utterance | |
+ "\n" | |
) | |
if user_model.is_terminated(): | |
successful_dialogues += 1 | |
print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC) | |
print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n") | |
if write_to_file: | |
# Write whole dialogue to file | |
with open(user_utterances_out_path, "a") as f: | |
for line in curr_dialogue_user_utterances_formatted: | |
f.write(line) | |
break | |
# Next turn materials | |
if ground_truth_system_responses: | |
# If we are at the end of the ground truth dialogues | |
if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx: | |
print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC) | |
print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n") | |
break | |
sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"] | |
else: | |
sys_response = sys_model.response(user_utterance, turn_idx) | |
capitalised_sys_response = sys_response[0].upper() + sys_response[1:] | |
print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response) | |
except Exception: | |
print(bcolors.RED + "*" * 30 + bcolors.ENDC) | |
print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC) | |
print(bcolors.RED + "*" * 30 + bcolors.ENDC) | |
traceback.print_exc() | |
continue | |
print("Successful dialogues: {}".format(successful_dialogues)) | |
print("Total dialogues: {}".format(n_dialogues)) | |
print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues)) | |
if __name__ == "__main__": | |
# TODO: move parameters to config file | |
# Fix the hacky mess below | |
ground_truth_system_responses = sys.argv[1] | |
if ground_truth_system_responses == "False": | |
ground_truth_system_responses = False | |
else: | |
ground_truth_system_responses = True | |
main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses) | |