peter szemraj commited on
Commit
81d65e8
1 Parent(s): d9f8cf2

:tada: init

Browse files
Files changed (2) hide show
  1. .gitignore +21 -0
  2. app.py +182 -0
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ignore gradio db files# sys files
3
+ *__pycache__*
4
+ *__pycache__/
5
+
6
+ # data
7
+
8
+ *.txt
9
+ *.pkl
10
+ *flagged/
11
+
12
+ # ignore log files
13
+ *.log
14
+ *logs/
15
+
16
+ # scratch
17
+ *scratch/
18
+ *scratch*
19
+
20
+ # notebooks
21
+ *notebooks/
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import time
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import pipeline
7
+
8
+ logging.basicConfig(
9
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
10
+ )
11
+
12
+ use_gpu = torch.cuda.is_available()
13
+
14
+ def generate_text(
15
+ prompt: str,
16
+ gen_length=64,
17
+ num_beams=4,
18
+ no_repeat_ngram_size=2,
19
+ length_penalty=1.0,
20
+ # perma params (not set by user)
21
+ repetition_penalty=3.5,
22
+ abs_max_length=512,
23
+ verbose=False,
24
+ ):
25
+ """
26
+ generate_text - generate text from a prompt using a text generation pipeline
27
+
28
+ Args:
29
+ prompt (str): the prompt to generate text from
30
+ model_input (_type_): the text generation pipeline
31
+ max_length (int, optional): the maximum length of the generated text. Defaults to 128.
32
+ method (str, optional): the generation method. Defaults to "Sampling".
33
+ verbose (bool, optional): the verbosity of the output. Defaults to False.
34
+
35
+ Returns:
36
+ str: the generated text
37
+ """
38
+ global generator
39
+ logging.info(f"Generating text from prompt: {prompt}")
40
+ st = time.perf_counter()
41
+
42
+ input_tokens = generator.tokenizer(prompt)
43
+ input_len = len(input_tokens['input_ids'])
44
+ if input_len > abs_max_length:
45
+ logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
46
+ result = generator(
47
+ prompt,
48
+ max_length=gen_length + input_len,
49
+ min_length=input_len + 4,
50
+ num_beams=num_beams,
51
+ repetition_penalty=repetition_penalty,
52
+ no_repeat_ngram_size=no_repeat_ngram_size,
53
+ length_penalty=length_penalty,
54
+ do_sample=False,
55
+ early_stopping=True,
56
+ # tokenizer
57
+ truncation=True,
58
+
59
+ ) # generate
60
+ response = result[0]['generated_text']
61
+ rt = time.perf_counter() - st
62
+ if verbose:
63
+ logging.info(f"Generated text: {response}")
64
+ logging.info(f"Generation time: {rt:.2f}s")
65
+ return response
66
+
67
+
68
+ def get_parser():
69
+ """
70
+ get_parser - a helper function for the argparse module
71
+ """
72
+ parser = argparse.ArgumentParser(
73
+ description="Text Generation demo for postbot",
74
+ )
75
+
76
+ parser.add_argument(
77
+ '-m',
78
+ '--model',
79
+ required=False,
80
+ type=str,
81
+ default="postbot/distilgpt2-emailgen",
82
+ help='Pass an different huggingface model tag to use a custom model',
83
+ )
84
+
85
+ parser.add_argument(
86
+ "-v",
87
+ "--verbose",
88
+ required=False,
89
+ action="store_true",
90
+ help="Verbose output",
91
+ )
92
+ return parser
93
+
94
+ default_prompt = """
95
+ Hello,
96
+
97
+ Following up on the bubblegum shipment."""
98
+
99
+ if __name__ == "__main__":
100
+ logging.info("\n\n\nStarting new instance of app.py")
101
+ args = get_parser().parse_args()
102
+ logging.info(f"received args:\t{args}")
103
+ model_tag = args.model
104
+ verbose = args.verbose
105
+ logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
106
+ generator = pipeline(
107
+ "text-generation",
108
+ model_tag,
109
+ device=0 if use_gpu else -1,
110
+ )
111
+
112
+
113
+ demo = gr.Blocks()
114
+
115
+ logging.info("launching interface...")
116
+
117
+ with demo:
118
+ gr.Markdown("# Autocompleting Emails with Textgen - Demo")
119
+ gr.Markdown(
120
+ "Enter part of an email, and the model will autocomplete it for you!"
121
+ )
122
+ gr.Markdown('The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)')
123
+ gr.Markdown("---")
124
+
125
+ with gr.Column():
126
+
127
+ gr.Markdown("## Generate Text")
128
+ gr.Markdown(
129
+ "Enter/edit the prompt and adjust the parameters as needed. Then press the Generate button!"
130
+ )
131
+ prompt_text = gr.Textbox(
132
+ lines=4,
133
+ label="Email Prompt",
134
+ value=default_prompt,
135
+ )
136
+ num_gen_tokens = gr.Slider(
137
+ label="Generation Tokens",
138
+ default=64,
139
+ maximum=128,
140
+ minimum=32,
141
+ step=16,
142
+ )
143
+ num_beams = gr.Radio(
144
+ choices=[4, 8, 16],
145
+ label="num beams",
146
+ value=4,
147
+ )
148
+ no_repeat_ngram_size = gr.Radio(
149
+ choices=[1, 2, 3, 4],
150
+ label="no repeat ngram size",
151
+ value=2,
152
+ )
153
+ length_penalty = gr.Slider(
154
+ minimum=0.5, maximum=1.0, label="length penalty", default=0.8, step=0.05
155
+ )
156
+ generated_email = gr.Textbox(
157
+ label="Generated Result", placeholder="The completed email will appear here"
158
+ )
159
+
160
+ generate_button = gr.Button(
161
+ "Generate!",
162
+ )
163
+ gr.Markdown("---")
164
+
165
+ with gr.Column():
166
+
167
+ gr.Markdown("## About")
168
+ gr.Markdown(
169
+ "This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
170
+ )
171
+ gr.Markdown("The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something.")
172
+ gr.Markdown("---")
173
+
174
+ generate_button.click(
175
+ fn=generate_text,
176
+ inputs=[prompt_text, num_gen_tokens, num_beams, no_repeat_ngram_size, length_penalty],
177
+ outputs=[generated_email],
178
+ )
179
+
180
+ demo.launch(
181
+ enable_queue=True,
182
+ )