File size: 7,059 Bytes
b16a132
 
 
 
 
 
 
 
 
 
2936a70
 
 
 
 
 
b16a132
 
2936a70
 
b16a132
 
2936a70
b16a132
 
 
2936a70
b16a132
 
 
 
 
 
 
 
 
2936a70
b16a132
 
 
 
 
2936a70
 
b16a132
 
 
 
 
 
 
 
2936a70
 
 
b16a132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import sys
import traceback
import pandas as pd

# from tqdm import tqdm
from UBAR_code.interaction import UBAR_interact
from user_model_code.interaction import multiwoz_interact
from UBAR_code.interaction.UBAR_interact import bcolors


# from tqdm import tqdm
from scripts.UBAR_code.interaction import UBAR_interact
from scripts.user_model_code.interaction import multiwoz_interact
from scripts.UBAR_code.interaction.UBAR_interact import bcolors


def instantiate_agents():

    UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2"
    user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k"

    sys_model = UBAR_interact.UbarSystemModel(
        "UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml"
    )

    user_model = multiwoz_interact.NeuralAgent(
        "user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml"
    )

    return sys_model, user_model


def read_multiwoz_data():
    """
    Read the multiwoz 2.0 raw data from the .json file
    """
    raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
    df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
    return df_raw_mwoz


def load_test_val_lists():
    val_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json"
    test_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json"


def main(
    write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
):
    sys_model, user_model = instantiate_agents()

    # TODO: move hardcoded vars into config file
    raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
    user_utterances_out_path = "cambridge-masters-project/data/preprocessed/UBAR/user_utterances_from_simulator.txt"
    logging_successes_path = "cambridge-masters-project/data/preprocessed/UBAR/logging_successes"
    sys_model.print_intermediary_info = False
    user_model.print_intermediary_info = False

    df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
    if n_dialogues == "all":
        n_dialogues = len(df_raw_mwoz.columns)

    curr_dialogue_user_utterances_formatted = []

    print("Loading goals...")
    goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)

    # Write column headers
    if write_to_file:
        with open(user_utterances_out_path, "w") as f:
            f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")

    print("Loading data...")
    df_mwoz_data = read_multiwoz_data()
    val_list, test_list = load_test_val_lists()

    successful_dialogues = 0
    total_dialogues_generated = 0  # train dialogues only
    for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
        if log_successes:
            # log successful_dialogues to logging_successes_path every 100 dialogues
            if dialogue_idx % 100 == 0:
                with open(logging_successes_path, "w") as f:
                    f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))

        curr_dialogue_user_utterances_formatted = []
        if train_only:
            if dialogue_filename in val_list or dialogue_filename in test_list:
                continue

        total_dialogues_generated += 1
        print("Dialogue: {}".format(dialogue_filename))

        # There are occasionally exceptions thrown from one of the agents, usually the user
        # In this case we simply continue to the next dialogue
        try:
            # Reset state after each dialogue
            sys_model.init_session()
            user_model.init_session(ini_goal=goal)
            sys_response = ""

            for turn_idx in range(50):
                # Turn idx in this case represents the turn as one user utterance AND one system response
                usr_response_raw_data_idx = turn_idx * 2
                sys_response_raw_data_idx = turn_idx * 2 + 1

                user_utterance = user_model.response(sys_response)
                print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)

                if write_to_file:
                    user_utterance = user_utterance.replace("\n", " ")
                    curr_dialogue_user_utterances_formatted.append(
                        str(dialogue_idx)
                        + "\t"
                        + dialogue_filename
                        + "\t"
                        + str(usr_response_raw_data_idx)
                        + "\t"
                        + user_utterance
                        + "\n"
                    )

                if user_model.is_terminated():
                    successful_dialogues += 1
                    print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
                    print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
                    if write_to_file:
                        # Write whole dialogue to file
                        with open(user_utterances_out_path, "a") as f:
                            for line in curr_dialogue_user_utterances_formatted:
                                f.write(line)
                    break

                # Next turn materials
                if ground_truth_system_responses:
                    # If we are at the end of the ground truth dialogues
                    if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
                        print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
                        print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
                        break
                    sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
                else:
                    sys_response = sys_model.response(user_utterance, turn_idx)
                    capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
                print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)

        except Exception:
            print(bcolors.RED + "*" * 30 + bcolors.ENDC)
            print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
            print(bcolors.RED + "*" * 30 + bcolors.ENDC)
            traceback.print_exc()
            continue

    print("Successful dialogues: {}".format(successful_dialogues))
    print("Total dialogues: {}".format(n_dialogues))
    print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))


if __name__ == "__main__":
    # TODO: move parameters to config file
    # Fix the hacky mess below
    ground_truth_system_responses = sys.argv[1]
    if ground_truth_system_responses == "False":
        ground_truth_system_responses = False
    else:
        ground_truth_system_responses = True
    main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses)