Spaces:
Runtime error
Runtime error
changed app.py
Browse files
app.py
CHANGED
|
@@ -9,12 +9,12 @@ def get_model():
|
|
| 9 |
"""Load the trained GPT model."""
|
| 10 |
model = GPT(GPTConfig())
|
| 11 |
# Load from the Hugging Face Hub instead of local file
|
| 12 |
-
model_path = '
|
| 13 |
-
model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/
|
| 14 |
model.eval()
|
| 15 |
return model
|
| 16 |
|
| 17 |
-
def generate_text(prompt, max_tokens=500, temperature=0.
|
| 18 |
"""Generate text based on the prompt."""
|
| 19 |
# Encode the prompt
|
| 20 |
enc = tiktoken.get_encoding('gpt2')
|
|
|
|
| 9 |
"""Load the trained GPT model."""
|
| 10 |
model = GPT(GPTConfig())
|
| 11 |
# Load from the Hugging Face Hub instead of local file
|
| 12 |
+
model_path = 'mathminakshi/custom_gpt2'
|
| 13 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(f'https://huggingface.co/{model_path}/resolve/main/best_model.pth', map_location='cpu')['model_state_dict'])
|
| 14 |
model.eval()
|
| 15 |
return model
|
| 16 |
|
| 17 |
+
def generate_text(prompt, max_tokens=500, temperature=0.3, top_k=40):
|
| 18 |
"""Generate text based on the prompt."""
|
| 19 |
# Encode the prompt
|
| 20 |
enc = tiktoken.get_encoding('gpt2')
|