File size: 7,804 Bytes
2936a70
6aeedda
b16a132
 
 
2936a70
 
 
 
6aeedda
b16a132
 
 
 
6aeedda
2936a70
9d33cfe
 
b16a132
2936a70
9d33cfe
2936a70
 
9d33cfe
2936a70
 
 
 
 
9d33cfe
2936a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db90364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2db6fc4
db90364
 
 
2936a70
 
7771dff
2936a70
 
db90364
2936a70
db90364
2936a70
 
 
 
7771dff
2936a70
 
db90364
2936a70
db90364
2936a70
 
 
 
 
057d34c
2936a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db90364
2936a70
 
db90364
057d34c
2936a70
 
 
 
 
 
 
 
 
d75062c
 
 
 
 
 
2936a70
 
 
 
057d34c
db90364
 
 
 
 
 
05c3251
 
 
db90364
8a1bda8
 
 
 
 
2936a70
 
 
384c151
 
2936a70
 
 
 
 
 
 
db90364
2936a70
 
 
384c151
 
 
2936a70
 
 
 
 
 
 
db90364
 
 
2936a70
 
 
384c151
 
 
2936a70
db90364
2936a70
db90364
 
 
 
b6db322
 
2936a70
 
 
db90364
 
 
 
 
2936a70
8a1bda8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import random
import gradio as gr
import sys
import traceback
import pandas as pd
import gradio as gr
import json

import yaml

# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors

# Initialise agents
UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2"
user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k"

sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
    "UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
)
user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
    "user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
)


# Get goals
n_goals = 100
goals_path = "data/raw/UBAR/multi-woz/data.json"
print("Loading goals...")
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)

# Initialise agent with first goal (can be incrememnted by user) for user simulator tab
curr_goal_idx = random.randint(0, n_goals - 1)
current_goal = goals[curr_goal_idx]
user_model.init_session(ini_goal=current_goal)

# Do the same initialisation but for the self-play tab
curr_sp_goal_idx = random.randint(0, n_goals - 1)
current_sp_goal = goals[curr_sp_goal_idx]
self_play_user_model.init_session(ini_goal=current_sp_goal)

# Get the responses for each agent and track conversation history
ds_history = []
us_history = []
self_play_history = []


def reset_ds_state():
    ds_history.clear()
    sys_model.init_session()
    return ds_history


def reset_us_state():
    us_history.clear()
    user_model.init_session(ini_goal=current_goal)
    return us_history


def reset_self_play_state():
    self_play_history.clear()
    self_play_sys_model.init_session()
    self_play_user_model.init_session(ini_goal=current_sp_goal)
    return self_play_history


def change_goal():
    global curr_goal_idx
    global current_goal
    curr_goal_idx = random.randint(0, n_goals - 1)
    current_goal = goals[curr_goal_idx]
    us_history = reset_us_state()
    current_goal_yaml = yaml.dump(current_goal, default_flow_style=False)
    return current_goal_yaml, us_history


def change_sp_goal():
    global curr_sp_goal_idx
    global current_sp_goal
    curr_sp_goal_idx = random.randint(0, n_goals - 1)
    current_sp_goal = goals[curr_sp_goal_idx]
    self_play_history = reset_self_play_state()
    current_sp_goal_yaml = yaml.dump(current_sp_goal, default_flow_style=False)
    return current_sp_goal_yaml, self_play_history


def ds_chatbot(user_utt):
    turn_id = len(ds_history)
    sys_response = sys_model.response(user_utt, turn_id)
    sys_response = sys_response[0].upper() + sys_response[1:]
    ds_history.append((user_utt, sys_response))
    return ds_history


def us_chatbot(sys_response):
    user_utt = user_model.response(sys_response)
    us_history.append((sys_response, user_utt))
    if user_model.is_terminated():
        change_goal()
    return us_history


def self_play():
    if len(self_play_history) == 0:
        sys_response = ""
    else:
        sys_response = self_play_history[-1][1]

    user_utt = self_play_user_model.response(sys_response)

    turn_id = len(self_play_history)
    sys_response = self_play_sys_model.response(user_utt, turn_id)
    sys_response = sys_response[0].upper() + sys_response[1:]

    self_play_history.append((user_utt, sys_response))

    if user_model.is_terminated():
        change_goal()

    return self_play_history


# Reset state upon client-side refresh
reset_ds_state()
reset_us_state()
reset_self_play_state()


# Initialise demo render
block = gr.Blocks()

with block:
    gr.Markdown("# πŸ’¬ Jointly Optimized Task-Oriented Dialogue System And User Simulator πŸ’¬")
    gr.Markdown(
        "Created by [Alistair McLeay](https://alistairmcleay.com) for the [Masters in Machine Learning & Machine Intelligence at Cambridge University](https://www.mlmi.eng.cam.ac.uk/). <br/>\
        Thank you to [Professor Bill Byrne](https://sites.google.com/view/bill-byrne/home) for his supervision and guidance. <br/> \
        Thank you to [Andy Tseng](https://github.com/andy194673) and [Alex Coca](https://github.com/alexcoca) who provided code and guidance."
    )
    gr.Markdown(
        "Both Systems are trained on the [MultiWOZ dataset](https://github.com/budzianowski/multiwoz). <br/> \
        Supported domains are: <br> \
        1. πŸš† Train, 2. 🏨 Hotel, 3. πŸš• Taxi, 4. πŸš“ Police, 5. 🏣 Restaurant, 6. πŸ—Ώ Attraction, 7. πŸ₯ Hospital."
    )
    gr.Markdown(
        "**Please note:** <br> \
        1. These systems are in development and are full of funny little bugs, as is this app. <br> \
        2. If you refresh this page the conversation state will persist. To reset a conversion you need to click 'Reset Conversation' below."
    )
    with gr.Tabs():
        with gr.TabItem("Dialogue System"):
            gr.Markdown(
                "This bot is a Task-Oriented Dialogue Systen. <br> \
                You are the user. Go ahead and try to book a train, or a hotel etc."
            )
            with gr.Row():
                ds_input_text = gr.inputs.Textbox(
                    label="User Message", placeholder="I'd like to book a train from Cambridge to London"
                )
                ds_response = gr.outputs.Chatbot(label="Dialogue System Response")
            ds_button = gr.Button("Submit Message")
            reset_ds_button = gr.Button("Reset Conversation")

        with gr.TabItem("User Simulator"):
            gr.Markdown(
                "This bot is a User Simulator. <br> \
                You are the Task-Oriented Dialogue System. Your job is to help the user with their requests. <br> \
                If you want the User Simulator to have a different goal press 'Generate New Goal'."
            )
            with gr.Row():
                us_input_text = gr.inputs.Textbox(
                    label="Dialogue System Message", placeholder="How can I help you today?"
                )
                us_response = gr.outputs.Chatbot(label="User Simulator Response")
            us_button = gr.Button("Submit Message")
            reset_us_button = gr.Button("Reset Conversation")
            new_goal_button = gr.Button("Generate New Goal")
            current_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")

        with gr.TabItem("Self-Play"):
            gr.Markdown(
                "In this case both the User Simulator and the Task-Oriented Dialogue System are agents. <br> \
                Get them to interact by pressing 'Run Next Step'. <br> \
                If you want the User Simulator to have a different goal press 'Generate New Goal'."
            )
            self_play_response = gr.outputs.Chatbot(label="Self-Play Output")
            self_play_button = gr.Button("Run Next Step")
            reset_self_play_button = gr.Button("Reset Conversation")
            new_sp_goal_button = gr.Button("Generate New Goal")
            current_sp_goal_yaml = gr.outputs.Textbox(label="New Goal (YAML)")

    gr.Markdown("Want to get in touch? [Email me](mailto:am@alistairmcleay.com)")

    ds_button.click(ds_chatbot, ds_input_text, ds_response)
    us_button.click(us_chatbot, us_input_text, us_response)
    self_play_button.click(self_play, None, self_play_response)
    new_goal_button.click(change_goal, None, [current_goal_yaml, us_response])
    new_sp_goal_button.click(change_sp_goal, None, [current_sp_goal_yaml, self_play_response])
    reset_ds_button.click(reset_ds_state, None, ds_response)
    reset_us_button.click(reset_us_state, None, us_response)
    reset_self_play_button.click(reset_self_play_state, None, self_play_response)

block.launch()