pszemraj's picture
Update app.py
6e289c3
raw history blame
No virus
9.36 kB
import argparse
import pprint as pp
import logging
import time
import gradio as gr
import torch
from transformers import pipeline
from utils import make_mailto_form, postprocess, clear, make_email_link
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
use_gpu = torch.cuda.is_available()
def generate_text(
prompt: str,
gen_length=64,
penalty_alpha=0.6,
top_k=6,
length_penalty=1.0,
# perma params (not set by user)
abs_max_length=512,
verbose=False,
):
"""
generate_text - generate text using the text generation pipeline
:param str prompt: the prompt to use for the text generation pipeline
:param int gen_length: the number of tokens to generate
:param float penalty_alpha: the penalty alpha for the text generation pipeline (contrastive search)
:param int top_k: the top k for the text generation pipeline (contrastive search)
:param int abs_max_length: the absolute max length for the text generation pipeline
:param bool verbose: verbose output
:return str: the generated text
"""
global generator
if verbose:
logging.info(f"Generating text from prompt:\n\n{prompt}")
logging.info(
pp.pformat(
f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, length_penalty={length_penalty}"
)
)
st = time.perf_counter()
input_tokens = generator.tokenizer(prompt)
input_len = len(input_tokens["input_ids"])
if input_len > abs_max_length:
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
result = generator(
prompt,
max_length=gen_length + input_len, # old API for generation
min_length=input_len + 4,
penalty_alpha=penalty_alpha,
top_k=top_k,
length_penalty=length_penalty,
) # generate
response = result[0]["generated_text"]
rt = time.perf_counter() - st
if verbose:
logging.info(f"Generated text: {response}")
rt_string = f"Generation time: {rt:.2f}s"
logging.info(rt_string)
formatted_email = postprocess(response)
return make_mailto_form(body=formatted_email), formatted_email
def load_emailgen_model(model_tag: str):
"""
load_emailgen_model - load a text generation pipeline for email generation
Args:
model_tag (str): the huggingface model tag to load
Returns:
transformers.pipelines.TextGenerationPipeline: the text generation pipeline
"""
global generator
generator = pipeline(
"text-generation",
model_tag,
device=0 if use_gpu else -1,
)
def get_parser():
"""
get_parser - a helper function for the argparse module
"""
parser = argparse.ArgumentParser(
description="Text Generation demo for postbot",
)
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="postbot/distilgpt2-emailgen-V2",
help="Pass an different huggingface model tag to use a custom model",
)
parser.add_argument(
"-l",
"--max_length",
required=False,
type=int,
default=40,
help="default max length of the generated text",
)
parser.add_argument(
"-a",
"--penalty_alpha",
type=float,
default=0.6,
help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
)
parser.add_argument(
"-k",
"--top_k",
type=int,
default=6,
help="The top k for the text generation pipeline (contrastive search) - default 6",
)
parser.add_argument(
"-v",
"--verbose",
required=False,
action="store_true",
help="Verbose output",
)
return parser
default_prompt = """
Hello,
Following up on last week's bubblegum shipment, I"""
available_models = [
"postbot/distilgpt2-emailgen-V2",
"postbot/distilgpt2-emailgen",
"postbot/gpt2-medium-emailgen",
"postbot/pythia-160m-hq-emails",
]
if __name__ == "__main__":
logging.info("\n\n\nStarting new instance of app.py")
args = get_parser().parse_args()
logging.info(f"received args:\t{args}")
model_tag = args.model
verbose = args.verbose
max_length = args.max_length
top_k = args.top_k
alpha = args.penalty_alpha
assert top_k > 0, "top_k must be greater than 0"
assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
generator = pipeline(
"text-generation",
model_tag,
device=0 if use_gpu else -1,
)
demo = gr.Blocks()
logging.info("launching interface...")
with demo:
gr.Markdown("# Auto-Complete Emails - Demo")
gr.Markdown(
"Enter part of an email, and a text-gen model will complete it! See details below. "
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## Generate Text")
gr.Markdown("Edit the prompt and parameters and press **Generate**!")
prompt_text = gr.Textbox(
lines=4,
label="Email Prompt",
value=default_prompt,
)
with gr.Row():
clear_button = gr.Button(
value="Clear Prompt",
)
num_gen_tokens = gr.Slider(
label="Generation Tokens",
value=max_length,
maximum=96,
minimum=16,
step=8,
)
generate_button = gr.Button(
value="Generate!",
variant="primary",
)
gr.Markdown("---")
gr.Markdown("### Results")
# put a large HTML placeholder here
generated_email = gr.Textbox(
label="Generated Text",
placeholder="This is where the generated text will appear",
interactive=False,
)
email_mailto_button = gr.HTML(
"<i>a clickable email button will appear here</i>"
)
gr.Markdown("---")
gr.Markdown("## Advanced Options")
gr.Markdown(
"This demo generates text via the new [contrastive search](https://huggingface.co/blog/introducing-csearch). See the csearch blog post for details on the parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
)
with gr.Row():
model_name = gr.Dropdown(
choices=available_models,
label="Choose a model",
value=model_tag,
)
load_model_button = gr.Button(
"Load Model",
variant="secondary",
)
with gr.Row():
contrastive_top_k = gr.Radio(
choices=[2, 4, 6, 8],
label="Top K",
value=top_k,
)
penalty_alpha = gr.Slider(
label="Penalty Alpha",
value=alpha,
maximum=1.0,
minimum=0.0,
step=0.1,
)
length_penalty = gr.Slider(
minimum=0.5,
maximum=1.0,
label="Length Penalty",
value=1.0,
step=0.1,
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## About")
gr.Markdown(
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 100k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
)
gr.Markdown(
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
)
gr.Markdown("---")
clear_button.click(
fn=clear,
inputs=[prompt_text],
outputs=[prompt_text],
)
generate_button.click(
fn=generate_text,
inputs=[
prompt_text,
num_gen_tokens,
penalty_alpha,
contrastive_top_k,
length_penalty,
],
outputs=[email_mailto_button, generated_email],
)
load_model_button.click(
fn=load_emailgen_model,
inputs=[model_name],
outputs=[],
)
demo.launch(
enable_queue=True,
share=True, # for local testing
)