Upload folder using huggingface_hub
Browse files- __pycache__/predict.cpython-311.pyc +0 -0
- model_pole_position.pt +2 -2
- model_pong_direct.pt +2 -2
- predict.py +6 -7
__pycache__/predict.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/predict.cpython-311.pyc and b/__pycache__/predict.cpython-311.pyc differ
|
|
|
model_pole_position.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8e0affcef8e533a29037751e27948a3eb0f2fda2792ce2b3dfc876cadb09e281
|
| 3 |
+
size 2971526
|
model_pong_direct.pt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:381e3abbf7c5c308e985f53a916379140be588383760fa117009775b6bc79281
|
| 3 |
+
size 1262522
|
predict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
@@ -65,13 +65,12 @@ def load_model(model_dir: str):
|
|
| 65 |
pong.eval()
|
| 66 |
ens.models["pong"] = pong
|
| 67 |
|
| 68 |
-
# Pong direct (
|
| 69 |
pong_direct = UNet(in_channels=24, out_channels=24,
|
| 70 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 71 |
upsample_mode="bilinear").to(DEVICE)
|
| 72 |
-
sd =
|
| 73 |
-
|
| 74 |
-
pong_direct.load_state_dict({k: v.float() for k, v in sd.items()})
|
| 75 |
pong_direct.eval()
|
| 76 |
ens.pong_direct = pong_direct
|
| 77 |
|
|
@@ -94,9 +93,9 @@ def load_model(model_dir: str):
|
|
| 94 |
sonic_direct.eval()
|
| 95 |
ens.sonic_direct = sonic_direct
|
| 96 |
|
| 97 |
-
# PP
|
| 98 |
pp = UNet(in_channels=24, out_channels=24,
|
| 99 |
-
enc_channels=(
|
| 100 |
upsample_mode="bilinear").to(DEVICE)
|
| 101 |
sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
|
| 102 |
map_location=DEVICE, weights_only=True)
|
|
|
|
| 1 |
+
"""Full PP swap: Pong direct int8, full PP model, Sonic AR fp16 + direct int8."""
|
| 2 |
import sys
|
| 3 |
import os
|
| 4 |
import numpy as np
|
|
|
|
| 65 |
pong.eval()
|
| 66 |
ens.models["pong"] = pong
|
| 67 |
|
| 68 |
+
# Pong direct (int8 quantized, 24 outputs)
|
| 69 |
pong_direct = UNet(in_channels=24, out_channels=24,
|
| 70 |
enc_channels=(32, 64, 128), bottleneck_channels=128,
|
| 71 |
upsample_mode="bilinear").to(DEVICE)
|
| 72 |
+
sd = load_int8_state_dict(os.path.join(model_dir, "model_pong_direct.pt"), DEVICE)
|
| 73 |
+
pong_direct.load_state_dict(sd)
|
|
|
|
| 74 |
pong_direct.eval()
|
| 75 |
ens.pong_direct = pong_direct
|
| 76 |
|
|
|
|
| 93 |
sonic_direct.eval()
|
| 94 |
ens.sonic_direct = sonic_direct
|
| 95 |
|
| 96 |
+
# PP full direct (fp16, 24 outputs)
|
| 97 |
pp = UNet(in_channels=24, out_channels=24,
|
| 98 |
+
enc_channels=(32, 64, 128), bottleneck_channels=192,
|
| 99 |
upsample_mode="bilinear").to(DEVICE)
|
| 100 |
sd = torch.load(os.path.join(model_dir, "model_pole_position.pt"),
|
| 101 |
map_location=DEVICE, weights_only=True)
|