Spaces:
Runtime error
Runtime error
peter szemraj
commited on
Commit
•
d2d61ad
1
Parent(s):
2979040
improve hyperparams
Browse files- ai_single_response.py +18 -13
- app.py +1 -0
ai_single_response.py
CHANGED
@@ -115,8 +115,10 @@ def query_gpt_model(
|
|
115 |
kparam=150,
|
116 |
temp=0.75,
|
117 |
top_p=0.65,
|
|
|
118 |
verbose=False,
|
119 |
use_gpu=False,
|
|
|
120 |
):
|
121 |
"""
|
122 |
query_gpt_model - the main function that calls the model.
|
@@ -153,19 +155,22 @@ def query_gpt_model(
|
|
153 |
# call the model
|
154 |
print("\n... generating...")
|
155 |
this_result = ai.generate(
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
169 |
if verbose:
|
170 |
print("\n... generated:\n")
|
171 |
pp.pprint(this_result) # for debugging
|
|
|
115 |
kparam=150,
|
116 |
temp=0.75,
|
117 |
top_p=0.65,
|
118 |
+
batch_size=64,
|
119 |
verbose=False,
|
120 |
use_gpu=False,
|
121 |
+
beams=2,
|
122 |
):
|
123 |
"""
|
124 |
query_gpt_model - the main function that calls the model.
|
|
|
155 |
# call the model
|
156 |
print("\n... generating...")
|
157 |
this_result = ai.generate(
|
158 |
+
n=1,
|
159 |
+
top_k=kparam,
|
160 |
+
batch_size=batch_size,
|
161 |
+
# the prompt input counts for text length constraints
|
162 |
+
max_length=resp_length + pr_len,
|
163 |
+
min_length=resp_min + pr_len,
|
164 |
+
prompt=this_prompt,
|
165 |
+
temperature=temp,
|
166 |
+
top_p=top_p,
|
167 |
+
top_k=kparam,
|
168 |
+
do_sample=True,
|
169 |
+
return_as_list=True,
|
170 |
+
use_cache=True,
|
171 |
+
num_beams=beams,
|
172 |
+
no_repeat_ngram_size=2,
|
173 |
+
)
|
174 |
if verbose:
|
175 |
print("\n... generated:\n")
|
176 |
pp.pprint(this_result) # for debugging
|
app.py
CHANGED
@@ -73,6 +73,7 @@ def ask_gpt(message: str):
|
|
73 |
kparam=150,
|
74 |
temp=0.75,
|
75 |
top_p=0.65,
|
|
|
76 |
)
|
77 |
if basic_sc:
|
78 |
cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell)
|
|
|
73 |
kparam=150,
|
74 |
temp=0.75,
|
75 |
top_p=0.65,
|
76 |
+
beams=4,
|
77 |
)
|
78 |
if basic_sc:
|
79 |
cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell)
|