pszemraj commited on
Commit
01bf561
2 Parent(s): a0d66c9 16157b3

Merge branch 'main' of https://huggingface.co/spaces/postbot/autocomplete-emails

Browse files
Files changed (2) hide show
  1. app.py +33 -8
  2. utils.py +1 -0
app.py CHANGED
@@ -23,6 +23,7 @@ def generate_text(
23
  # perma params (not set by user)
24
  repetition_penalty=3.5,
25
  abs_max_length=512,
 
26
  verbose=False,
27
  ):
28
  """
@@ -55,6 +56,7 @@ def generate_text(
55
  max_length=gen_length + input_len,
56
  min_length=input_len + 4,
57
  num_beams=num_beams,
 
58
  repetition_penalty=repetition_penalty,
59
  no_repeat_ngram_size=no_repeat_ngram_size,
60
  length_penalty=length_penalty,
@@ -70,7 +72,8 @@ def generate_text(
70
  formatted_email = postprocess(response)
71
  return make_mailto_form(body=formatted_email)
72
 
73
- def load_emailgen_model(model_tag:str):
 
74
  """
75
  load_emailgen_model - load a text generation pipeline for email generation
76
 
@@ -87,6 +90,7 @@ def load_emailgen_model(model_tag:str):
87
  device=0 if use_gpu else -1,
88
  )
89
 
 
90
  def get_parser():
91
  """
92
  get_parser - a helper function for the argparse module
@@ -111,6 +115,21 @@ def get_parser():
111
  action="store_true",
112
  help="Verbose output",
113
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return parser
115
 
116
 
@@ -119,7 +138,11 @@ Hello,
119
 
120
  Following up on last week's bubblegum shipment, I"""
121
 
122
- available_models = ['postbot/distilgpt2-emailgen-V2', 'postbot/distilgpt2-emailgen', 'postbot/gpt2-medium-emailgen']
 
 
 
 
123
 
124
  if __name__ == "__main__":
125
  logging.info("\n\n\nStarting new instance of app.py")
@@ -189,16 +212,18 @@ if __name__ == "__main__":
189
  value=model_tag,
190
  )
191
  load_model_button = gr.Button(
192
- 'Load Model',
193
- variant='secondary',
194
  )
195
 
196
  with gr.Row():
197
  num_beams = gr.Radio(
198
- choices=[4, 8, 12, 16],
199
- label="Number of Beams",
200
- value=4,
201
- )
 
 
202
  no_repeat_ngram_size = gr.Radio(
203
  choices=[1, 2, 3, 4],
204
  label="no repeat ngram size",
 
23
  # perma params (not set by user)
24
  repetition_penalty=3.5,
25
  abs_max_length=512,
26
+ num_beam_groups=2,
27
  verbose=False,
28
  ):
29
  """
 
56
  max_length=gen_length + input_len,
57
  min_length=input_len + 4,
58
  num_beams=num_beams,
59
+ num_beam_groups=num_beam_groups,
60
  repetition_penalty=repetition_penalty,
61
  no_repeat_ngram_size=no_repeat_ngram_size,
62
  length_penalty=length_penalty,
 
72
  formatted_email = postprocess(response)
73
  return make_mailto_form(body=formatted_email)
74
 
75
+
76
+ def load_emailgen_model(model_tag: str):
77
  """
78
  load_emailgen_model - load a text generation pipeline for email generation
79
 
 
90
  device=0 if use_gpu else -1,
91
  )
92
 
93
+
94
  def get_parser():
95
  """
96
  get_parser - a helper function for the argparse module
 
115
  action="store_true",
116
  help="Verbose output",
117
  )
118
+
119
+ parser.add_argument(
120
+ "-nb",
121
+ "--num_beams",
122
+ type=int,
123
+ default=4,
124
+ help="Number of beams for beam search. 1 means no beam search.",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--num_beam_groups",
129
+ type=int,
130
+ default=2,
131
+ help="Number of groups to divide nbest candidates into in order to ensure diversity among different groups of beams that yield the best n results. 1 means no group beam search.",
132
+ )
133
  return parser
134
 
135
 
 
138
 
139
  Following up on last week's bubblegum shipment, I"""
140
 
141
+ available_models = [
142
+ "postbot/distilgpt2-emailgen-V2",
143
+ "postbot/distilgpt2-emailgen",
144
+ "postbot/gpt2-medium-emailgen",
145
+ ]
146
 
147
  if __name__ == "__main__":
148
  logging.info("\n\n\nStarting new instance of app.py")
 
212
  value=model_tag,
213
  )
214
  load_model_button = gr.Button(
215
+ "Load Model",
216
+ variant="secondary",
217
  )
218
 
219
  with gr.Row():
220
  num_beams = gr.Radio(
221
+ choices=[4, 8, 12, 16],
222
+ label="Number of Beams",
223
+ value=4,
224
+ )
225
+ with gr.Row():
226
+
227
  no_repeat_ngram_size = gr.Radio(
228
  choices=[1, 2, 3, 4],
229
  label="no repeat ngram size",
utils.py CHANGED
@@ -4,6 +4,7 @@
4
  import logging
5
  import re
6
 
 
7
  def postprocess(text: str):
8
  """
9
  postprocess - remove common values in scraped dataset
 
4
  import logging
5
  import re
6
 
7
+
8
  def postprocess(text: str):
9
  """
10
  postprocess - remove common values in scraped dataset