peter szemraj commited on
Commit
e58a27a
1 Parent(s): d2d61ad

:art: format code to black

Browse files
Files changed (1) hide show
  1. ai_single_response.py +18 -18
ai_single_response.py CHANGED
@@ -122,7 +122,7 @@ def query_gpt_model(
122
  ):
123
  """
124
  query_gpt_model - the main function that calls the model.
125
-
126
  Parameters:
127
  -----------
128
  prompt_msg (str): the prompt to be sent to the model
@@ -137,7 +137,7 @@ def query_gpt_model(
137
  use_gpu (bool, optional): use gpu. Defaults to False.
138
  """
139
  ai = aitextgen(
140
- model="pszemraj/Ballpark-Trivia-L", # THIS WORKS
141
  # model="pszemraj/Ballpark-Trivia-XL", # does not seem to work
142
  to_gpu=use_gpu,
143
  )
@@ -155,22 +155,22 @@ def query_gpt_model(
155
  # call the model
156
  print("\n... generating...")
157
  this_result = ai.generate(
158
- n=1,
159
- top_k=kparam,
160
- batch_size=batch_size,
161
- # the prompt input counts for text length constraints
162
- max_length=resp_length + pr_len,
163
- min_length=resp_min + pr_len,
164
- prompt=this_prompt,
165
- temperature=temp,
166
- top_p=top_p,
167
- top_k=kparam,
168
- do_sample=True,
169
- return_as_list=True,
170
- use_cache=True,
171
- num_beams=beams,
172
- no_repeat_ngram_size=2,
173
- )
174
  if verbose:
175
  print("\n... generated:\n")
176
  pp.pprint(this_result) # for debugging
 
122
  ):
123
  """
124
  query_gpt_model - the main function that calls the model.
125
+
126
  Parameters:
127
  -----------
128
  prompt_msg (str): the prompt to be sent to the model
 
137
  use_gpu (bool, optional): use gpu. Defaults to False.
138
  """
139
  ai = aitextgen(
140
+ model="pszemraj/Ballpark-Trivia-L", # THIS WORKS
141
  # model="pszemraj/Ballpark-Trivia-XL", # does not seem to work
142
  to_gpu=use_gpu,
143
  )
 
155
  # call the model
156
  print("\n... generating...")
157
  this_result = ai.generate(
158
+ n=1,
159
+ top_k=kparam,
160
+ batch_size=batch_size,
161
+ # the prompt input counts for text length constraints
162
+ max_length=resp_length + pr_len,
163
+ min_length=resp_min + pr_len,
164
+ prompt=this_prompt,
165
+ temperature=temp,
166
+ top_p=top_p,
167
+ top_k=kparam,
168
+ do_sample=True,
169
+ return_as_list=True,
170
+ use_cache=True,
171
+ num_beams=beams,
172
+ no_repeat_ngram_size=2,
173
+ )
174
  if verbose:
175
  print("\n... generated:\n")
176
  pp.pprint(this_result) # for debugging