Spaces:
Runtime error
Runtime error
peter szemraj
commited on
Commit
•
e58a27a
1
Parent(s):
d2d61ad
:art: format code to black
Browse files- ai_single_response.py +18 -18
ai_single_response.py
CHANGED
@@ -122,7 +122,7 @@ def query_gpt_model(
|
|
122 |
):
|
123 |
"""
|
124 |
query_gpt_model - the main function that calls the model.
|
125 |
-
|
126 |
Parameters:
|
127 |
-----------
|
128 |
prompt_msg (str): the prompt to be sent to the model
|
@@ -137,7 +137,7 @@ def query_gpt_model(
|
|
137 |
use_gpu (bool, optional): use gpu. Defaults to False.
|
138 |
"""
|
139 |
ai = aitextgen(
|
140 |
-
model="pszemraj/Ballpark-Trivia-L",
|
141 |
# model="pszemraj/Ballpark-Trivia-XL", # does not seem to work
|
142 |
to_gpu=use_gpu,
|
143 |
)
|
@@ -155,22 +155,22 @@ def query_gpt_model(
|
|
155 |
# call the model
|
156 |
print("\n... generating...")
|
157 |
this_result = ai.generate(
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
if verbose:
|
175 |
print("\n... generated:\n")
|
176 |
pp.pprint(this_result) # for debugging
|
|
|
122 |
):
|
123 |
"""
|
124 |
query_gpt_model - the main function that calls the model.
|
125 |
+
|
126 |
Parameters:
|
127 |
-----------
|
128 |
prompt_msg (str): the prompt to be sent to the model
|
|
|
137 |
use_gpu (bool, optional): use gpu. Defaults to False.
|
138 |
"""
|
139 |
ai = aitextgen(
|
140 |
+
model="pszemraj/Ballpark-Trivia-L", # THIS WORKS
|
141 |
# model="pszemraj/Ballpark-Trivia-XL", # does not seem to work
|
142 |
to_gpu=use_gpu,
|
143 |
)
|
|
|
155 |
# call the model
|
156 |
print("\n... generating...")
|
157 |
this_result = ai.generate(
|
158 |
+
n=1,
|
159 |
+
top_k=kparam,
|
160 |
+
batch_size=batch_size,
|
161 |
+
# the prompt input counts for text length constraints
|
162 |
+
max_length=resp_length + pr_len,
|
163 |
+
min_length=resp_min + pr_len,
|
164 |
+
prompt=this_prompt,
|
165 |
+
temperature=temp,
|
166 |
+
top_p=top_p,
|
167 |
+
top_k=kparam,
|
168 |
+
do_sample=True,
|
169 |
+
return_as_list=True,
|
170 |
+
use_cache=True,
|
171 |
+
num_beams=beams,
|
172 |
+
no_repeat_ngram_size=2,
|
173 |
+
)
|
174 |
if verbose:
|
175 |
print("\n... generated:\n")
|
176 |
pp.pprint(this_result) # for debugging
|