alistairmcleay's picture
Improving UI
db90364
raw history blame
No virus
7.24 kB
import random
import gradio as gr
import sys
import traceback
import pandas as pd
import gradio as gr
import json
import yaml
# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors
# Initialise agents
UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2"
user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k"
sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
)
user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
)
# Get goals
n_goals = 100
goals_path = "data/raw/UBAR/multi-woz/data.json"
print("Loading goals...")
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
# Initialise agent with first goal (can be incrememnted by user) for user simulator tab
curr_goal_idx = random.randint(0, n_goals - 1)
current_goal = goals[curr_goal_idx]
user_model.init_session(ini_goal=current_goal)
# Do the same initialisation but for the self-play tab
curr_sp_goal_idx = random.randint(0, n_goals - 1)
current_sp_goal = goals[curr_sp_goal_idx]
self_play_user_model.init_session(ini_goal=current_sp_goal)
# Get the responses for each agent and track conversation history
ds_history = []
us_history = []
self_play_history = []
def reset_ds_state():
ds_history.clear()
sys_model.init_session()
return ds_history
def reset_us_state():
us_history.clear()
user_model.init_session(ini_goal=current_goal)
return us_history
def reset_self_play_state():
self_play_history.clear()
self_play_sys_model.init_session()
self_play_user_model.init_session(ini_goal=current_goal)
return self_play_history
def change_goal():
global curr_goal_idx
curr_goal_idx = random.randint(0, n_goals - 1)
current_goal = goals[curr_goal_idx]
us_history = reset_us_state()
current_goal_yaml = yaml.dump(current_goal, default_flow_style=False)
return current_goal_yaml, us_history
def change_sp_goal():
global curr_sp_goal_idx
curr_sp_goal_idx = random.randint(0, n_goals - 1)
current_sp_goal = goals[curr_sp_goal_idx]
self_play_history = reset_self_play_state()
current_sp_goal_yaml = yaml.dump(current_sp_goal, default_flow_style=False)
return current_sp_goal_yaml, self_play_history
def ds_chatbot(user_utt):
turn_id = len(ds_history)
sys_response = sys_model.response(user_utt, turn_id)
ds_history.append((user_utt, sys_response))
return ds_history
def us_chatbot(sys_response):
user_utt = user_model.response(sys_response)
us_history.append((sys_response, user_utt))
if user_model.is_terminated():
change_goal()
return us_history
def self_play():
if len(self_play_history) == 0:
sys_response = ""
else:
sys_response = self_play_history[-1][1]
user_utt = self_play_user_model.response(sys_response)
turn_id = len(self_play_history)
sys_response = self_play_sys_model.response(user_utt, turn_id)
self_play_history.append((user_utt, sys_response))
if user_model.is_terminated():
change_goal()
return self_play_history
# Initialise demo render
block = gr.Blocks()
with block:
gr.Markdown("#πŸ’¬ Jointly Optimized Task-Oriented Dialogue System And User Simulator")
gr.Markdown(
"Created by [Alistair McLeay](https://alistairmcleay.com) for the [Masters in Machine Learning & Machine Intelligence at Cambridge University](https://www.mlmi.eng.cam.ac.uk/). <br/>\
Thank you to [Professor Bill Byrne](https://sites.google.com/view/bill-byrne/home) for his supervision and guidance. <br/> \
Thank you to [Andy Tseng](https://github.com/andy194673) and [Alex Coca](https://github.com/alexcoca) who provided code and guidance."
)
gr.Markdown(
"*Both Systems are trained on the [MultiWOZ dataset](https://github.com/budzianowski/multiwoz). <br/> \
Supported domains are 1. πŸš† Train, 2. 🏨 Hotel, 3. πŸš• Taxi, 4. πŸš“ Police, 5. 🏣 Restaurant, 6. πŸ—Ώ Attraction, 7. πŸ₯ Hospital.*"
)
with gr.Tabs():
with gr.TabItem("Dialogue System"):
gr.Markdown(
"This bot is a Task-Oriented Dialogue Systen. \nYou are the user. Go ahead and try to book a train, or a hotel etc."
)
with gr.Row():
ds_input_text = gr.inputs.Textbox(
label="User Message", placeholder="I'd like to book a train from Cambridge to London"
)
ds_response = gr.outputs.Chatbot(label="Dialogue System Response")
ds_button = gr.Button("Submit Message")
reset_ds_button = gr.Button("Reset Conversation")
with gr.TabItem("User Simulator"):
gr.Markdown(
"This bot is a User Simulator. \nYou are the Task-Oriented Dialogue System. Your job is to help the user with their requests. \nIf you want the User Simulator to have a different goal press 'Generate New Goal'"
)
with gr.Row():
us_input_text = gr.inputs.Textbox(
label="Dialogue System Message", placeholder="How can I help you today?"
)
us_response = gr.outputs.Chatbot(label="User Simulator Response")
us_button = gr.Button("Submit Message")
reset_us_button = gr.Button("Reset Conversation")
new_goal_button = gr.Button("Generate New Goal")
current_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")
with gr.TabItem("Self-Play"):
gr.Markdown(
"In this case both the User Simulator and the Task-Oriented Dialogue System are agents. \nGet them to interact by pressing 'Run Next Step' \nIf you want the User Simulator to have a different goal press 'Generate New Goal'"
)
self_play_response = gr.outputs.Chatbot(label="Self-Play Output")
self_play_button = gr.Button("Run Next Step")
reset_self_play_button = gr.Button("Reset Conversation")
new_sp_goal_button = gr.Button("Generate New Goal")
current_sp_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")
gr.Markdown("## System Architecture Overview")
gr.Markdown(
"![System Architecture](https://huggingface.co/spaces/alistairmcleay/cambridge-masters-project/tree/main/system_architecture.png)"
)
ds_button.click(ds_chatbot, ds_input_text, ds_response)
us_button.click(us_chatbot, us_input_text, us_response)
self_play_button.click(self_play, None, self_play_response)
new_goal_button.click(change_goal, None, [current_goal_yaml, us_response])
new_sp_goal_button.click(change_sp_goal, None, [current_sp_goal_yaml, self_play_response])
reset_ds_button.click(reset_ds_state, None, ds_response)
reset_us_button.click(reset_us_state, None, us_response)
reset_self_play_button.click(reset_self_play_state, None, self_play_response)
block.launch(share=True)