ChrisGeishauser commited on
Commit
2285042
1 Parent(s): 9b9ce11

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. config_saved.json +1 -0
  3. supervised.pol.mdl +3 -0
  4. train_INFO.log +351 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ supervised.pol.mdl filter=lfs diff=lfs merge=lfs -text
config_saved.json ADDED
@@ -0,0 +1 @@
 
1
+ {"args": {"seed": 1, "eval_freq": 2, "dataset_name": "multiwoz21", "model_path": "NO/seed1/save/supervised.pol.mdl"}, "config": {"batchsz": 64, "epoch": 40, "gamma": 0.99, "policy_lr": 5e-06, "supervised_lr": 1e-05, "entropy_weight": 0.01, "value_lr": 0.0001, "save_dir": "save", "log_dir": "log", "save_per_epoch": 5000, "hidden_size": 256, "load": "save/best", "logging_mode": "INFO", "use_cer": true, "memory_size": 5000, "behaviour_cloning_weight": 0.1, "supervised_weight": 0.0, "online_offline_ratio": 0.2, "smoothed_value_function": false, "use_reservoir_sampling": false, "seed": 0, "lambda": 1, "tau": 0.001, "policy_freq": 1, "print_per_batch": 400, "c": 1.0, "rho_bar": 1, "max_length": 10, "noisy_linear": false, "dataset_name": "multiwoz21", "data_percentage": 1.0, "dialogue_order": 0, "multiwoz_like": false, "regularization_weight": 0.0, "enc_input_dim": 128, "enc_nhead": 2, "enc_d_hid": 128, "enc_nlayers": 4, "enc_dropout": 0.1, "dec_input_dim": 128, "dec_nhead": 2, "dec_d_hid": 128, "dec_nlayers": 2, "dec_dropout": 0.0, "action_embedding_dim": 128, "domain_embedding_dim": 64, "value_embedding_dim": 12, "node_embedding_dim": 128, "roberta_path": "", "node_attention": true, "semantic_descriptions": true, "freeze_roberta": true, "use_pooled": false, "mean": true, "roberta_actions": true, "independent_descriptions": true, "random_matrix": false, "distance_metric": false, "verbose": false, "ignore_features": [], "domains_removed": ["hospital", "police", "train", "hotel", "attraction", "taxi"], "only_active_values": false, "permuted_data": false, "need_weights": false, "cls_dim": 128, "independent": true, "old_critic": false, "pos_weight": 5, "weight_decay": 1e-05}, "policy_config": null}
supervised.pol.mdl ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:847929fde204d26f279c7002ad7b8eb1108df943c24e207cd5bf01d2892f55ed
3
+ size 9331458
train_INFO.log ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Visible device: cuda
2
+ Seed used: 1
3
+ Batch size: 64
4
+ Epochs: 40
5
+ Learning rate: 1e-05
6
+ Entropy weight: 0.01
7
+ Regularization weight: 0.0
8
+ Only use multiwoz like domains: False
9
+ We use: 100.0% of the data
10
+ Dialogue order used: 0
11
+ Vectorizer: Data set used is multiwoz21
12
+ We filter state by active domains: True
13
+ Vectorizer: Data set used is multiwoz21
14
+ Embedding semantic descriptions: True
15
+ Embedded descriptions successfully. Size: torch.Size([338, 768])
16
+ Data set used for descriptions: multiwoz21
17
+ We use Roberta to embed actions.
18
+ Didnt load a model
19
+ Start training
20
+ Epoch: 0
21
+ Average actions: 1.957058072090149
22
+ Average target actions: 2.669339895248413
23
+ Precision: 0.13822525597269625
24
+ Recall: 0.10146667362597213
25
+ F1: 0.11702736056346508
26
+ <<dialog policy>> epoch 0: saved network to mdl
27
+ Best Precision: 0.13822525597269625
28
+ Best Recall: 0.10146667362597213
29
+ Best F1: 0.11702736056346508
30
+ Epoch: 1
31
+ Precision: 0.13822525597269625
32
+ Recall: 0.10146667362597213
33
+ F1: 0.11702736056346508
34
+ Best Precision: 0.13822525597269625
35
+ Best Recall: 0.10146667362597213
36
+ Best F1: 0.11702736056346508
37
+ Epoch: 2
38
+ Average actions: 2.0794308185577393
39
+ Average target actions: 2.6675729751586914
40
+ Precision: 0.22303363258743134
41
+ Recall: 0.1737564591053813
42
+ F1: 0.19533519143318176
43
+ <<dialog policy>> epoch 2: saved network to mdl
44
+ Best Precision: 0.22303363258743134
45
+ Best Recall: 0.1737564591053813
46
+ Best F1: 0.19533519143318176
47
+ Epoch: 3
48
+ Precision: 0.22303363258743134
49
+ Recall: 0.1737564591053813
50
+ F1: 0.19533519143318176
51
+ Best Precision: 0.22303363258743134
52
+ Best Recall: 0.1737564591053813
53
+ Best F1: 0.19533519143318176
54
+ Epoch: 4
55
+ Average actions: 2.0110926628112793
56
+ Average target actions: 2.665806293487549
57
+ Precision: 0.26409084614319345
58
+ Recall: 0.19907093272091445
59
+ F1: 0.22701705306389688
60
+ <<dialog policy>> epoch 4: saved network to mdl
61
+ Best Precision: 0.26409084614319345
62
+ Best Recall: 0.19907093272091445
63
+ Best F1: 0.22701705306389688
64
+ Epoch: 5
65
+ Precision: 0.26409084614319345
66
+ Recall: 0.19907093272091445
67
+ F1: 0.22701705306389688
68
+ Best Precision: 0.26409084614319345
69
+ Best Recall: 0.19907093272091445
70
+ Best F1: 0.22701705306389688
71
+ Epoch: 6
72
+ Average actions: 1.9673057794570923
73
+ Average target actions: 2.667219877243042
74
+ Precision: 0.2910210146465719
75
+ Recall: 0.21467717521791324
76
+ F1: 0.2470863871200288
77
+ <<dialog policy>> epoch 6: saved network to mdl
78
+ Best Precision: 0.2910210146465719
79
+ Best Recall: 0.21467717521791324
80
+ Best F1: 0.2470863871200288
81
+ Epoch: 7
82
+ Precision: 0.2910210146465719
83
+ Recall: 0.21467717521791324
84
+ F1: 0.2470863871200288
85
+ Best Precision: 0.2910210146465719
86
+ Best Recall: 0.21467717521791324
87
+ Best F1: 0.2470863871200288
88
+ Epoch: 8
89
+ Average actions: 1.8258512020111084
90
+ Average target actions: 2.667926549911499
91
+ Precision: 0.30450038138825325
92
+ Recall: 0.20836160551176994
93
+ F1: 0.24742012457776819
94
+ <<dialog policy>> epoch 8: saved network to mdl
95
+ Best Precision: 0.30450038138825325
96
+ Best Recall: 0.21467717521791324
97
+ Best F1: 0.24742012457776819
98
+ Epoch: 9
99
+ Precision: 0.30450038138825325
100
+ Recall: 0.20836160551176994
101
+ F1: 0.24742012457776819
102
+ Best Precision: 0.30450038138825325
103
+ Best Recall: 0.21467717521791324
104
+ Best F1: 0.24742012457776819
105
+ Epoch: 10
106
+ Average actions: 1.7796674966812134
107
+ Average target actions: 2.66333270072937
108
+ Precision: 0.3297132588483475
109
+ Recall: 0.2202620178506185
110
+ F1: 0.2640966268227048
111
+ <<dialog policy>> epoch 10: saved network to mdl
112
+ Best Precision: 0.3297132588483475
113
+ Best Recall: 0.2202620178506185
114
+ Best F1: 0.2640966268227048
115
+ Epoch: 11
116
+ Precision: 0.3297132588483475
117
+ Recall: 0.2202620178506185
118
+ F1: 0.2640966268227048
119
+ Best Precision: 0.3297132588483475
120
+ Best Recall: 0.2202620178506185
121
+ Best F1: 0.2640966268227048
122
+ Epoch: 12
123
+ Average actions: 1.8398014307022095
124
+ Average target actions: 2.67004656791687
125
+ Precision: 0.34064769975786924
126
+ Recall: 0.23498094890129964
127
+ F1: 0.27811583011583013
128
+ <<dialog policy>> epoch 12: saved network to mdl
129
+ Best Precision: 0.34064769975786924
130
+ Best Recall: 0.23498094890129964
131
+ Best F1: 0.27811583011583013
132
+ Epoch: 13
133
+ Precision: 0.34064769975786924
134
+ Recall: 0.23498094890129964
135
+ F1: 0.27811583011583013
136
+ Best Precision: 0.34064769975786924
137
+ Best Recall: 0.23498094890129964
138
+ Best F1: 0.27811583011583013
139
+ Epoch: 14
140
+ Average actions: 1.7070426940917969
141
+ Average target actions: 2.667219877243042
142
+ Precision: 0.35462034091835903
143
+ Recall: 0.22694295109348087
144
+ F1: 0.2767663908338638
145
+ Best Precision: 0.35462034091835903
146
+ Best Recall: 0.23498094890129964
147
+ Best F1: 0.27811583011583013
148
+ Epoch: 15
149
+ Precision: 0.35462034091835903
150
+ Recall: 0.22694295109348087
151
+ F1: 0.2767663908338638
152
+ Best Precision: 0.35462034091835903
153
+ Best Recall: 0.23498094890129964
154
+ Best F1: 0.27811583011583013
155
+ Epoch: 16
156
+ Average actions: 1.6812468767166138
157
+ Average target actions: 2.6643927097320557
158
+ Precision: 0.34859650575474044
159
+ Recall: 0.21974006994101988
160
+ F1: 0.2695607632219234
161
+ Best Precision: 0.35462034091835903
162
+ Best Recall: 0.23498094890129964
163
+ Best F1: 0.27811583011583013
164
+ Epoch: 17
165
+ Precision: 0.34859650575474044
166
+ Recall: 0.21974006994101988
167
+ F1: 0.2695607632219234
168
+ Best Precision: 0.35462034091835903
169
+ Best Recall: 0.23498094890129964
170
+ Best F1: 0.27811583011583013
171
+ Epoch: 18
172
+ Average actions: 1.675270438194275
173
+ Average target actions: 2.6640396118164062
174
+ Precision: 0.35976419794088343
175
+ Recall: 0.22616002922908293
176
+ F1: 0.27772970547703746
177
+ Best Precision: 0.35976419794088343
178
+ Best Recall: 0.23498094890129964
179
+ Best F1: 0.27811583011583013
180
+ Epoch: 19
181
+ Precision: 0.35976419794088343
182
+ Recall: 0.22616002922908293
183
+ F1: 0.27772970547703746
184
+ Best Precision: 0.35976419794088343
185
+ Best Recall: 0.23498094890129964
186
+ Best F1: 0.27811583011583013
187
+ Epoch: 20
188
+ Average actions: 1.5666790008544922
189
+ Average target actions: 2.6647462844848633
190
+ Precision: 0.3769442716203004
191
+ Recall: 0.2213581084607756
192
+ F1: 0.27892140743176586
193
+ <<dialog policy>> epoch 20: saved network to mdl
194
+ Best Precision: 0.3769442716203004
195
+ Best Recall: 0.23498094890129964
196
+ Best F1: 0.27892140743176586
197
+ Epoch: 21
198
+ Precision: 0.3769442716203004
199
+ Recall: 0.2213581084607756
200
+ F1: 0.27892140743176586
201
+ Best Precision: 0.3769442716203004
202
+ Best Recall: 0.23498094890129964
203
+ Best F1: 0.27892140743176586
204
+ Epoch: 22
205
+ Average actions: 1.6693706512451172
206
+ Average target actions: 2.6661596298217773
207
+ Precision: 0.3716379382130069
208
+ Recall: 0.23294535205386502
209
+ F1: 0.2863834702258727
210
+ <<dialog policy>> epoch 22: saved network to mdl
211
+ Best Precision: 0.3769442716203004
212
+ Best Recall: 0.23498094890129964
213
+ Best F1: 0.2863834702258727
214
+ Epoch: 23
215
+ Precision: 0.3716379382130069
216
+ Recall: 0.23294535205386502
217
+ F1: 0.2863834702258727
218
+ Best Precision: 0.3769442716203004
219
+ Best Recall: 0.23498094890129964
220
+ Best F1: 0.2863834702258727
221
+ Epoch: 24
222
+ Average actions: 1.6701388359069824
223
+ Average target actions: 2.6643927097320557
224
+ Precision: 0.3714618714618715
225
+ Recall: 0.23289315726290516
226
+ F1: 0.2862917455327067
227
+ Best Precision: 0.3769442716203004
228
+ Best Recall: 0.23498094890129964
229
+ Best F1: 0.2863834702258727
230
+ Epoch: 25
231
+ Precision: 0.3714618714618715
232
+ Recall: 0.23289315726290516
233
+ F1: 0.2862917455327067
234
+ Best Precision: 0.3769442716203004
235
+ Best Recall: 0.23498094890129964
236
+ Best F1: 0.2863834702258727
237
+ Epoch: 26
238
+ Average actions: 1.6909722089767456
239
+ Average target actions: 2.665099620819092
240
+ Precision: 0.3781160016454134
241
+ Recall: 0.2398872592515267
242
+ F1: 0.2935428242958421
243
+ <<dialog policy>> epoch 26: saved network to mdl
244
+ Best Precision: 0.3781160016454134
245
+ Best Recall: 0.2398872592515267
246
+ Best F1: 0.2935428242958421
247
+ Epoch: 27
248
+ Precision: 0.3781160016454134
249
+ Recall: 0.2398872592515267
250
+ F1: 0.2935428242958421
251
+ Best Precision: 0.3781160016454134
252
+ Best Recall: 0.2398872592515267
253
+ Best F1: 0.2935428242958421
254
+ Epoch: 28
255
+ Average actions: 1.8047566413879395
256
+ Average target actions: 2.6643927097320557
257
+ Precision: 0.3654779326811985
258
+ Recall: 0.24766428310454616
259
+ F1: 0.29525231783958683
260
+ <<dialog policy>> epoch 28: saved network to mdl
261
+ Best Precision: 0.3781160016454134
262
+ Best Recall: 0.24766428310454616
263
+ Best F1: 0.29525231783958683
264
+ Epoch: 29
265
+ Precision: 0.3654779326811985
266
+ Recall: 0.24766428310454616
267
+ F1: 0.29525231783958683
268
+ Best Precision: 0.3781160016454134
269
+ Best Recall: 0.24766428310454616
270
+ Best F1: 0.29525231783958683
271
+ Epoch: 30
272
+ Average actions: 1.680601716041565
273
+ Average target actions: 2.6640396118164062
274
+ Precision: 0.37665562913907286
275
+ Recall: 0.23748629886737305
276
+ F1: 0.2913025384935497
277
+ Best Precision: 0.3781160016454134
278
+ Best Recall: 0.24766428310454616
279
+ Best F1: 0.29525231783958683
280
+ Epoch: 31
281
+ Precision: 0.37665562913907286
282
+ Recall: 0.23748629886737305
283
+ F1: 0.2913025384935497
284
+ Best Precision: 0.3781160016454134
285
+ Best Recall: 0.24766428310454616
286
+ Best F1: 0.29525231783958683
287
+ Epoch: 32
288
+ Average actions: 1.7778853178024292
289
+ Average target actions: 2.667219877243042
290
+ Precision: 0.3660120491354354
291
+ Recall: 0.2441672321102354
292
+ F1: 0.2929242329367564
293
+ Best Precision: 0.3781160016454134
294
+ Best Recall: 0.24766428310454616
295
+ Best F1: 0.29525231783958683
296
+ Epoch: 33
297
+ Precision: 0.3660120491354354
298
+ Recall: 0.2441672321102354
299
+ F1: 0.2929242329367564
300
+ Best Precision: 0.3781160016454134
301
+ Best Recall: 0.24766428310454616
302
+ Best F1: 0.29525231783958683
303
+ Epoch: 34
304
+ Average actions: 1.726846694946289
305
+ Average target actions: 2.66333270072937
306
+ Precision: 0.3723121526938874
307
+ Recall: 0.24129651860744297
308
+ F1: 0.29281732961743095
309
+ Best Precision: 0.3781160016454134
310
+ Best Recall: 0.24766428310454616
311
+ Best F1: 0.29525231783958683
312
+ Epoch: 35
313
+ Precision: 0.3723121526938874
314
+ Recall: 0.24129651860744297
315
+ F1: 0.29281732961743095
316
+ Best Precision: 0.3781160016454134
317
+ Best Recall: 0.24766428310454616
318
+ Best F1: 0.29525231783958683
319
+ Epoch: 36
320
+ Average actions: 1.8067078590393066
321
+ Average target actions: 2.6675729751586914
322
+ Precision: 0.37099753694581283
323
+ Recall: 0.2515788924265358
324
+ F1: 0.29983515287238344
325
+ <<dialog policy>> epoch 36: saved network to mdl
326
+ Best Precision: 0.3781160016454134
327
+ Best Recall: 0.2515788924265358
328
+ Best F1: 0.29983515287238344
329
+ Epoch: 37
330
+ Precision: 0.37099753694581283
331
+ Recall: 0.2515788924265358
332
+ F1: 0.29983515287238344
333
+ Best Precision: 0.3781160016454134
334
+ Best Recall: 0.2515788924265358
335
+ Best F1: 0.29983515287238344
336
+ Epoch: 38
337
+ Average actions: 1.7964909076690674
338
+ Average target actions: 2.6647462844848633
339
+ Precision: 0.36536823356307596
340
+ Recall: 0.2462550237486299
341
+ F1: 0.2942130207034173
342
+ Best Precision: 0.3781160016454134
343
+ Best Recall: 0.2515788924265358
344
+ Best F1: 0.29983515287238344
345
+ Epoch: 39
346
+ Precision: 0.36536823356307596
347
+ Recall: 0.2462550237486299
348
+ F1: 0.2942130207034173
349
+ Best Precision: 0.3781160016454134
350
+ Best Recall: 0.2515788924265358
351
+ Best F1: 0.29983515287238344