Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -29,7 +29,10 @@ enc = tiktoken.Encoding(
|
|
29 |
# Load model from checkpoint
|
30 |
model_save_path = 'tuned_ckpt.pt'
|
31 |
if os.path.exists(model_save_path):
|
32 |
-
|
|
|
|
|
|
|
33 |
else:
|
34 |
raise FileNotFoundError(f"Model file {model_save_path} not found")
|
35 |
|
|
|
29 |
# Load model from checkpoint
|
30 |
model_save_path = 'tuned_ckpt.pt'
|
31 |
if os.path.exists(model_save_path):
|
32 |
+
checkpoint = torch.load(model_save_path, map_location=device)
|
33 |
+
gptconf = GPTConfig(**checkpoint['model_args'])
|
34 |
+
model = GPT(gptconf)
|
35 |
+
model.load_state_dict(checkpoint['model'])
|
36 |
else:
|
37 |
raise FileNotFoundError(f"Model file {model_save_path} not found")
|
38 |
|