Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -78,6 +78,12 @@ def ensure_complete_output(output, context, max_length, temperature, top_k, top_
|
|
78 |
|
79 |
# Text generation function for Gradio interface
|
80 |
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
with torch.no_grad():
|
82 |
with ctx:
|
83 |
start_ids = encode(prompt)
|
|
|
78 |
|
79 |
# Text generation function for Gradio interface
|
80 |
def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
|
81 |
+
# Add input validation
|
82 |
+
if num_samples is None:
|
83 |
+
num_samples = 1
|
84 |
+
elif not isinstance(num_samples, int):
|
85 |
+
raise ValueError("Number of samples must be an integer")
|
86 |
+
|
87 |
with torch.no_grad():
|
88 |
with ctx:
|
89 |
start_ids = encode(prompt)
|