peter szemraj commited on
Commit
d2d61ad
1 Parent(s): 2979040

improve hyperparams

Browse files
Files changed (2) hide show
  1. ai_single_response.py +18 -13
  2. app.py +1 -0
ai_single_response.py CHANGED
@@ -115,8 +115,10 @@ def query_gpt_model(
115
  kparam=150,
116
  temp=0.75,
117
  top_p=0.65,
 
118
  verbose=False,
119
  use_gpu=False,
 
120
  ):
121
  """
122
  query_gpt_model - the main function that calls the model.
@@ -153,19 +155,22 @@ def query_gpt_model(
153
  # call the model
154
  print("\n... generating...")
155
  this_result = ai.generate(
156
- n=1,
157
- top_k=kparam,
158
- batch_size=128,
159
- # the prompt input counts for text length constraints
160
- max_length=resp_length + pr_len,
161
- min_length=resp_min + pr_len,
162
- prompt=this_prompt,
163
- temperature=temp,
164
- top_p=top_p,
165
- do_sample=True,
166
- return_as_list=True,
167
- use_cache=True,
168
- )
 
 
 
169
  if verbose:
170
  print("\n... generated:\n")
171
  pp.pprint(this_result) # for debugging
 
115
  kparam=150,
116
  temp=0.75,
117
  top_p=0.65,
118
+ batch_size=64,
119
  verbose=False,
120
  use_gpu=False,
121
+ beams=2,
122
  ):
123
  """
124
  query_gpt_model - the main function that calls the model.
 
155
  # call the model
156
  print("\n... generating...")
157
  this_result = ai.generate(
158
+ n=1,
159
+ top_k=kparam,
160
+ batch_size=batch_size,
161
+ # the prompt input counts for text length constraints
162
+ max_length=resp_length + pr_len,
163
+ min_length=resp_min + pr_len,
164
+ prompt=this_prompt,
165
+ temperature=temp,
166
+ top_p=top_p,
167
+ top_k=kparam,
168
+ do_sample=True,
169
+ return_as_list=True,
170
+ use_cache=True,
171
+ num_beams=beams,
172
+ no_repeat_ngram_size=2,
173
+ )
174
  if verbose:
175
  print("\n... generated:\n")
176
  pp.pprint(this_result) # for debugging
app.py CHANGED
@@ -73,6 +73,7 @@ def ask_gpt(message: str):
73
  kparam=150,
74
  temp=0.75,
75
  top_p=0.65,
 
76
  )
77
  if basic_sc:
78
  cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell)
 
73
  kparam=150,
74
  temp=0.75,
75
  top_p=0.65,
76
+ beams=4,
77
  )
78
  if basic_sc:
79
  cln_resp = symspeller(resp["out_text"], sym_checker=schnellspell)