p1atdev commited on
Commit
51d0877
1 Parent(s): 13f05ca

chore: use default generate function

Browse files
Files changed (1) hide show
  1. app.py +33 -3
app.py CHANGED
@@ -31,19 +31,31 @@ def generate(
31
  temperature=1.0,
32
  top_p=0.95,
33
  top_k=20,
 
 
34
  ):
35
  if input_text.strip() == "":
36
  return ""
37
 
38
  inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
39
- generated = model.custom_generate(
 
 
 
 
 
 
 
 
 
40
  **inputs,
41
- parallel_compute_prompt=True,
42
  max_new_tokens=max_new_tokens,
43
  do_sample=do_sample,
44
  temperature=temperature,
45
  top_p=top_p,
46
  top_k=top_k,
 
 
47
  )
48
  return tokenizer.batch_decode(generated)[0]
49
 
@@ -97,7 +109,7 @@ with gr.Blocks() as demo:
97
  label="Max tokens",
98
  minimum=8,
99
  maximum=512,
100
- value=128,
101
  step=4,
102
  )
103
  do_sample = gr.Checkbox(
@@ -125,6 +137,20 @@ with gr.Blocks() as demo:
125
  value=20,
126
  step=1,
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  gr.Examples(
130
  examples=EXAMPLE_INPUTS,
@@ -140,6 +166,8 @@ with gr.Blocks() as demo:
140
  temperature,
141
  top_p,
142
  top_k,
 
 
143
  ],
144
  outputs=output_text,
145
  queue=False,
@@ -153,6 +181,8 @@ with gr.Blocks() as demo:
153
  temperature,
154
  top_p,
155
  top_k,
 
 
156
  ],
157
  outputs=[input_text, output_text],
158
  queue=False,
 
31
  temperature=1.0,
32
  top_p=0.95,
33
  top_k=20,
34
+ no_repeat_ngram_size=3,
35
+ num_beams=1,
36
  ):
37
  if input_text.strip() == "":
38
  return ""
39
 
40
  inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
41
+ # generated = model.custom_generate(
42
+ # **inputs,
43
+ # parallel_compute_prompt=True,
44
+ # max_new_tokens=max_new_tokens,
45
+ # do_sample=do_sample,
46
+ # temperature=temperature,
47
+ # top_p=top_p,
48
+ # top_k=top_k,
49
+ # )
50
+ generated = model.generate(
51
  **inputs,
 
52
  max_new_tokens=max_new_tokens,
53
  do_sample=do_sample,
54
  temperature=temperature,
55
  top_p=top_p,
56
  top_k=top_k,
57
+ no_repeat_ngram_size=no_repeat_ngram_size,
58
+ num_beams=num_beams,
59
  )
60
  return tokenizer.batch_decode(generated)[0]
61
 
 
109
  label="Max tokens",
110
  minimum=8,
111
  maximum=512,
112
+ value=64,
113
  step=4,
114
  )
115
  do_sample = gr.Checkbox(
 
137
  value=20,
138
  step=1,
139
  )
140
+ no_repeat_ngram_size = gr.Slider(
141
+ label="No repeat ngram size",
142
+ minimum=0,
143
+ maximum=10,
144
+ value=3,
145
+ step=1,
146
+ )
147
+ num_beams = gr.Slider(
148
+ label="Num beams",
149
+ minimum=1,
150
+ maximum=8,
151
+ value=1,
152
+ step=1,
153
+ )
154
 
155
  gr.Examples(
156
  examples=EXAMPLE_INPUTS,
 
166
  temperature,
167
  top_p,
168
  top_k,
169
+ no_repeat_ngram_size,
170
+ num_beams,
171
  ],
172
  outputs=output_text,
173
  queue=False,
 
181
  temperature,
182
  top_p,
183
  top_k,
184
+ no_repeat_ngram_size,
185
+ num_beams,
186
  ],
187
  outputs=[input_text, output_text],
188
  queue=False,