cambridge-masters-project / scripts /simulate_interaction.py
alistairmcleay's picture
Put models on model hub
2936a70
raw
history blame
7.06 kB
import sys
import traceback
import pandas as pd
# from tqdm import tqdm
from UBAR_code.interaction import UBAR_interact
from user_model_code.interaction import multiwoz_interact
from UBAR_code.interaction.UBAR_interact import bcolors
# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors
def instantiate_agents():
UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2"
user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k"
sys_model = UBAR_interact.UbarSystemModel(
"UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml"
)
user_model = multiwoz_interact.NeuralAgent(
"user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml"
)
return sys_model, user_model
def read_multiwoz_data():
"""
Read the multiwoz 2.0 raw data from the .json file
"""
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
return df_raw_mwoz
def load_test_val_lists():
val_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json"
test_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json"
def main(
write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
):
sys_model, user_model = instantiate_agents()
# TODO: move hardcoded vars into config file
raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
user_utterances_out_path = "cambridge-masters-project/data/preprocessed/UBAR/user_utterances_from_simulator.txt"
logging_successes_path = "cambridge-masters-project/data/preprocessed/UBAR/logging_successes"
sys_model.print_intermediary_info = False
user_model.print_intermediary_info = False
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
if n_dialogues == "all":
n_dialogues = len(df_raw_mwoz.columns)
curr_dialogue_user_utterances_formatted = []
print("Loading goals...")
goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
# Write column headers
if write_to_file:
with open(user_utterances_out_path, "w") as f:
f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
print("Loading data...")
df_mwoz_data = read_multiwoz_data()
val_list, test_list = load_test_val_lists()
successful_dialogues = 0
total_dialogues_generated = 0 # train dialogues only
for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
if log_successes:
# log successful_dialogues to logging_successes_path every 100 dialogues
if dialogue_idx % 100 == 0:
with open(logging_successes_path, "w") as f:
f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
curr_dialogue_user_utterances_formatted = []
if train_only:
if dialogue_filename in val_list or dialogue_filename in test_list:
continue
total_dialogues_generated += 1
print("Dialogue: {}".format(dialogue_filename))
# There are occasionally exceptions thrown from one of the agents, usually the user
# In this case we simply continue to the next dialogue
try:
# Reset state after each dialogue
sys_model.init_session()
user_model.init_session(ini_goal=goal)
sys_response = ""
for turn_idx in range(50):
# Turn idx in this case represents the turn as one user utterance AND one system response
usr_response_raw_data_idx = turn_idx * 2
sys_response_raw_data_idx = turn_idx * 2 + 1
user_utterance = user_model.response(sys_response)
print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
if write_to_file:
user_utterance = user_utterance.replace("\n", " ")
curr_dialogue_user_utterances_formatted.append(
str(dialogue_idx)
+ "\t"
+ dialogue_filename
+ "\t"
+ str(usr_response_raw_data_idx)
+ "\t"
+ user_utterance
+ "\n"
)
if user_model.is_terminated():
successful_dialogues += 1
print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
if write_to_file:
# Write whole dialogue to file
with open(user_utterances_out_path, "a") as f:
for line in curr_dialogue_user_utterances_formatted:
f.write(line)
break
# Next turn materials
if ground_truth_system_responses:
# If we are at the end of the ground truth dialogues
if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
break
sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
else:
sys_response = sys_model.response(user_utterance, turn_idx)
capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
except Exception:
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
traceback.print_exc()
continue
print("Successful dialogues: {}".format(successful_dialogues))
print("Total dialogues: {}".format(n_dialogues))
print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
if __name__ == "__main__":
# TODO: move parameters to config file
# Fix the hacky mess below
ground_truth_system_responses = sys.argv[1]
if ground_truth_system_responses == "False":
ground_truth_system_responses = False
else:
ground_truth_system_responses = True
main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses)