peter szemraj commited on
Commit
88b1e11
1 Parent(s): 91d1162
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -11,6 +11,7 @@ logging.basicConfig(
11
 
12
  use_gpu = torch.cuda.is_available()
13
 
 
14
  def generate_text(
15
  prompt: str,
16
  gen_length=64,
@@ -40,7 +41,7 @@ def generate_text(
40
  st = time.perf_counter()
41
 
42
  input_tokens = generator.tokenizer(prompt)
43
- input_len = len(input_tokens['input_ids'])
44
  if input_len > abs_max_length:
45
  logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
46
  result = generator(
@@ -55,9 +56,8 @@ def generate_text(
55
  early_stopping=True,
56
  # tokenizer
57
  truncation=True,
58
-
59
- ) # generate
60
- response = result[0]['generated_text']
61
  rt = time.perf_counter() - st
62
  if verbose:
63
  logging.info(f"Generated text: {response}")
@@ -74,12 +74,12 @@ def get_parser():
74
  )
75
 
76
  parser.add_argument(
77
- '-m',
78
- '--model',
79
  required=False,
80
  type=str,
81
  default="postbot/distilgpt2-emailgen",
82
- help='Pass an different huggingface model tag to use a custom model',
83
  )
84
 
85
  parser.add_argument(
@@ -91,6 +91,7 @@ def get_parser():
91
  )
92
  return parser
93
 
 
94
  default_prompt = """
95
  Hello,
96
 
@@ -109,7 +110,6 @@ if __name__ == "__main__":
109
  device=0 if use_gpu else -1,
110
  )
111
 
112
-
113
  demo = gr.Blocks()
114
 
115
  logging.info("launching interface...")
@@ -119,7 +119,9 @@ if __name__ == "__main__":
119
  gr.Markdown(
120
  "Enter part of an email, and the model will autocomplete it for you!"
121
  )
122
- gr.Markdown('The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)')
 
 
123
  gr.Markdown("---")
124
 
125
  with gr.Column():
@@ -151,10 +153,11 @@ if __name__ == "__main__":
151
  value=2,
152
  )
153
  length_penalty = gr.Slider(
154
- minimum=0.5, maximum=1.0, label="length penalty", default=0.8, step=0.05
155
  )
156
  generated_email = gr.Textbox(
157
- label="Generated Result", placeholder="The completed email will appear here"
 
158
  )
159
 
160
  generate_button = gr.Button(
@@ -168,16 +171,24 @@ if __name__ == "__main__":
168
  gr.Markdown(
169
  "This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
170
  )
171
- gr.Markdown("The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something.")
 
 
172
  gr.Markdown("---")
173
 
174
  generate_button.click(
175
  fn=generate_text,
176
- inputs=[prompt_text, num_gen_tokens, num_beams, no_repeat_ngram_size, length_penalty],
 
 
 
 
 
 
177
  outputs=[generated_email],
178
  )
179
 
180
  demo.launch(
181
  enable_queue=True,
182
- share=True, # for local testing
183
  )
 
11
 
12
  use_gpu = torch.cuda.is_available()
13
 
14
+
15
  def generate_text(
16
  prompt: str,
17
  gen_length=64,
 
41
  st = time.perf_counter()
42
 
43
  input_tokens = generator.tokenizer(prompt)
44
+ input_len = len(input_tokens["input_ids"])
45
  if input_len > abs_max_length:
46
  logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
47
  result = generator(
 
56
  early_stopping=True,
57
  # tokenizer
58
  truncation=True,
59
+ ) # generate
60
+ response = result[0]["generated_text"]
 
61
  rt = time.perf_counter() - st
62
  if verbose:
63
  logging.info(f"Generated text: {response}")
 
74
  )
75
 
76
  parser.add_argument(
77
+ "-m",
78
+ "--model",
79
  required=False,
80
  type=str,
81
  default="postbot/distilgpt2-emailgen",
82
+ help="Pass an different huggingface model tag to use a custom model",
83
  )
84
 
85
  parser.add_argument(
 
91
  )
92
  return parser
93
 
94
+
95
  default_prompt = """
96
  Hello,
97
 
 
110
  device=0 if use_gpu else -1,
111
  )
112
 
 
113
  demo = gr.Blocks()
114
 
115
  logging.info("launching interface...")
 
119
  gr.Markdown(
120
  "Enter part of an email, and the model will autocomplete it for you!"
121
  )
122
+ gr.Markdown(
123
+ "The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)"
124
+ )
125
  gr.Markdown("---")
126
 
127
  with gr.Column():
 
153
  value=2,
154
  )
155
  length_penalty = gr.Slider(
156
+ minimum=0.5, maximum=1.0, label="length penalty", default=0.8, step=0.05
157
  )
158
  generated_email = gr.Textbox(
159
+ label="Generated Result",
160
+ placeholder="The completed email will appear here",
161
  )
162
 
163
  generate_button = gr.Button(
 
171
  gr.Markdown(
172
  "This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
173
  )
174
+ gr.Markdown(
175
+ "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something."
176
+ )
177
  gr.Markdown("---")
178
 
179
  generate_button.click(
180
  fn=generate_text,
181
+ inputs=[
182
+ prompt_text,
183
+ num_gen_tokens,
184
+ num_beams,
185
+ no_repeat_ngram_size,
186
+ length_penalty,
187
+ ],
188
  outputs=[generated_email],
189
  )
190
 
191
  demo.launch(
192
  enable_queue=True,
193
+ share=True, # for local testing
194
  )