alistairmcleay's picture
Fixing goals bug in self-play module
2db6fc4
import random
import gradio as gr
import sys
import traceback
import pandas as pd
import gradio as gr
import json
import yaml
# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors
# Initialise agents
UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2"
user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k"
sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
)
user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
)
# Get goals
n_goals = 100
goals_path = "data/raw/UBAR/multi-woz/data.json"
print("Loading goals...")
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
# Initialise agent with first goal (can be incrememnted by user) for user simulator tab
curr_goal_idx = random.randint(0, n_goals - 1)
current_goal = goals[curr_goal_idx]
user_model.init_session(ini_goal=current_goal)
# Do the same initialisation but for the self-play tab
curr_sp_goal_idx = random.randint(0, n_goals - 1)
current_sp_goal = goals[curr_sp_goal_idx]
self_play_user_model.init_session(ini_goal=current_sp_goal)
# Get the responses for each agent and track conversation history
ds_history = []
us_history = []
self_play_history = []
def reset_ds_state():
ds_history.clear()
sys_model.init_session()
return ds_history
def reset_us_state():
us_history.clear()
user_model.init_session(ini_goal=current_goal)
return us_history
def reset_self_play_state():
self_play_history.clear()
self_play_sys_model.init_session()
self_play_user_model.init_session(ini_goal=current_sp_goal)
return self_play_history
def change_goal():
global curr_goal_idx
global current_goal
curr_goal_idx = random.randint(0, n_goals - 1)
current_goal = goals[curr_goal_idx]
us_history = reset_us_state()
current_goal_yaml = yaml.dump(current_goal, default_flow_style=False)
return current_goal_yaml, us_history
def change_sp_goal():
global curr_sp_goal_idx
global current_sp_goal
curr_sp_goal_idx = random.randint(0, n_goals - 1)
current_sp_goal = goals[curr_sp_goal_idx]
self_play_history = reset_self_play_state()
current_sp_goal_yaml = yaml.dump(current_sp_goal, default_flow_style=False)
return current_sp_goal_yaml, self_play_history
def ds_chatbot(user_utt):
turn_id = len(ds_history)
sys_response = sys_model.response(user_utt, turn_id)
sys_response = sys_response[0].upper() + sys_response[1:]
ds_history.append((user_utt, sys_response))
return ds_history
def us_chatbot(sys_response):
user_utt = user_model.response(sys_response)
us_history.append((sys_response, user_utt))
if user_model.is_terminated():
change_goal()
return us_history
def self_play():
if len(self_play_history) == 0:
sys_response = ""
else:
sys_response = self_play_history[-1][1]
user_utt = self_play_user_model.response(sys_response)
turn_id = len(self_play_history)
sys_response = self_play_sys_model.response(user_utt, turn_id)
sys_response = sys_response[0].upper() + sys_response[1:]
self_play_history.append((user_utt, sys_response))
if user_model.is_terminated():
change_goal()
return self_play_history
# Reset state upon client-side refresh
reset_ds_state()
reset_us_state()
reset_self_play_state()
# Initialise demo render
block = gr.Blocks()
with block:
gr.Markdown("# πŸ’¬ Jointly Optimized Task-Oriented Dialogue System And User Simulator πŸ’¬")
gr.Markdown(
"Created by [Alistair McLeay](https://alistairmcleay.com) for the [Masters in Machine Learning & Machine Intelligence at Cambridge University](https://www.mlmi.eng.cam.ac.uk/). <br/>\
Thank you to [Professor Bill Byrne](https://sites.google.com/view/bill-byrne/home) for his supervision and guidance. <br/> \
Thank you to [Andy Tseng](https://github.com/andy194673) and [Alex Coca](https://github.com/alexcoca) who provided code and guidance."
)
gr.Markdown(
"Both Systems are trained on the [MultiWOZ dataset](https://github.com/budzianowski/multiwoz). <br/> \
Supported domains are: <br> \
1. πŸš† Train, 2. 🏨 Hotel, 3. πŸš• Taxi, 4. πŸš“ Police, 5. 🏣 Restaurant, 6. πŸ—Ώ Attraction, 7. πŸ₯ Hospital."
)
gr.Markdown(
"**Please note:** <br> \
1. These systems are in development and are full of funny little bugs, as is this app. <br> \
2. If you refresh this page the conversation state will persist. To reset a conversion you need to click 'Reset Conversation' below."
)
with gr.Tabs():
with gr.TabItem("Dialogue System"):
gr.Markdown(
"This bot is a Task-Oriented Dialogue Systen. <br> \
You are the user. Go ahead and try to book a train, or a hotel etc."
)
with gr.Row():
ds_input_text = gr.inputs.Textbox(
label="User Message", placeholder="I'd like to book a train from Cambridge to London"
)
ds_response = gr.outputs.Chatbot(label="Dialogue System Response")
ds_button = gr.Button("Submit Message")
reset_ds_button = gr.Button("Reset Conversation")
with gr.TabItem("User Simulator"):
gr.Markdown(
"This bot is a User Simulator. <br> \
You are the Task-Oriented Dialogue System. Your job is to help the user with their requests. <br> \
If you want the User Simulator to have a different goal press 'Generate New Goal'."
)
with gr.Row():
us_input_text = gr.inputs.Textbox(
label="Dialogue System Message", placeholder="How can I help you today?"
)
us_response = gr.outputs.Chatbot(label="User Simulator Response")
us_button = gr.Button("Submit Message")
reset_us_button = gr.Button("Reset Conversation")
new_goal_button = gr.Button("Generate New Goal")
current_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")
with gr.TabItem("Self-Play"):
gr.Markdown(
"In this case both the User Simulator and the Task-Oriented Dialogue System are agents. <br> \
Get them to interact by pressing 'Run Next Step'. <br> \
If you want the User Simulator to have a different goal press 'Generate New Goal'."
)
self_play_response = gr.outputs.Chatbot(label="Self-Play Output")
self_play_button = gr.Button("Run Next Step")
reset_self_play_button = gr.Button("Reset Conversation")
new_sp_goal_button = gr.Button("Generate New Goal")
current_sp_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")
gr.Markdown("Want to get in touch? [Email me](mailto:am@alistairmcleay.com)")
ds_button.click(ds_chatbot, ds_input_text, ds_response)
us_button.click(us_chatbot, us_input_text, us_response)
self_play_button.click(self_play, None, self_play_response)
new_goal_button.click(change_goal, None, [current_goal_yaml, us_response])
new_sp_goal_button.click(change_sp_goal, None, [current_sp_goal_yaml, self_play_response])
reset_ds_button.click(reset_ds_state, None, ds_response)
reset_us_button.click(reset_us_state, None, us_response)
reset_self_play_button.click(reset_self_play_state, None, self_play_response)
block.launch()