Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -129,8 +129,6 @@ def load_model(model_path):
|
|
129 |
model.to(device)
|
130 |
return model
|
131 |
|
132 |
-
# Don't load the model here
|
133 |
-
# model = load_model('gpt_model.pth')
|
134 |
enc = tiktoken.get_encoding('gpt2')
|
135 |
|
136 |
# Update the generate_text function
|
@@ -166,6 +164,14 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
166 |
if len(generated) == max_length:
|
167 |
yield "... (output truncated due to length)"
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
# # Your existing imports and model code here...
|
171 |
|
|
|
129 |
model.to(device)
|
130 |
return model
|
131 |
|
|
|
|
|
132 |
enc = tiktoken.get_encoding('gpt2')
|
133 |
|
134 |
# Update the generate_text function
|
|
|
164 |
if len(generated) == max_length:
|
165 |
yield "... (output truncated due to length)"
|
166 |
|
167 |
+
# Add the gradio_generate function
|
168 |
+
@spaces.GPU(duration=60)
|
169 |
+
async def gradio_generate(prompt, max_length, temperature, top_k):
|
170 |
+
output = ""
|
171 |
+
async for token in generate_text(prompt, max_length, temperature, top_k):
|
172 |
+
output += token
|
173 |
+
yield output
|
174 |
+
|
175 |
|
176 |
# # Your existing imports and model code here...
|
177 |
|