Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -138,6 +138,9 @@ try:
|
|
| 138 |
)
|
| 139 |
state_dict = load_file(model_path, device=device)
|
| 140 |
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
| 141 |
model_loaded = True
|
| 142 |
print(f"✅ Model loaded successfully from SafeTensors: {repo_id}")
|
| 143 |
except Exception as e:
|
|
@@ -148,7 +151,9 @@ try:
|
|
| 148 |
filename="model_checkpoint_final.pt",
|
| 149 |
cache_dir=None
|
| 150 |
)
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# Handle different checkpoint formats
|
| 154 |
if 'model_state_dict' in checkpoint:
|
|
@@ -168,7 +173,8 @@ try:
|
|
| 168 |
filename="model_checkpoint_final.pt",
|
| 169 |
cache_dir=None
|
| 170 |
)
|
| 171 |
-
|
|
|
|
| 172 |
|
| 173 |
# Handle different checkpoint formats
|
| 174 |
if 'model_state_dict' in checkpoint:
|
|
@@ -185,7 +191,8 @@ try:
|
|
| 185 |
print(f"⚠️ Could not load from Hub ({e}), trying local file...")
|
| 186 |
try:
|
| 187 |
# Fallback to local file
|
| 188 |
-
|
|
|
|
| 189 |
if 'model_state_dict' in checkpoint:
|
| 190 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 191 |
elif 'state_dict' in checkpoint:
|
|
@@ -369,5 +376,6 @@ with gr.Blocks(title="GPT-2 124M Shakespeare Model") as demo:
|
|
| 369 |
""")
|
| 370 |
|
| 371 |
if __name__ == "__main__":
|
| 372 |
-
|
|
|
|
| 373 |
|
|
|
|
| 138 |
)
|
| 139 |
state_dict = load_file(model_path, device=device)
|
| 140 |
model.load_state_dict(state_dict)
|
| 141 |
+
# Restore weight sharing (broken during SafeTensors conversion)
|
| 142 |
+
# lm_head.weight and transformer.wte.weight should share memory
|
| 143 |
+
model.transformer.wte.weight = model.lm_head.weight
|
| 144 |
model_loaded = True
|
| 145 |
print(f"✅ Model loaded successfully from SafeTensors: {repo_id}")
|
| 146 |
except Exception as e:
|
|
|
|
| 151 |
filename="model_checkpoint_final.pt",
|
| 152 |
cache_dir=None
|
| 153 |
)
|
| 154 |
+
# PyTorch 2.6+ requires weights_only=False for custom classes
|
| 155 |
+
# This is safe since we trust our own trained model
|
| 156 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
| 157 |
|
| 158 |
# Handle different checkpoint formats
|
| 159 |
if 'model_state_dict' in checkpoint:
|
|
|
|
| 173 |
filename="model_checkpoint_final.pt",
|
| 174 |
cache_dir=None
|
| 175 |
)
|
| 176 |
+
# PyTorch 2.6+ requires weights_only=False for custom classes
|
| 177 |
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
| 178 |
|
| 179 |
# Handle different checkpoint formats
|
| 180 |
if 'model_state_dict' in checkpoint:
|
|
|
|
| 191 |
print(f"⚠️ Could not load from Hub ({e}), trying local file...")
|
| 192 |
try:
|
| 193 |
# Fallback to local file
|
| 194 |
+
# PyTorch 2.6+ requires weights_only=False for custom classes
|
| 195 |
+
checkpoint = torch.load('model_checkpoint_final.pt', map_location=device, weights_only=False)
|
| 196 |
if 'model_state_dict' in checkpoint:
|
| 197 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 198 |
elif 'state_dict' in checkpoint:
|
|
|
|
| 376 |
""")
|
| 377 |
|
| 378 |
if __name__ == "__main__":
|
| 379 |
+
# Don't use share=True on HuggingFace Spaces
|
| 380 |
+
demo.launch()
|
| 381 |
|