ojaffe commited on
Commit
ed16acf
·
verified ·
1 Parent(s): d7c2bdd

Upload folder using huggingface_hub

Browse files
__pycache__/predict.cpython-311.pyc CHANGED
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
 
model_pong_direct.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab8070ddcde00333d7b52c89a0da9a61eece1e67c46163cd011ce4cd3c422f0c
3
- size 2436712
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ada51c7d09e003a2bea134bfe7be0e762756f10cebaab73c904135c5a4e33cf
3
+ size 2437546
predict.py CHANGED
@@ -167,7 +167,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
167
 
168
  predicted = torch.zeros_like(direct_pred)
169
  for step in range(PRED_FRAMES):
170
- ar_weight = 0.8
171
  direct_weight = 1.0 - ar_weight
172
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
173
 
@@ -227,7 +227,7 @@ def predict_next_frame(ens, context_frames: np.ndarray) -> np.ndarray:
227
 
228
  predicted = torch.zeros_like(direct_pred)
229
  for step in range(PRED_FRAMES):
230
- ar_weight = 0.5
231
  direct_weight = 1.0 - ar_weight
232
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
233
 
 
167
 
168
  predicted = torch.zeros_like(direct_pred)
169
  for step in range(PRED_FRAMES):
170
+ ar_weight = 0.85 - (step / (PRED_FRAMES - 1)) * 0.2
171
  direct_weight = 1.0 - ar_weight
172
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
173
 
 
227
 
228
  predicted = torch.zeros_like(direct_pred)
229
  for step in range(PRED_FRAMES):
230
+ ar_weight = 0.7 - (step / (PRED_FRAMES - 1)) * 0.4
231
  direct_weight = 1.0 - ar_weight
232
  predicted[:, step] = ar_weight * ar_pred[:, step] + direct_weight * direct_pred[:, step]
233
 
train.log CHANGED
@@ -1,77 +1,52 @@
1
- [2026-04-12 05:14:21] Starting augmented Pong AR training for 2026-04-12-110000-pong-augmented
2
- [2026-04-12 05:14:21] Device: cuda
3
- [2026-04-12 05:14:21] Pong params: 1,198,531 (2.3 MB fp16)
4
- [2026-04-12 05:14:21] Phase 1: single-step + hflip aug, 100 epochs, lr=0.0003
5
- [2026-04-12 05:14:22] pong train: 8432 sequences (seq_len=9)
6
- [2026-04-12 05:14:22] pong val: 992 sequences (seq_len=9)
7
- [2026-04-12 05:14:31] P1 E1/100 | T:0.127401(S:0.8196) V:0.110707(S:0.8434) LR:3.00e-04
8
- [2026-04-12 05:14:38] P1 E2/100 | T:0.101264(S:0.8567) V:0.099352(S:0.8593) LR:3.00e-04
9
- [2026-04-12 05:14:46] P1 E3/100 | T:0.093089(S:0.8683) V:0.101696(S:0.8561) LR:2.99e-04
10
- [2026-04-12 05:15:01] P1 E5/100 | T:0.082154(S:0.8838) V:0.084095(S:0.8810) LR:2.98e-04
11
- [2026-04-12 05:15:07] P1 E6/100 | T:0.077151(S:0.8908) V:0.080748(S:0.8857) LR:2.97e-04
12
- [2026-04-12 05:15:22] P1 E8/100 | T:0.066223(S:0.9063) V:0.075570(S:0.8930) LR:2.95e-04
13
- [2026-04-12 05:15:37] P1 E10/100 | T:0.059430(S:0.9160) V:0.077257(S:0.8907) LR:2.93e-04
14
- [2026-04-12 05:15:44] P1 E11/100 | T:0.055401(S:0.9217) V:0.071990(S:0.8981) LR:2.91e-04
15
- [2026-04-12 05:15:51] P1 E12/100 | T:0.052138(S:0.9263) V:0.069664(S:0.9014) LR:2.90e-04
16
- [2026-04-12 05:15:59] P1 E13/100 | T:0.049025(S:0.9307) V:0.069328(S:0.9019) LR:2.88e-04
17
- [2026-04-12 05:16:06] P1 E14/100 | T:0.047183(S:0.9333) V:0.063733(S:0.9098) LR:2.86e-04
18
- [2026-04-12 05:16:12] P1 E15/100 | T:0.044935(S:0.9365) V:0.063349(S:0.9104) LR:2.84e-04
19
- [2026-04-12 05:16:27] P1 E17/100 | T:0.042284(S:0.9403) V:0.061114(S:0.9135) LR:2.79e-04
20
- [2026-04-12 05:16:35] P1 E18/100 | T:0.040933(S:0.9422) V:0.058447(S:0.9173) LR:2.77e-04
21
- [2026-04-12 05:16:49] P1 E20/100 | T:0.038315(S:0.9459) V:0.061833(S:0.9125) LR:2.71e-04
22
- [2026-04-12 05:16:57] P1 E21/100 | T:0.037595(S:0.9469) V:0.057463(S:0.9187) LR:2.69e-04
23
- [2026-04-12 05:17:04] P1 E22/100 | T:0.036860(S:0.9479) V:0.055401(S:0.9216) LR:2.66e-04
24
- [2026-04-12 05:17:18] P1 E24/100 | T:0.033665(S:0.9525) V:0.055152(S:0.9220) LR:2.59e-04
25
- [2026-04-12 05:17:25] P1 E25/100 | T:0.032287(S:0.9544) V:0.054026(S:0.9235) LR:2.56e-04
26
- [2026-04-12 05:17:40] P1 E27/100 | T:0.030628(S:0.9568) V:0.053318(S:0.9245) LR:2.49e-04
27
- [2026-04-12 05:17:47] P1 E28/100 | T:0.029062(S:0.9590) V:0.051204(S:0.9275) LR:2.46e-04
28
- [2026-04-12 05:18:02] P1 E30/100 | T:0.027471(S:0.9612) V:0.049118(S:0.9305) LR:2.38e-04
29
- [2026-04-12 05:18:31] P1 E34/100 | T:0.024918(S:0.9649) V:0.048185(S:0.9318) LR:2.23e-04
30
- [2026-04-12 05:18:53] P1 E37/100 | T:0.022570(S:0.9682) V:0.046684(S:0.9339) LR:2.10e-04
31
- [2026-04-12 05:19:14] P1 E40/100 | T:0.020725(S:0.9708) V:0.045685(S:0.9353) LR:1.97e-04
32
- [2026-04-12 05:19:29] P1 E42/100 | T:0.019517(S:0.9725) V:0.045561(S:0.9355) LR:1.88e-04
33
- [2026-04-12 05:19:51] P1 E45/100 | T:0.018041(S:0.9746) V:0.045091(S:0.9362) LR:1.74e-04
34
- [2026-04-12 05:20:05] P1 E47/100 | T:0.017219(S:0.9757) V:0.044821(S:0.9365) LR:1.65e-04
35
- [2026-04-12 05:20:13] P1 E48/100 | T:0.016920(S:0.9762) V:0.044686(S:0.9367) LR:1.60e-04
36
- [2026-04-12 05:20:28] P1 E50/100 | T:0.016031(S:0.9774) V:0.045932(S:0.9350) LR:1.50e-04
37
- [2026-04-12 05:20:43] P1 E52/100 | T:0.015575(S:0.9781) V:0.044620(S:0.9368) LR:1.41e-04
38
- [2026-04-12 05:20:58] P1 E54/100 | T:0.014855(S:0.9791) V:0.044487(S:0.9370) LR:1.32e-04
39
- [2026-04-12 05:21:04] P1 E55/100 | T:0.014467(S:0.9796) V:0.043267(S:0.9387) LR:1.27e-04
40
- [2026-04-12 05:21:11] P1 E56/100 | T:0.013987(S:0.9803) V:0.043114(S:0.9389) LR:1.22e-04
41
- [2026-04-12 05:21:41] P1 E60/100 | T:0.012814(S:0.9820) V:0.042679(S:0.9396) LR:1.04e-04
42
- [2026-04-12 05:21:55] P1 E62/100 | T:0.012138(S:0.9829) V:0.041535(S:0.9412) LR:9.55e-05
43
- [2026-04-12 05:22:09] P1 E64/100 | T:0.011468(S:0.9839) V:0.041223(S:0.9416) LR:8.68e-05
44
- [2026-04-12 05:22:42] P1 E69/100 | T:0.010313(S:0.9855) V:0.040714(S:0.9423) LR:6.65e-05
45
- [2026-04-12 05:22:49] P1 E70/100 | T:0.010118(S:0.9858) V:0.041125(S:0.9418) LR:6.26e-05
46
- [2026-04-12 05:23:22] P1 E75/100 | T:0.009125(S:0.9872) V:0.040657(S:0.9424) LR:4.48e-05
47
- [2026-04-12 05:23:52] P1 E79/100 | T:0.008602(S:0.9879) V:0.040464(S:0.9427) LR:3.24e-05
48
- [2026-04-12 05:23:59] P1 E80/100 | T:0.008478(S:0.9881) V:0.040606(S:0.9425) LR:2.96e-05
49
- [2026-04-12 05:24:19] P1 E83/100 | T:0.008091(S:0.9886) V:0.040422(S:0.9427) LR:2.18e-05
50
- [2026-04-12 05:24:34] P1 E85/100 | T:0.008020(S:0.9887) V:0.040420(S:0.9427) LR:1.73e-05
51
- [2026-04-12 05:24:49] P1 E87/100 | T:0.007915(S:0.9889) V:0.040373(S:0.9428) LR:1.33e-05
52
- [2026-04-12 05:24:56] P1 E88/100 | T:0.007870(S:0.9890) V:0.040319(S:0.9429) LR:1.15e-05
53
- [2026-04-12 05:25:10] P1 E90/100 | T:0.007718(S:0.9892) V:0.040199(S:0.9431) LR:8.32e-06
54
- [2026-04-12 05:25:17] P1 E91/100 | T:0.007692(S:0.9892) V:0.040185(S:0.9431) LR:6.94e-06
55
- [2026-04-12 05:25:24] P1 E92/100 | T:0.007651(S:0.9893) V:0.040164(S:0.9431) LR:5.70e-06
56
- [2026-04-12 05:25:32] P1 E93/100 | T:0.007631(S:0.9893) V:0.040117(S:0.9432) LR:4.60e-06
57
- [2026-04-12 05:26:21] P1 E100/100 | T:0.007589(S:0.9893) V:0.040109(S:0.9432) LR:1.00e-06
58
- [2026-04-12 05:26:21] Phase 1 done. Best val loss: 0.040109
59
- [2026-04-12 05:26:21] Phase 2: 2-step AR + hflip aug, 50 epochs, lr=5e-05
60
- [2026-04-12 05:26:22] pong train: 8398 sequences (seq_len=10)
61
- [2026-04-12 05:26:22] pong val: 988 sequences (seq_len=10)
62
- [2026-04-12 05:26:44] P2 E1/50 | T:0.014117(S:0.9801) V:0.056007(S:0.9206) LR:5.00e-05
63
- [2026-04-12 05:27:05] P2 E2/50 | T:0.013481(S:0.9810) V:0.054888(S:0.9222) LR:4.98e-05
64
- [2026-04-12 05:27:26] P2 E3/50 | T:0.013067(S:0.9816) V:0.054390(S:0.9229) LR:4.96e-05
65
- [2026-04-12 05:28:52] P2 E7/50 | T:0.011823(S:0.9834) V:0.053736(S:0.9238) LR:4.77e-05
66
- [2026-04-12 05:30:01] P2 E10/50 | T:0.011506(S:0.9838) V:0.054305(S:0.9230) LR:4.53e-05
67
- [2026-04-12 05:33:34] P2 E20/50 | T:0.010052(S:0.9859) V:0.053792(S:0.9237) LR:3.31e-05
68
- [2026-04-12 05:34:17] P2 E22/50 | T:0.009761(S:0.9863) V:0.053448(S:0.9242) LR:3.01e-05
69
- [2026-04-12 05:34:39] P2 E23/50 | T:0.009488(S:0.9867) V:0.053285(S:0.9245) LR:2.86e-05
70
- [2026-04-12 05:35:22] P2 E25/50 | T:0.009336(S:0.9869) V:0.052888(S:0.9250) LR:2.55e-05
71
- [2026-04-12 05:37:11] P2 E30/50 | T:0.008903(S:0.9875) V:0.053298(S:0.9244) LR:1.79e-05
72
- [2026-04-12 05:40:56] P2 E40/50 | T:0.008178(S:0.9885) V:0.053040(S:0.9248) LR:5.68e-06
73
- [2026-04-12 05:42:25] P2 E44/50 | T:0.008011(S:0.9888) V:0.052805(S:0.9251) LR:2.72e-06
74
- [2026-04-12 05:44:41] P2 E50/50 | T:0.007865(S:0.9890) V:0.052993(S:0.9249) LR:1.00e-06
75
- [2026-04-12 05:44:41] Phase 2 done. Best val loss: 0.052805
76
- [2026-04-12 05:44:41] Pong model: 2.3 MB
77
- [2026-04-12 05:44:41] Training complete!
 
1
+ [2026-04-12 05:59:24] Starting Pong direct 8-frame training for 2026-04-12-133000-pong-direct-v2
2
+ [2026-04-12 05:59:24] Device: cuda
3
+ [2026-04-12 05:59:25] Pong direct: 1,199,224 params (2.3 MB fp16)
4
+ [2026-04-12 05:59:25] pong train: 8194 seqs (len=16)
5
+ [2026-04-12 05:59:26] pong val: 964 seqs (len=16)
6
+ [2026-04-12 05:59:47] E1/150 | T:0.232601(S:0.6492) V:0.206121(S:0.6850) LR:3.00e-04
7
+ [2026-04-12 06:00:07] E2/150 | T:0.173923(S:0.7357) V:0.175883(S:0.7317) LR:3.00e-04
8
+ [2026-04-12 06:00:28] E3/150 | T:0.138439(S:0.7891) V:0.157651(S:0.7599) LR:3.00e-04
9
+ [2026-04-12 06:00:50] E4/150 | T:0.117363(S:0.8211) V:0.140205(S:0.7864) LR:2.99e-04
10
+ [2026-04-12 06:01:09] E5/150 | T:0.103357(S:0.8425) V:0.136901(S:0.7912) LR:2.99e-04
11
+ [2026-04-12 06:01:30] E6/150 | T:0.093950(S:0.8569) V:0.126921(S:0.8063) LR:2.99e-04
12
+ [2026-04-12 06:01:51] E7/150 | T:0.086157(S:0.8691) V:0.121485(S:0.8144) LR:2.98e-04
13
+ [2026-04-12 06:02:53] E10/150 | T:0.071323(S:0.8921) V:0.119263(S:0.8183) LR:2.97e-04
14
+ [2026-04-12 06:03:13] E11/150 | T:0.067633(S:0.8979) V:0.113872(S:0.8260) LR:2.96e-04
15
+ [2026-04-12 06:03:34] E12/150 | T:0.064540(S:0.9027) V:0.112058(S:0.8288) LR:2.95e-04
16
+ [2026-04-12 06:04:38] E15/150 | T:0.057899(S:0.9133) V:0.109369(S:0.8330) LR:2.93e-04
17
+ [2026-04-12 06:05:00] E16/150 | T:0.054467(S:0.9185) V:0.108353(S:0.8339) LR:2.92e-04
18
+ [2026-04-12 06:05:20] E17/150 | T:0.053424(S:0.9202) V:0.106683(S:0.8371) LR:2.91e-04
19
+ [2026-04-12 06:06:19] E20/150 | T:0.049003(S:0.9271) V:0.106042(S:0.8383) LR:2.87e-04
20
+ [2026-04-12 06:07:00] E22/150 | T:0.046678(S:0.9307) V:0.102919(S:0.8428) LR:2.84e-04
21
+ [2026-04-12 06:08:19] E26/150 | T:0.042844(S:0.9367) V:0.102117(S:0.8444) LR:2.78e-04
22
+ [2026-04-12 06:08:39] E27/150 | T:0.042220(S:0.9377) V:0.102082(S:0.8444) LR:2.77e-04
23
+ [2026-04-12 06:09:38] E30/150 | T:0.039701(S:0.9416) V:0.101520(S:0.8453) LR:2.71e-04
24
+ [2026-04-12 06:09:58] E31/150 | T:0.038995(S:0.9427) V:0.100438(S:0.8467) LR:2.70e-04
25
+ [2026-04-12 06:10:16] E32/150 | T:0.039062(S:0.9426) V:0.100245(S:0.8473) LR:2.68e-04
26
+ [2026-04-12 06:10:36] E33/150 | T:0.038450(S:0.9436) V:0.099826(S:0.8476) LR:2.66e-04
27
+ [2026-04-12 06:11:13] E35/150 | T:0.037610(S:0.9449) V:0.099788(S:0.8479) LR:2.62e-04
28
+ [2026-04-12 06:11:33] E36/150 | T:0.036675(S:0.9463) V:0.098966(S:0.8496) LR:2.59e-04
29
+ [2026-04-12 06:12:52] E40/150 | T:0.034903(S:0.9490) V:0.099786(S:0.8485) LR:2.51e-04
30
+ [2026-04-12 06:13:50] E43/150 | T:0.033734(S:0.9509) V:0.098349(S:0.8506) LR:2.43e-04
31
+ [2026-04-12 06:14:10] E44/150 | T:0.033505(S:0.9512) V:0.098003(S:0.8510) LR:2.41e-04
32
+ [2026-04-12 06:14:50] E46/150 | T:0.033140(S:0.9518) V:0.097767(S:0.8514) LR:2.36e-04
33
+ [2026-04-12 06:15:49] E49/150 | T:0.032213(S:0.9533) V:0.097305(S:0.8522) LR:2.28e-04
34
+ [2026-04-12 06:16:09] E50/150 | T:0.031634(S:0.9541) V:0.097047(S:0.8524) LR:2.25e-04
35
+ [2026-04-12 06:17:28] E54/150 | T:0.030719(S:0.9556) V:0.096493(S:0.8531) LR:2.14e-04
36
+ [2026-04-12 06:18:27] E57/150 | T:0.030218(S:0.9563) V:0.095355(S:0.8549) LR:2.06e-04
37
+ [2026-04-12 06:18:47] E58/150 | T:0.029896(S:0.9568) V:0.094789(S:0.8558) LR:2.03e-04
38
+ [2026-04-12 06:19:27] E60/150 | T:0.029534(S:0.9574) V:0.096519(S:0.8532) LR:1.97e-04
39
+ [2026-04-12 06:22:43] E70/150 | T:0.027838(S:0.9600) V:0.095992(S:0.8544) LR:1.66e-04
40
+ [2026-04-12 06:25:55] E80/150 | T:0.026637(S:0.9619) V:0.095754(S:0.8547) LR:1.35e-04
41
+ [2026-04-12 06:28:12] E87/150 | T:0.025836(S:0.9631) V:0.094571(S:0.8563) LR:1.13e-04
42
+ [2026-04-12 06:28:52] E89/150 | T:0.025732(S:0.9633) V:0.094485(S:0.8564) LR:1.07e-04
43
+ [2026-04-12 06:29:11] E90/150 | T:0.025655(S:0.9634) V:0.094835(S:0.8561) LR:1.04e-04
44
+ [2026-04-12 06:32:25] E100/150 | T:0.024915(S:0.9645) V:0.094847(S:0.8561) LR:7.57e-05
45
+ [2026-04-12 06:35:38] E110/150 | T:0.024432(S:0.9653) V:0.094407(S:0.8567) LR:5.05e-05
46
+ [2026-04-12 06:38:57] E120/150 | T:0.024090(S:0.9658) V:0.094761(S:0.8562) LR:2.96e-05
47
+ [2026-04-12 06:42:12] E130/150 | T:0.023891(S:0.9661) V:0.094652(S:0.8564) LR:1.39e-05
48
+ [2026-04-12 06:45:25] E140/150 | T:0.023785(S:0.9663) V:0.094651(S:0.8564) LR:4.27e-06
49
+ [2026-04-12 06:48:40] E150/150 | T:0.023744(S:0.9663) V:0.094663(S:0.8564) LR:1.00e-06
50
+ [2026-04-12 06:48:40] Done. Best val loss: 0.094407
51
+ [2026-04-12 06:48:40] Model size: 2.3 MB
52
+ [2026-04-12 06:48:40] Training complete!