Roblox22r pszemraj commited on
Commit
f590842
·
0 Parent(s):

Duplicate from postbot/autocomplete-emails

Browse files

Co-authored-by: Peter Szemraj <pszemraj@users.noreply.huggingface.co>

Files changed (6) hide show
  1. .gitattributes +31 -0
  2. .gitignore +21 -0
  3. README.md +19 -0
  4. app.py +297 -0
  5. requirements.txt +3 -0
  6. utils.py +165 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ignore gradio db files# sys files
3
+ *__pycache__*
4
+ *__pycache__/
5
+
6
+ # data
7
+
8
+ *.txt
9
+ *.pkl
10
+ *flagged/
11
+
12
+ # ignore log files
13
+ *.log
14
+ *logs/
15
+
16
+ # scratch
17
+ *scratch/
18
+ *scratch*
19
+
20
+ # notebooks
21
+ *notebooks/
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Autocomplete Emails
3
+ emoji: 📨
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.1.4
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ tags:
12
+ - email
13
+ - autocomplete
14
+ - text generation
15
+ - contrastive search
16
+ duplicated_from: postbot/autocomplete-emails
17
+ ---
18
+
19
+ Check out the configuration reference at <https://huggingface.co/docs/hub/spaces-config-reference>
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pprint as pp
3
+ import logging
4
+ import time
5
+ import gradio as gr
6
+ import torch
7
+ from transformers import pipeline
8
+
9
+ from utils import make_mailto_form, postprocess, clear, make_email_link
10
+
11
+ logging.basicConfig(
12
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
13
+ )
14
+
15
+ use_gpu = torch.cuda.is_available()
16
+
17
+
18
+ def generate_text(
19
+ prompt: str,
20
+ gen_length=64,
21
+ penalty_alpha=0.6,
22
+ top_k=6,
23
+ no_repeat_ngram_size=2,
24
+ length_penalty=1.0,
25
+ # perma params (not set by user)
26
+ abs_max_length=512,
27
+ verbose=False,
28
+ ):
29
+ """
30
+ generate_text - generate text from a prompt using a text generation pipeline
31
+
32
+ Args:
33
+ prompt (str): the prompt to generate text from
34
+ model_input (_type_): the text generation pipeline
35
+ max_length (int, optional): the maximum length of the generated text. Defaults to 128.
36
+ method (str, optional): the generation method. Defaults to "Sampling".
37
+ verbose (bool, optional): the verbosity of the output. Defaults to False.
38
+
39
+ Returns:
40
+ str: the generated text
41
+ """
42
+ global generator
43
+ if verbose:
44
+ logging.info(f"Generating text from prompt:\n\n{prompt}")
45
+ logging.info(
46
+ pp.pformat(
47
+ f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}"
48
+ )
49
+ )
50
+ st = time.perf_counter()
51
+
52
+ input_tokens = generator.tokenizer(prompt)
53
+ input_len = len(input_tokens["input_ids"])
54
+ if input_len > abs_max_length:
55
+ logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
56
+ result = generator(
57
+ prompt,
58
+ max_length=gen_length + input_len,
59
+ min_length=input_len + 4,
60
+ penalty_alpha=penalty_alpha,
61
+ top_k=top_k,
62
+ no_repeat_ngram_size=no_repeat_ngram_size,
63
+ length_penalty=length_penalty,
64
+ ) # generate
65
+ response = result[0]["generated_text"]
66
+ rt = time.perf_counter() - st
67
+ if verbose:
68
+ logging.info(f"Generated text: {response}")
69
+ rt_string = f"Generation time: {rt:.2f}s"
70
+ logging.info(rt_string)
71
+
72
+ formatted_email = postprocess(response)
73
+ return make_mailto_form(body=formatted_email), formatted_email
74
+
75
+
76
+ def load_emailgen_model(model_tag: str):
77
+ """
78
+ load_emailgen_model - load a text generation pipeline for email generation
79
+
80
+ Args:
81
+ model_tag (str): the huggingface model tag to load
82
+
83
+ Returns:
84
+ transformers.pipelines.TextGenerationPipeline: the text generation pipeline
85
+ """
86
+ global generator
87
+ generator = pipeline(
88
+ "text-generation",
89
+ model_tag,
90
+ device=0 if use_gpu else -1,
91
+ )
92
+
93
+
94
+ def get_parser():
95
+ """
96
+ get_parser - a helper function for the argparse module
97
+ """
98
+ parser = argparse.ArgumentParser(
99
+ description="Text Generation demo for postbot",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "-m",
104
+ "--model",
105
+ required=False,
106
+ type=str,
107
+ default="postbot/distilgpt2-emailgen-V2",
108
+ help="Pass an different huggingface model tag to use a custom model",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "-v",
113
+ "--verbose",
114
+ required=False,
115
+ action="store_true",
116
+ help="Verbose output",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "-a",
121
+ "--penalty_alpha",
122
+ type=float,
123
+ default=0.6,
124
+ help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
125
+ )
126
+
127
+ parser.add_argument(
128
+ "-k",
129
+ "--top_k",
130
+ type=int,
131
+ default=6,
132
+ help="The top k for the text generation pipeline (contrastive search) - default 6",
133
+ )
134
+ return parser
135
+
136
+
137
+ default_prompt = """
138
+ Hello,
139
+
140
+ Following up on last week's bubblegum shipment, I"""
141
+
142
+ available_models = [
143
+ "postbot/distilgpt2-emailgen-V2",
144
+ "postbot/distilgpt2-emailgen",
145
+ "postbot/gpt2-medium-emailgen",
146
+ ]
147
+
148
+ if __name__ == "__main__":
149
+
150
+ logging.info("\n\n\nStarting new instance of app.py")
151
+ args = get_parser().parse_args()
152
+ logging.info(f"received args:\t{args}")
153
+ model_tag = args.model
154
+ verbose = args.verbose
155
+ top_k = args.top_k
156
+ alpha = args.penalty_alpha
157
+
158
+ assert top_k > 0, "top_k must be greater than 0"
159
+ assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"
160
+
161
+ logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
162
+ generator = pipeline(
163
+ "text-generation",
164
+ model_tag,
165
+ device=0 if use_gpu else -1,
166
+ )
167
+
168
+ demo = gr.Blocks()
169
+
170
+ logging.info("launching interface...")
171
+
172
+ with demo:
173
+ gr.Markdown("# Auto-Complete Emails - Demo")
174
+ gr.Markdown(
175
+ "Enter part of an email, and a text-gen model will complete it! See details below. "
176
+ )
177
+ gr.Markdown("---")
178
+
179
+ with gr.Column():
180
+
181
+ gr.Markdown("## Generate Text")
182
+ gr.Markdown("Edit the prompt and parameters and press **Generate**!")
183
+ prompt_text = gr.Textbox(
184
+ lines=4,
185
+ label="Email Prompt",
186
+ value=default_prompt,
187
+ )
188
+
189
+ with gr.Row():
190
+ clear_button = gr.Button(
191
+ value="Clear Prompt",
192
+ )
193
+ num_gen_tokens = gr.Slider(
194
+ label="Generation Tokens",
195
+ value=32,
196
+ maximum=96,
197
+ minimum=16,
198
+ step=8,
199
+ )
200
+
201
+ generate_button = gr.Button(
202
+ value="Generate!",
203
+ variant="primary",
204
+ )
205
+ gr.Markdown("---")
206
+ gr.Markdown("### Results")
207
+ # put a large HTML placeholder here
208
+ generated_email = gr.Textbox(
209
+ label="Generated Text",
210
+ placeholder="This is where the generated text will appear",
211
+ interactive=False,
212
+ )
213
+ email_mailto_button = gr.HTML(
214
+ "<i>a clickable email button will appear here</i>"
215
+ )
216
+
217
+ gr.Markdown("---")
218
+ gr.Markdown("## Advanced Options")
219
+ gr.Markdown(
220
+ "This demo generates text via the new [constrastive search](https://huggingface.co/blog/introducing-csearch). See details on the csearch blog post for the methods' parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
221
+ )
222
+ with gr.Row():
223
+ model_name = gr.Dropdown(
224
+ choices=available_models,
225
+ label="Choose a model",
226
+ value=model_tag,
227
+ )
228
+ load_model_button = gr.Button(
229
+ "Load Model",
230
+ variant="secondary",
231
+ )
232
+ no_repeat_ngram_size = gr.Radio(
233
+ choices=[1, 2, 3, 4],
234
+ label="no repeat ngram size",
235
+ value=3,
236
+ )
237
+ with gr.Row():
238
+ contrastive_top_k = gr.Radio(
239
+ choices=[2, 4, 6, 8],
240
+ label="Top K",
241
+ value=top_k,
242
+ )
243
+
244
+ penalty_alpha = gr.Slider(
245
+ label="Penalty Alpha",
246
+ value=alpha,
247
+ maximum=1.0,
248
+ minimum=0.0,
249
+ step=0.1,
250
+ )
251
+ length_penalty = gr.Slider(
252
+ minimum=0.5,
253
+ maximum=1.0,
254
+ label="Length Penalty",
255
+ value=1.0,
256
+ step=0.1,
257
+ )
258
+ gr.Markdown("---")
259
+
260
+ with gr.Column():
261
+
262
+ gr.Markdown("## About")
263
+ gr.Markdown(
264
+ "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
265
+ )
266
+ gr.Markdown(
267
+ "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
268
+ )
269
+ gr.Markdown("---")
270
+
271
+ clear_button.click(
272
+ fn=clear,
273
+ inputs=[prompt_text],
274
+ outputs=[prompt_text],
275
+ )
276
+ generate_button.click(
277
+ fn=generate_text,
278
+ inputs=[
279
+ prompt_text,
280
+ num_gen_tokens,
281
+ penalty_alpha,
282
+ contrastive_top_k,
283
+ no_repeat_ngram_size,
284
+ length_penalty,
285
+ ],
286
+ outputs=[email_mailto_button, generated_email],
287
+ )
288
+
289
+ load_model_button.click(
290
+ fn=load_emailgen_model,
291
+ inputs=[model_name],
292
+ outputs=[],
293
+ )
294
+ demo.launch(
295
+ enable_queue=True,
296
+ share=True, # for local testing
297
+ )
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers>=4.24.0
utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py - Utility functions for the project.
3
+ """
4
+ import logging
5
+ import re
6
+
7
+
8
+ def postprocess(text: str):
9
+ """
10
+ postprocess - remove common values in scraped dataset
11
+
12
+ Args:
13
+ text (str): the text to postprocess
14
+ """
15
+
16
+ replacements = {
17
+ "ENA": "COMPANY",
18
+ "Enron": "COMPANY",
19
+ "Enron Corporation": "COMPANY",
20
+ "Sony Pictures Entertainment": "COMPANY",
21
+ "Columbia Pictures": "COMPANY",
22
+ "Sony": "COMPANY",
23
+ "Columbia": "COMPANY",
24
+ "Hillary": "Jane",
25
+ "Clinton": "Smith",
26
+ "Amy": "Jane",
27
+ "Sara": "Jane",
28
+ "Harambe": "Jane",
29
+ "Pascal": "PERSON",
30
+ }
31
+
32
+ # replace common values, also check lowercase
33
+ for k, v in replacements.items():
34
+ text = text.replace(k, v)
35
+ text = text.replace(k.lower(), v)
36
+
37
+ return text
38
+
39
+
40
+ def clear(text, verbose=False, **kwargs):
41
+ """for use with buttons"""
42
+ if verbose:
43
+ logging.info(f"Clearing text: {text}")
44
+ return ""
45
+
46
+
47
+ def make_email_link(
48
+ subject: str = "Email subject - This was generated by Postbot",
49
+ link_text: str = "click to open in your email client",
50
+ body: str = None,
51
+ tag_placeholder: str = "PLACEHOLDER",
52
+ ):
53
+ """
54
+ email_link - generate an email link
55
+
56
+ Args:
57
+ subject (str, optional): the subject of the email. Defaults to "Email subject - This was generated by Postbot".
58
+ link_text (str, optional): the text of the link. Defaults to "click to open in your email client".
59
+ body (str, optional): the body of the email. Defaults to None.
60
+ tag_placeholder (str, optional): the placeholder for the tag. Defaults to "PLACEHOLDER".
61
+
62
+ Returns:
63
+ str: the email link, in the form of an html link
64
+ """
65
+
66
+ if body is None:
67
+ body = "hmm - no body. replace me"
68
+
69
+ # strip brackets and other HTML-tag characters from body with regex
70
+ body = re.sub(r"<[^>]*>", tag_placeholder, body)
71
+
72
+ # replace all newline chars with a whitespace
73
+ body = body.replace("\n", " ")
74
+
75
+ nice_html_button = f"""<!DOCTYPE html>
76
+ <html>
77
+ <head>
78
+ <title>Generated Email</title>
79
+ <style>
80
+ body {{
81
+ font-family: sans-serif;
82
+ font-size: 1.2em;
83
+ }}
84
+ .button {{
85
+ background-color: #6CCEC6;
86
+ border: none;
87
+ color: white;
88
+ padding: 15px 32px;
89
+ text-align: center;
90
+ text-decoration: none;
91
+ display: inline-block;
92
+ font-size: 16px;
93
+ margin: 4px 2px;
94
+ cursor: pointer;
95
+ value: "Send Email";
96
+ }}
97
+ </style>
98
+ <button class="button" onclick="window.location.href='mailto:?subject={subject}&body={body}'">{link_text} value="Open in Email client"</button>
99
+ </html>"""
100
+
101
+ # return f'<a href="mailto:%20?subject={subject}&body={body}">{link_text}</a>'
102
+ return nice_html_button
103
+
104
+
105
+ def make_mailto_form(
106
+ body: str = None,
107
+ subject: str = "This email was generated by Postbot with AI!",
108
+ cc_email: str = "",
109
+ ):
110
+ """Returns a mailto link with the given parameters"""
111
+
112
+ if body is None:
113
+ body = "hmm - no body. Replace me or try rerunning the model."
114
+
115
+ template = f"""<!DOCTYPE html>
116
+ <html>
117
+ <head>
118
+ <title>Generated Email</title>
119
+ <style>
120
+ body {{
121
+ font-family: sans-serif;
122
+ font-size: 1.2em;
123
+ }}
124
+ .button {{
125
+ background-color: #6CCEC6;
126
+ border: none;
127
+ color: white;
128
+ padding: 15px 32px;
129
+ text-align: center;
130
+ text-decoration: none;
131
+ display: inline-block;
132
+ font-size: 16px;
133
+ margin: 4px 2px;
134
+ cursor: pointer;
135
+ value: "Send Email";
136
+ }}
137
+ </style>
138
+ </head>
139
+ <body>
140
+ <h1>Adjust and Open in your mail client:</h1>
141
+ <form action="mailto:" method="get" enctype="text/plain">
142
+
143
+ <div>
144
+ <label for="cc">CC Email:
145
+ <input type="text" name="cc" id="cc" value="{cc_email}"/>
146
+ </label>
147
+ </div>
148
+ <div>
149
+ <label for="subject">Subject:
150
+ <input type="text" name="subject" id="subject" value="{subject}"/>
151
+ </label>
152
+ </div>
153
+ <div>
154
+ <label>Email Body:</label>
155
+ <br />
156
+ <textarea name="body" id="body" rows="12" cols="35">{body}</textarea>
157
+ </div>
158
+ <div>
159
+ <input type="submit" name="submit" value="Open in Email App" class="button"/>
160
+ </div>
161
+ </form>
162
+ </body>
163
+ </html>"""
164
+
165
+ return template