peter szemraj
:art: format
74f30e9
import argparse
import logging
import time
import gradio as gr
import torch
from transformers import pipeline
from utils import postprocess, clear
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
use_gpu = torch.cuda.is_available()
def generate_text(
prompt: str,
gen_length=64,
num_beams=4,
no_repeat_ngram_size=2,
length_penalty=1.0,
# perma params (not set by user)
repetition_penalty=3.5,
abs_max_length=512,
verbose=False,
):
"""
generate_text - generate text from a prompt using a text generation pipeline
Args:
prompt (str): the prompt to generate text from
model_input (_type_): the text generation pipeline
max_length (int, optional): the maximum length of the generated text. Defaults to 128.
method (str, optional): the generation method. Defaults to "Sampling".
verbose (bool, optional): the verbosity of the output. Defaults to False.
Returns:
str: the generated text
"""
global generator
if verbose:
logging.info(f"Generating text from prompt:\n\n{prompt}")
logging.info(
f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}"
)
st = time.perf_counter()
input_tokens = generator.tokenizer(prompt)
input_len = len(input_tokens["input_ids"])
if input_len > abs_max_length:
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
result = generator(
prompt,
max_length=gen_length + input_len,
min_length=input_len + 4,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
do_sample=False,
early_stopping=True,
# tokenizer
truncation=True,
) # generate
response = result[0]["generated_text"]
rt = time.perf_counter() - st
if verbose:
logging.info(f"Generated text: {response}")
logging.info(f"Generation time: {rt:.2f}s")
return postprocess(response)
def get_parser():
"""
get_parser - a helper function for the argparse module
"""
parser = argparse.ArgumentParser(
description="Text Generation demo for postbot",
)
parser.add_argument(
"-m",
"--model",
required=False,
type=str,
default="postbot/distilgpt2-emailgen",
help="Pass an different huggingface model tag to use a custom model",
)
parser.add_argument(
"-v",
"--verbose",
required=False,
action="store_true",
help="Verbose output",
)
return parser
default_prompt = """
Hello,
Following up on last's bubblegum shipment, I"""
if __name__ == "__main__":
logging.info("\n\n\nStarting new instance of app.py")
args = get_parser().parse_args()
logging.info(f"received args:\t{args}")
model_tag = args.model
verbose = args.verbose
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
generator = pipeline(
"text-generation",
model_tag,
device=0 if use_gpu else -1,
)
demo = gr.Blocks()
logging.info("launching interface...")
with demo:
gr.Markdown("# Auto-Complete Emails - Demo")
gr.Markdown(
"Enter part of an email, and a text-gen model will complete it! See details below. "
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## Generate Text")
gr.Markdown("Edit the prompt and parameters and press **Generate**!")
prompt_text = gr.Textbox(
lines=4,
label="Email Prompt",
value=default_prompt,
)
with gr.Row():
clear_button = gr.Button(
value="Clear Prompt",
)
num_gen_tokens = gr.Slider(
label="Generation Tokens",
value=64,
maximum=128,
minimum=32,
step=16,
)
generated_email = gr.Textbox(
label="Generated Result",
placeholder="The completed email will appear here",
)
generate_button = gr.Button(
value="Generate!",
variant="primary",
)
gr.Markdown("## Advanced Options")
gr.Markdown(
"This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is."
)
num_beams = gr.Radio(
choices=[4, 8, 16],
label="Number of Beams",
value=4,
)
no_repeat_ngram_size = gr.Radio(
choices=[1, 2, 3, 4],
label="no repeat ngram size",
value=2,
)
length_penalty = gr.Slider(
minimum=0.5, maximum=1.0, label="length penalty", value=0.8, step=0.1
)
gr.Markdown("---")
with gr.Column():
gr.Markdown("## About")
gr.Markdown(
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage."
)
gr.Markdown(
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something."
)
gr.Markdown("---")
clear_button.on_click(
fn=clear,
input_text=[prompt_text],
output_text=[prompt_text],
)
generate_button.click(
fn=generate_text,
inputs=[
prompt_text,
num_gen_tokens,
num_beams,
no_repeat_ngram_size,
length_penalty,
],
outputs=[generated_email],
)
demo.launch(
enable_queue=True,
share=True, # for local testing
)