saintyboy commited on
Commit
48efaf5
1 Parent(s): 542e671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -78,6 +78,12 @@ def ensure_complete_output(output, context, max_length, temperature, top_k, top_
78
 
79
  # Text generation function for Gradio interface
80
  def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
 
 
 
 
 
 
81
  with torch.no_grad():
82
  with ctx:
83
  start_ids = encode(prompt)
 
78
 
79
  # Text generation function for Gradio interface
80
  def generate_text(prompt, num_samples, max_new_tokens, temperature, top_k, top_p, repetition_penalty, eor_token_id):
81
+ # Add input validation
82
+ if num_samples is None:
83
+ num_samples = 1
84
+ elif not isinstance(num_samples, int):
85
+ raise ValueError("Number of samples must be an integer")
86
+
87
  with torch.no_grad():
88
  with ctx:
89
  start_ids = encode(prompt)