Spaces:
Runtime error
Runtime error
alistairmcleay
commited on
Commit
•
9d33cfe
1
Parent(s):
94fe073
Fixing paths
Browse files
app.py
CHANGED
@@ -15,20 +15,20 @@ from scripts.UBAR_code.interaction.UBAR_interact import bcolors
|
|
15 |
|
16 |
|
17 |
# Initialise agents
|
18 |
-
UBAR_checkpoint_path = "
|
19 |
-
user_model_checkpoint_path = "
|
20 |
|
21 |
sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
|
22 |
-
"UBAR_sys_model", UBAR_checkpoint_path, "
|
23 |
)
|
24 |
user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
|
25 |
-
"user", user_model_checkpoint_path, "
|
26 |
)
|
27 |
|
28 |
|
29 |
# Get goals
|
30 |
n_goals = 100
|
31 |
-
goals_path = "
|
32 |
print("Loading goals...")
|
33 |
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
|
34 |
|
|
|
15 |
|
16 |
|
17 |
# Initialise agents
|
18 |
+
UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2"
|
19 |
+
user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k"
|
20 |
|
21 |
sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
|
22 |
+
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
|
23 |
)
|
24 |
user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
|
25 |
+
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
|
26 |
)
|
27 |
|
28 |
|
29 |
# Get goals
|
30 |
n_goals = 100
|
31 |
+
goals_path = "data/raw/UBAR/multi-woz/data.json"
|
32 |
print("Loading goals...")
|
33 |
goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
|
34 |
|
scripts/UBAR_code/interaction/config.yaml
CHANGED
@@ -3,7 +3,7 @@ model:
|
|
3 |
goal_update:
|
4 |
finish_inform: "loose" # loose or strict
|
5 |
|
6 |
-
schema_path: "
|
7 |
|
8 |
decode:
|
9 |
dec_max_len: 1024
|
@@ -14,10 +14,10 @@ decode:
|
|
14 |
use_all_previous_context: False
|
15 |
|
16 |
dbs_path:
|
17 |
-
"attraction": "
|
18 |
-
"hospital": "
|
19 |
-
"hotel": "
|
20 |
-
"police": "
|
21 |
-
"restaurant": "
|
22 |
-
"taxi": "
|
23 |
-
"train": "
|
|
|
3 |
goal_update:
|
4 |
finish_inform: "loose" # loose or strict
|
5 |
|
6 |
+
schema_path: "scripts/user_model_code/interaction/schema.json"
|
7 |
|
8 |
decode:
|
9 |
dec_max_len: 1024
|
|
|
14 |
use_all_previous_context: False
|
15 |
|
16 |
dbs_path:
|
17 |
+
"attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json"
|
18 |
+
"hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json"
|
19 |
+
"hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json"
|
20 |
+
"police": "data/preprocessed/UBAR/db_processed/police_db_processed.json"
|
21 |
+
"restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json"
|
22 |
+
"taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json"
|
23 |
+
"train": "data/preprocessed/UBAR/db_processed/train_db_processed.json"
|
scripts/user_model_code/interaction/config.yaml
CHANGED
@@ -3,7 +3,7 @@ model:
|
|
3 |
goal_update:
|
4 |
finish_inform: "loose" # loose or strict
|
5 |
|
6 |
-
schema_path: "
|
7 |
|
8 |
decode:
|
9 |
dec_max_len: 1024
|
|
|
3 |
goal_update:
|
4 |
finish_inform: "loose" # loose or strict
|
5 |
|
6 |
+
schema_path: "scripts/user_model_code/interaction/schema.json"
|
7 |
|
8 |
decode:
|
9 |
dec_max_len: 1024
|
src/crazyneuraluser/UBAR_code/config.py
CHANGED
@@ -10,29 +10,25 @@ class _Config:
|
|
10 |
def _multiwoz_ubar_init(self):
|
11 |
self.gpt_path = "distilgpt2"
|
12 |
|
13 |
-
self.vocab_path_train = "
|
14 |
self.vocab_path_eval = None
|
15 |
-
self.data_path = "
|
16 |
self.data_file = "data_for_ubar.json"
|
17 |
-
self.dev_list = "
|
18 |
-
self.test_list = "
|
19 |
self.dbs = {
|
20 |
-
"attraction": "
|
21 |
-
"hospital": "
|
22 |
-
"hotel": "
|
23 |
-
"police": "
|
24 |
-
"restaurant": "
|
25 |
-
"taxi": "
|
26 |
-
"train": "
|
27 |
}
|
28 |
-
self.glove_path = "
|
29 |
-
self.domain_file_path = "
|
30 |
-
self.slot_value_set_path =
|
31 |
-
|
32 |
-
)
|
33 |
-
self.multi_acts_path = (
|
34 |
-
"cambridge-masters-project/data/preprocessed/UBAR/multi-woz-processed/multi_act_mapping_train.json"
|
35 |
-
)
|
36 |
self.exp_path = "to be generated"
|
37 |
self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
38 |
|
@@ -140,11 +136,11 @@ class _Config:
|
|
140 |
|
141 |
def _init_logging_handler(self, mode):
|
142 |
stderr_handler = logging.StreamHandler()
|
143 |
-
if not os.path.exists("
|
144 |
-
os.mkdir("
|
145 |
if self.save_log and self.mode == "train":
|
146 |
file_handler = logging.FileHandler(
|
147 |
-
"
|
148 |
self.log_time,
|
149 |
mode,
|
150 |
"-".join(self.exp_domains),
|
|
|
10 |
def _multiwoz_ubar_init(self):
|
11 |
self.gpt_path = "distilgpt2"
|
12 |
|
13 |
+
self.vocab_path_train = "data/preprocessed/UBAR/multi-woz-processed/vocab"
|
14 |
self.vocab_path_eval = None
|
15 |
+
self.data_path = "data/preprocessed/UBAR/multi-woz-processed/"
|
16 |
self.data_file = "data_for_ubar.json"
|
17 |
+
self.dev_list = "data/raw/UBAR/multi-woz/valListFile.json"
|
18 |
+
self.test_list = "data/raw/UBAR/multi-woz/testListFile.json"
|
19 |
self.dbs = {
|
20 |
+
"attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json",
|
21 |
+
"hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json",
|
22 |
+
"hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json",
|
23 |
+
"police": "data/preprocessed/UBAR/db_processed/police_db_processed.json",
|
24 |
+
"restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json",
|
25 |
+
"taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json",
|
26 |
+
"train": "data/preprocessed/UBAR/db_processed/train_db_processed.json",
|
27 |
}
|
28 |
+
self.glove_path = "data/glove/glove.6B.50d.txt"
|
29 |
+
self.domain_file_path = "data/preprocessed/UBAR/multi-woz-processed/domain_files.json"
|
30 |
+
self.slot_value_set_path = "data/preprocessed/UBAR/db_processed/value_set_processed.json"
|
31 |
+
self.multi_acts_path = "data/preprocessed/UBAR/multi-woz-processed/multi_act_mapping_train.json"
|
|
|
|
|
|
|
|
|
32 |
self.exp_path = "to be generated"
|
33 |
self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
34 |
|
|
|
136 |
|
137 |
def _init_logging_handler(self, mode):
|
138 |
stderr_handler = logging.StreamHandler()
|
139 |
+
if not os.path.exists("log"):
|
140 |
+
os.mkdir("log")
|
141 |
if self.save_log and self.mode == "train":
|
142 |
file_handler = logging.FileHandler(
|
143 |
+
"log/log_{}_{}_{}_{}_sd{}.txt".format(
|
144 |
self.log_time,
|
145 |
mode,
|
146 |
"-".join(self.exp_domains),
|