alistairmcleay commited on
Commit
9d33cfe
1 Parent(s): 94fe073

Fixing paths

Browse files
app.py CHANGED
@@ -15,20 +15,20 @@ from scripts.UBAR_code.interaction.UBAR_interact import bcolors
15
 
16
 
17
  # Initialise agents
18
- UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2"
19
- user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k"
20
 
21
  sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
22
- "UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml"
23
  )
24
  user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
25
- "user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml"
26
  )
27
 
28
 
29
  # Get goals
30
  n_goals = 100
31
- goals_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json"
32
  print("Loading goals...")
33
  goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
34
 
 
15
 
16
 
17
  # Initialise agents
18
+ UBAR_checkpoint_path = "epoch50_trloss0.59_gpt2"
19
+ user_model_checkpoint_path = "MultiWOZ-full_checkpoint_step340k"
20
 
21
  sys_model = self_play_sys_model = UBAR_interact.UbarSystemModel(
22
+ "UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
23
  )
24
  user_model = self_play_user_model = multiwoz_interact.NeuralAgent(
25
+ "user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
26
  )
27
 
28
 
29
  # Get goals
30
  n_goals = 100
31
+ goals_path = "data/raw/UBAR/multi-woz/data.json"
32
  print("Loading goals...")
33
  goals = multiwoz_interact.read_multiWOZ_20_goals(goals_path, n_goals)
34
 
scripts/UBAR_code/interaction/config.yaml CHANGED
@@ -3,7 +3,7 @@ model:
3
  goal_update:
4
  finish_inform: "loose" # loose or strict
5
 
6
- schema_path: "cambridge-masters-project/scripts/user_model_code/interaction/schema.json"
7
 
8
  decode:
9
  dec_max_len: 1024
@@ -14,10 +14,10 @@ decode:
14
  use_all_previous_context: False
15
 
16
  dbs_path:
17
- "attraction": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/attraction_db_processed.json"
18
- "hospital": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/hospital_db_processed.json"
19
- "hotel": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/hotel_db_processed.json"
20
- "police": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/police_db_processed.json"
21
- "restaurant": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/restaurant_db_processed.json"
22
- "taxi": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/taxi_db_processed.json"
23
- "train": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/train_db_processed.json"
 
3
  goal_update:
4
  finish_inform: "loose" # loose or strict
5
 
6
+ schema_path: "scripts/user_model_code/interaction/schema.json"
7
 
8
  decode:
9
  dec_max_len: 1024
 
14
  use_all_previous_context: False
15
 
16
  dbs_path:
17
+ "attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json"
18
+ "hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json"
19
+ "hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json"
20
+ "police": "data/preprocessed/UBAR/db_processed/police_db_processed.json"
21
+ "restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json"
22
+ "taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json"
23
+ "train": "data/preprocessed/UBAR/db_processed/train_db_processed.json"
scripts/user_model_code/interaction/config.yaml CHANGED
@@ -3,7 +3,7 @@ model:
3
  goal_update:
4
  finish_inform: "loose" # loose or strict
5
 
6
- schema_path: "cambridge-masters-project/scripts/user_model_code/interaction/schema.json"
7
 
8
  decode:
9
  dec_max_len: 1024
 
3
  goal_update:
4
  finish_inform: "loose" # loose or strict
5
 
6
+ schema_path: "scripts/user_model_code/interaction/schema.json"
7
 
8
  decode:
9
  dec_max_len: 1024
src/crazyneuraluser/UBAR_code/config.py CHANGED
@@ -10,29 +10,25 @@ class _Config:
10
  def _multiwoz_ubar_init(self):
11
  self.gpt_path = "distilgpt2"
12
 
13
- self.vocab_path_train = "cambridge-masters-project/data/preprocessed/UBAR/multi-woz-processed/vocab"
14
  self.vocab_path_eval = None
15
- self.data_path = "cambridge-masters-project/data/preprocessed/UBAR/multi-woz-processed/"
16
  self.data_file = "data_for_ubar.json"
17
- self.dev_list = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json"
18
- self.test_list = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json"
19
  self.dbs = {
20
- "attraction": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/attraction_db_processed.json",
21
- "hospital": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/hospital_db_processed.json",
22
- "hotel": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/hotel_db_processed.json",
23
- "police": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/police_db_processed.json",
24
- "restaurant": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/restaurant_db_processed.json",
25
- "taxi": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/taxi_db_processed.json",
26
- "train": "cambridge-masters-project/data/preprocessed/UBAR/db_processed/train_db_processed.json",
27
  }
28
- self.glove_path = "cambridge-masters-project/data/glove/glove.6B.50d.txt"
29
- self.domain_file_path = "cambridge-masters-project/data/preprocessed/UBAR/multi-woz-processed/domain_files.json"
30
- self.slot_value_set_path = (
31
- "cambridge-masters-project/data/preprocessed/UBAR/db_processed/value_set_processed.json"
32
- )
33
- self.multi_acts_path = (
34
- "cambridge-masters-project/data/preprocessed/UBAR/multi-woz-processed/multi_act_mapping_train.json"
35
- )
36
  self.exp_path = "to be generated"
37
  self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
38
 
@@ -140,11 +136,11 @@ class _Config:
140
 
141
  def _init_logging_handler(self, mode):
142
  stderr_handler = logging.StreamHandler()
143
- if not os.path.exists("./log"):
144
- os.mkdir("./log")
145
  if self.save_log and self.mode == "train":
146
  file_handler = logging.FileHandler(
147
- "./log/log_{}_{}_{}_{}_sd{}.txt".format(
148
  self.log_time,
149
  mode,
150
  "-".join(self.exp_domains),
 
10
  def _multiwoz_ubar_init(self):
11
  self.gpt_path = "distilgpt2"
12
 
13
+ self.vocab_path_train = "data/preprocessed/UBAR/multi-woz-processed/vocab"
14
  self.vocab_path_eval = None
15
+ self.data_path = "data/preprocessed/UBAR/multi-woz-processed/"
16
  self.data_file = "data_for_ubar.json"
17
+ self.dev_list = "data/raw/UBAR/multi-woz/valListFile.json"
18
+ self.test_list = "data/raw/UBAR/multi-woz/testListFile.json"
19
  self.dbs = {
20
+ "attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json",
21
+ "hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json",
22
+ "hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json",
23
+ "police": "data/preprocessed/UBAR/db_processed/police_db_processed.json",
24
+ "restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json",
25
+ "taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json",
26
+ "train": "data/preprocessed/UBAR/db_processed/train_db_processed.json",
27
  }
28
+ self.glove_path = "data/glove/glove.6B.50d.txt"
29
+ self.domain_file_path = "data/preprocessed/UBAR/multi-woz-processed/domain_files.json"
30
+ self.slot_value_set_path = "data/preprocessed/UBAR/db_processed/value_set_processed.json"
31
+ self.multi_acts_path = "data/preprocessed/UBAR/multi-woz-processed/multi_act_mapping_train.json"
 
 
 
 
32
  self.exp_path = "to be generated"
33
  self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
34
 
 
136
 
137
  def _init_logging_handler(self, mode):
138
  stderr_handler = logging.StreamHandler()
139
+ if not os.path.exists("log"):
140
+ os.mkdir("log")
141
  if self.save_log and self.mode == "train":
142
  file_handler = logging.FileHandler(
143
+ "log/log_{}_{}_{}_{}_sd{}.txt".format(
144
  self.log_time,
145
  mode,
146
  "-".join(self.exp_domains),