ChrisGeishauser commited on
Commit
ea585ab
1 Parent(s): c3ec217

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. config_saved.json +1 -0
  3. supervised.pol.mdl +3 -0
  4. train_INFO.log +341 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ supervised.pol.mdl filter=lfs diff=lfs merge=lfs -text
config_saved.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"args": {"seed": 0, "eval_freq": 2, "dataset_name": "multiwoz21", "model_path": "onvlab/policy/vtrace_DPT/supervised/experiments/sgd/save/supervised.pol.mdl"}, "config": {"batchsz": 64, "epoch": 40, "gamma": 0.99, "policy_lr": 5e-05, "supervised_lr": 1e-05, "entropy_weight": 0.01, "value_lr": 0.0001, "save_dir": "save", "log_dir": "log", "save_per_epoch": 5000, "hidden_size": 256, "load": "save/best", "logging_mode": "INFO", "use_cer": true, "memory_size": 5000, "behaviour_cloning_weight": 0.1, "supervised_weight": 0.0, "online_offline_ratio": 0.2, "smoothed_value_function": false, "use_reservoir_sampling": false, "seed": 0, "lambda": 1, "tau": 0.001, "policy_freq": 2, "print_per_batch": 400, "c": 1.0, "rho_bar": 1, "max_length": 10, "noisy_linear": false, "dataset_name": "multiwoz21", "data_percentage": 0.01, "multiwoz_like": false, "regularization_weight": 0.0, "enc_input_dim": 128, "enc_nhead": 2, "enc_d_hid": 128, "enc_nlayers": 4, "enc_dropout": 0.1, "dec_input_dim": 128, "dec_nhead": 2, "dec_d_hid": 128, "dec_nlayers": 2, "dec_dropout": 0.0, "action_embedding_dim": 128, "domain_embedding_dim": 64, "value_embedding_dim": 12, "node_embedding_dim": 128, "roberta_path": "", "node_attention": true, "semantic_descriptions": true, "freeze_roberta": true, "use_pooled": false, "mean": true, "roberta_actions": true, "independent_descriptions": true, "random_matrix": false, "distance_metric": false, "verbose": false, "ignore_features": [], "domains_removed": ["hospital", "police", "train", "hotel", "attraction", "taxi"], "only_active_values": false, "permuted_data": false, "need_weights": false, "cls_dim": 128, "independent": true, "old_critic": false, "pos_weight": 5, "weight_decay": 1e-05}, "policy_config": null}
supervised.pol.mdl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6992b768e91941f87a8f8275c88e5e3accf998bf4a1f7fbb5eb0bb337bd7fa6f
3
+ size 9331458
train_INFO.log ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Visible device: cuda
2
+ Seed used: 0
3
+ Batch size: 64
4
+ Epochs: 40
5
+ Learning rate: 1e-05
6
+ Entropy weight: 0.01
7
+ Regularization weight: 0.0
8
+ Only use multiwoz like domains: False
9
+ Vectorizer: Data set used is multiwoz21
10
+ We filter state by active domains: True
11
+ Vectorizer: Data set used is multiwoz21
12
+ Embedding semantic descriptions: True
13
+ Embedded descriptions successfully. Size: torch.Size([338, 768])
14
+ Data set used for descriptions: multiwoz21
15
+ We use Roberta to embed actions.
16
+ Didnt load a model
17
+ Start training
18
+ Epoch: 0
19
+ Precision: 0
20
+ Recall: 0
21
+ F1: 0
22
+ Best Precision: 0.0
23
+ Best Recall: 0.0
24
+ Best F1: 0.0
25
+ Epoch: 1
26
+ Precision: 0
27
+ Recall: 0
28
+ F1: 0
29
+ Best Precision: 0.0
30
+ Best Recall: 0.0
31
+ Best F1: 0.0
32
+ Epoch: 2
33
+ Average actions: 2.4348959922790527
34
+ Average target actions: 2.28125
35
+ Precision: 0.043010752688172046
36
+ Recall: 0.0425531914893617
37
+ F1: 0.04278074866310161
38
+ <<dialog policy>> epoch 2: saved network to mdl
39
+ Best Precision: 0.043010752688172046
40
+ Best Recall: 0.0425531914893617
41
+ Best F1: 0.04278074866310161
42
+ Epoch: 3
43
+ Precision: 0.043010752688172046
44
+ Recall: 0.0425531914893617
45
+ F1: 0.04278074866310161
46
+ Best Precision: 0.043010752688172046
47
+ Best Recall: 0.0425531914893617
48
+ Best F1: 0.04278074866310161
49
+ Epoch: 4
50
+ Average actions: 2.4114584922790527
51
+ Average target actions: 2.7890625
52
+ Precision: 0.07058823529411765
53
+ Recall: 0.06382978723404255
54
+ F1: 0.06703910614525138
55
+ <<dialog policy>> epoch 4: saved network to mdl
56
+ Best Precision: 0.07058823529411765
57
+ Best Recall: 0.06382978723404255
58
+ Best F1: 0.06703910614525138
59
+ Epoch: 5
60
+ Precision: 0.07058823529411765
61
+ Recall: 0.06382978723404255
62
+ F1: 0.06703910614525138
63
+ Best Precision: 0.07058823529411765
64
+ Best Recall: 0.06382978723404255
65
+ Best F1: 0.06703910614525138
66
+ Epoch: 6
67
+ Average actions: 2.1536459922790527
68
+ Average target actions: 2.5859375
69
+ Precision: 0.049079754601226995
70
+ Recall: 0.0425531914893617
71
+ F1: 0.045584045584045586
72
+ Best Precision: 0.07058823529411765
73
+ Best Recall: 0.06382978723404255
74
+ Best F1: 0.06703910614525138
75
+ Epoch: 7
76
+ Precision: 0.049079754601226995
77
+ Recall: 0.0425531914893617
78
+ F1: 0.045584045584045586
79
+ Best Precision: 0.07058823529411765
80
+ Best Recall: 0.06382978723404255
81
+ Best F1: 0.06703910614525138
82
+ Epoch: 8
83
+ Average actions: 2.15625
84
+ Average target actions: 2.5520834922790527
85
+ Precision: 0.07547169811320754
86
+ Recall: 0.06382978723404255
87
+ F1: 0.06916426512968299
88
+ <<dialog policy>> epoch 8: saved network to mdl
89
+ Best Precision: 0.07547169811320754
90
+ Best Recall: 0.06382978723404255
91
+ Best F1: 0.06916426512968299
92
+ Epoch: 9
93
+ Precision: 0.07547169811320754
94
+ Recall: 0.06382978723404255
95
+ F1: 0.06916426512968299
96
+ Best Precision: 0.07547169811320754
97
+ Best Recall: 0.06382978723404255
98
+ Best F1: 0.06916426512968299
99
+ Epoch: 10
100
+ Average actions: 2.0572915077209473
101
+ Average target actions: 2.3489584922790527
102
+ Precision: 0.04516129032258064
103
+ Recall: 0.03723404255319149
104
+ F1: 0.04081632653061224
105
+ Best Precision: 0.07547169811320754
106
+ Best Recall: 0.06382978723404255
107
+ Best F1: 0.06916426512968299
108
+ Epoch: 11
109
+ Precision: 0.04516129032258064
110
+ Recall: 0.03723404255319149
111
+ F1: 0.04081632653061224
112
+ Best Precision: 0.07547169811320754
113
+ Best Recall: 0.06382978723404255
114
+ Best F1: 0.06916426512968299
115
+ Epoch: 12
116
+ Average actions: 1.984375
117
+ Average target actions: 2.5520834922790527
118
+ Precision: 0.08666666666666667
119
+ Recall: 0.06914893617021277
120
+ F1: 0.07692307692307691
121
+ <<dialog policy>> epoch 12: saved network to mdl
122
+ Best Precision: 0.08666666666666667
123
+ Best Recall: 0.06914893617021277
124
+ Best F1: 0.07692307692307691
125
+ Epoch: 13
126
+ Precision: 0.08666666666666667
127
+ Recall: 0.06914893617021277
128
+ F1: 0.07692307692307691
129
+ Best Precision: 0.08666666666666667
130
+ Best Recall: 0.06914893617021277
131
+ Best F1: 0.07692307692307691
132
+ Epoch: 14
133
+ Average actions: 2.0416665077209473
134
+ Average target actions: 2.3828125
135
+ Precision: 0.05228758169934641
136
+ Recall: 0.0425531914893617
137
+ F1: 0.046920821114369494
138
+ Best Precision: 0.08666666666666667
139
+ Best Recall: 0.06914893617021277
140
+ Best F1: 0.07692307692307691
141
+ Epoch: 15
142
+ Precision: 0.05228758169934641
143
+ Recall: 0.0425531914893617
144
+ F1: 0.046920821114369494
145
+ Best Precision: 0.08666666666666667
146
+ Best Recall: 0.06914893617021277
147
+ Best F1: 0.07692307692307691
148
+ Epoch: 16
149
+ Average actions: 2.1666665077209473
150
+ Average target actions: 2.2135417461395264
151
+ Precision: 0.1346153846153846
152
+ Recall: 0.11170212765957446
153
+ F1: 0.12209302325581395
154
+ <<dialog policy>> epoch 16: saved network to mdl
155
+ Best Precision: 0.1346153846153846
156
+ Best Recall: 0.11170212765957446
157
+ Best F1: 0.12209302325581395
158
+ Epoch: 17
159
+ Precision: 0.1346153846153846
160
+ Recall: 0.11170212765957446
161
+ F1: 0.12209302325581395
162
+ Best Precision: 0.1346153846153846
163
+ Best Recall: 0.11170212765957446
164
+ Best F1: 0.12209302325581395
165
+ Epoch: 18
166
+ Average actions: 1.7734375
167
+ Average target actions: 2.5520834922790527
168
+ Precision: 0.0661764705882353
169
+ Recall: 0.047872340425531915
170
+ F1: 0.05555555555555556
171
+ Best Precision: 0.1346153846153846
172
+ Best Recall: 0.11170212765957446
173
+ Best F1: 0.12209302325581395
174
+ Epoch: 19
175
+ Precision: 0.0661764705882353
176
+ Recall: 0.047872340425531915
177
+ F1: 0.05555555555555556
178
+ Best Precision: 0.1346153846153846
179
+ Best Recall: 0.11170212765957446
180
+ Best F1: 0.12209302325581395
181
+ Epoch: 20
182
+ Average actions: 2.1328125
183
+ Average target actions: 2.6197917461395264
184
+ Precision: 0.1346153846153846
185
+ Recall: 0.11170212765957446
186
+ F1: 0.12209302325581395
187
+ Best Precision: 0.1346153846153846
188
+ Best Recall: 0.11170212765957446
189
+ Best F1: 0.12209302325581395
190
+ Epoch: 21
191
+ Precision: 0.1346153846153846
192
+ Recall: 0.11170212765957446
193
+ F1: 0.12209302325581395
194
+ Best Precision: 0.1346153846153846
195
+ Best Recall: 0.11170212765957446
196
+ Best F1: 0.12209302325581395
197
+ Epoch: 22
198
+ Average actions: 1.9296875
199
+ Average target actions: 2.1119792461395264
200
+ Precision: 0.08391608391608392
201
+ Recall: 0.06382978723404255
202
+ F1: 0.07250755287009063
203
+ Best Precision: 0.1346153846153846
204
+ Best Recall: 0.11170212765957446
205
+ Best F1: 0.12209302325581395
206
+ Epoch: 23
207
+ Precision: 0.08391608391608392
208
+ Recall: 0.06382978723404255
209
+ F1: 0.07250755287009063
210
+ Best Precision: 0.1346153846153846
211
+ Best Recall: 0.11170212765957446
212
+ Best F1: 0.12209302325581395
213
+ Epoch: 24
214
+ Average actions: 2.2213540077209473
215
+ Average target actions: 2.3151042461395264
216
+ Precision: 0.09815950920245399
217
+ Recall: 0.0851063829787234
218
+ F1: 0.09116809116809117
219
+ Best Precision: 0.1346153846153846
220
+ Best Recall: 0.11170212765957446
221
+ Best F1: 0.12209302325581395
222
+ Epoch: 25
223
+ Precision: 0.09815950920245399
224
+ Recall: 0.0851063829787234
225
+ F1: 0.09116809116809117
226
+ Best Precision: 0.1346153846153846
227
+ Best Recall: 0.11170212765957446
228
+ Best F1: 0.12209302325581395
229
+ Epoch: 26
230
+ Average actions: 2.1171875
231
+ Average target actions: 2.7890625
232
+ Precision: 0.12987012987012986
233
+ Recall: 0.10638297872340426
234
+ F1: 0.11695906432748537
235
+ Best Precision: 0.1346153846153846
236
+ Best Recall: 0.11170212765957446
237
+ Best F1: 0.12209302325581395
238
+ Epoch: 27
239
+ Precision: 0.12987012987012986
240
+ Recall: 0.10638297872340426
241
+ F1: 0.11695906432748537
242
+ Best Precision: 0.1346153846153846
243
+ Best Recall: 0.11170212765957446
244
+ Best F1: 0.12209302325581395
245
+ Epoch: 28
246
+ Average actions: 1.7734375
247
+ Average target actions: 2.484375
248
+ Precision: 0.08823529411764706
249
+ Recall: 0.06382978723404255
250
+ F1: 0.07407407407407407
251
+ Best Precision: 0.1346153846153846
252
+ Best Recall: 0.11170212765957446
253
+ Best F1: 0.12209302325581395
254
+ Epoch: 29
255
+ Precision: 0.08823529411764706
256
+ Recall: 0.06382978723404255
257
+ F1: 0.07407407407407407
258
+ Best Precision: 0.1346153846153846
259
+ Best Recall: 0.11170212765957446
260
+ Best F1: 0.12209302325581395
261
+ Epoch: 30
262
+ Average actions: 2.1822915077209473
263
+ Average target actions: 2.3489584922790527
264
+ Precision: 0.10126582278481013
265
+ Recall: 0.0851063829787234
266
+ F1: 0.09248554913294797
267
+ Best Precision: 0.1346153846153846
268
+ Best Recall: 0.11170212765957446
269
+ Best F1: 0.12209302325581395
270
+ Epoch: 31
271
+ Precision: 0.10126582278481013
272
+ Recall: 0.0851063829787234
273
+ F1: 0.09248554913294797
274
+ Best Precision: 0.1346153846153846
275
+ Best Recall: 0.11170212765957446
276
+ Best F1: 0.12209302325581395
277
+ Epoch: 32
278
+ Average actions: 2.0442707538604736
279
+ Average target actions: 2.6197917461395264
280
+ Precision: 0.12345679012345678
281
+ Recall: 0.10638297872340426
282
+ F1: 0.11428571428571428
283
+ Best Precision: 0.1346153846153846
284
+ Best Recall: 0.11170212765957446
285
+ Best F1: 0.12209302325581395
286
+ Epoch: 33
287
+ Precision: 0.12345679012345678
288
+ Recall: 0.10638297872340426
289
+ F1: 0.11428571428571428
290
+ Best Precision: 0.1346153846153846
291
+ Best Recall: 0.11170212765957446
292
+ Best F1: 0.12209302325581395
293
+ Epoch: 34
294
+ Average actions: 1.8307292461395264
295
+ Average target actions: 2.5859375
296
+ Precision: 0.11510791366906475
297
+ Recall: 0.0851063829787234
298
+ F1: 0.09785932721712538
299
+ Best Precision: 0.1346153846153846
300
+ Best Recall: 0.11170212765957446
301
+ Best F1: 0.12209302325581395
302
+ Epoch: 35
303
+ Precision: 0.11510791366906475
304
+ Recall: 0.0851063829787234
305
+ F1: 0.09785932721712538
306
+ Best Precision: 0.1346153846153846
307
+ Best Recall: 0.11170212765957446
308
+ Best F1: 0.12209302325581395
309
+ Epoch: 36
310
+ Average actions: 2.2838540077209473
311
+ Average target actions: 2.3489584922790527
312
+ Precision: 0.1286549707602339
313
+ Recall: 0.11702127659574468
314
+ F1: 0.12256267409470752
315
+ <<dialog policy>> epoch 36: saved network to mdl
316
+ Best Precision: 0.1346153846153846
317
+ Best Recall: 0.11702127659574468
318
+ Best F1: 0.12256267409470752
319
+ Epoch: 37
320
+ Precision: 0.1286549707602339
321
+ Recall: 0.11702127659574468
322
+ F1: 0.12256267409470752
323
+ Best Precision: 0.1346153846153846
324
+ Best Recall: 0.11702127659574468
325
+ Best F1: 0.12256267409470752
326
+ Epoch: 38
327
+ Average actions: 1.9479167461395264
328
+ Average target actions: 2.7552084922790527
329
+ Precision: 0.12337662337662338
330
+ Recall: 0.10106382978723404
331
+ F1: 0.1111111111111111
332
+ Best Precision: 0.1346153846153846
333
+ Best Recall: 0.11702127659574468
334
+ Best F1: 0.12256267409470752
335
+ Epoch: 39
336
+ Precision: 0.12337662337662338
337
+ Recall: 0.10106382978723404
338
+ F1: 0.1111111111111111
339
+ Best Precision: 0.1346153846153846
340
+ Best Recall: 0.11702127659574468
341
+ Best F1: 0.12256267409470752