Nick088 commited on
Commit
a81c6ef
1 Parent(s): 244f082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -21,6 +21,7 @@ def generate(
21
  temperature,
22
  top_p,
23
  top_k,
 
24
  seed,
25
  model_path="roborovski/superprompt-v1",
26
  dtype="fp16",
@@ -40,7 +41,9 @@ def generate(
40
  input_text = f"{prompt}, {history}"
41
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
42
 
43
- torch.manual_seed(seed)
 
 
44
  outputs = model.generate(
45
  input_ids,
46
  max_new_tokens=max_new_tokens,
@@ -100,6 +103,11 @@ additional_inputs = [
100
  label="Top K",
101
  info="Higher k means more diverse outputs by considering a range of tokens",
102
  ),
 
 
 
 
 
103
  gr.Number(
104
  value=42,
105
  interactive=True,
@@ -123,6 +131,7 @@ examples = [
123
  None,
124
  None,
125
  None,
 
126
  None,
127
  "roborovski/superprompt-v1",
128
  "fp16",
 
21
  temperature,
22
  top_p,
23
  top_k,
24
+ seed_checkbox,
25
  seed,
26
  model_path="roborovski/superprompt-v1",
27
  dtype="fp16",
 
41
  input_text = f"{prompt}, {history}"
42
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
43
 
44
+ if seed_checkbox:
45
+ torch.manual_seed(seed)
46
+
47
  outputs = model.generate(
48
  input_ids,
49
  max_new_tokens=max_new_tokens,
 
103
  label="Top K",
104
  info="Higher k means more diverse outputs by considering a range of tokens",
105
  ),
106
+ gr.Checkbox(
107
+ value=False,
108
+ label="Use Random Seed",
109
+ info="Check to use a random seed for the generation process",
110
+ ),
111
  gr.Number(
112
  value=42,
113
  interactive=True,
 
131
  None,
132
  None,
133
  None,
134
+ False,
135
  None,
136
  "roborovski/superprompt-v1",
137
  "fp16",