import argparse import logging import time import gradio as gr import torch from transformers import pipeline from utils import postprocess, clear logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) use_gpu = torch.cuda.is_available() def generate_text( prompt: str, gen_length=64, num_beams=4, no_repeat_ngram_size=2, length_penalty=1.0, # perma params (not set by user) repetition_penalty=3.5, abs_max_length=512, verbose=False, ): """ generate_text - generate text from a prompt using a text generation pipeline Args: prompt (str): the prompt to generate text from model_input (_type_): the text generation pipeline max_length (int, optional): the maximum length of the generated text. Defaults to 128. method (str, optional): the generation method. Defaults to "Sampling". verbose (bool, optional): the verbosity of the output. Defaults to False. Returns: str: the generated text """ global generator if verbose: logging.info(f"Generating text from prompt:\n\n{prompt}") logging.info( f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}" ) st = time.perf_counter() input_tokens = generator.tokenizer(prompt) input_len = len(input_tokens["input_ids"]) if input_len > abs_max_length: logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") result = generator( prompt, max_length=gen_length + input_len, min_length=input_len + 4, num_beams=num_beams, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, length_penalty=length_penalty, do_sample=False, early_stopping=True, # tokenizer truncation=True, ) # generate response = result[0]["generated_text"] rt = time.perf_counter() - st if verbose: logging.info(f"Generated text: {response}") logging.info(f"Generation time: {rt:.2f}s") return postprocess(response) def get_parser(): """ get_parser - a helper function for the argparse module """ parser = argparse.ArgumentParser( description="Text Generation demo for postbot", ) parser.add_argument( "-m", "--model", required=False, type=str, default="postbot/distilgpt2-emailgen", help="Pass an different huggingface model tag to use a custom model", ) parser.add_argument( "-v", "--verbose", required=False, action="store_true", help="Verbose output", ) return parser default_prompt = """ Hello, Following up on last's bubblegum shipment, I""" if __name__ == "__main__": logging.info("\n\n\nStarting new instance of app.py") args = get_parser().parse_args() logging.info(f"received args:\t{args}") model_tag = args.model verbose = args.verbose logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") generator = pipeline( "text-generation", model_tag, device=0 if use_gpu else -1, ) demo = gr.Blocks() logging.info("launching interface...") with demo: gr.Markdown("# Auto-Complete Emails - Demo") gr.Markdown( "Enter part of an email, and a text-gen model will complete it! See details below. " ) gr.Markdown("---") with gr.Column(): gr.Markdown("## Generate Text") gr.Markdown("Edit the prompt and parameters and press **Generate**!") prompt_text = gr.Textbox( lines=4, label="Email Prompt", value=default_prompt, ) with gr.Row(): clear_button = gr.Button( value="Clear Prompt", ) num_gen_tokens = gr.Slider( label="Generation Tokens", value=64, maximum=128, minimum=32, step=16, ) generated_email = gr.Textbox( label="Generated Result", placeholder="The completed email will appear here", ) generate_button = gr.Button( value="Generate!", variant="primary", ) gr.Markdown("## Advanced Options") gr.Markdown( "This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is." ) num_beams = gr.Radio( choices=[4, 8, 16], label="Number of Beams", value=4, ) no_repeat_ngram_size = gr.Radio( choices=[1, 2, 3, 4], label="no repeat ngram size", value=2, ) length_penalty = gr.Slider( minimum=0.5, maximum=1.0, label="length penalty", value=0.8, step=0.1 ) gr.Markdown("---") with gr.Column(): gr.Markdown("## About") gr.Markdown( "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." ) gr.Markdown( "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." ) gr.Markdown("---") clear_button.on_click( fn=clear, input_text=[prompt_text], output_text=[prompt_text], ) generate_button.click( fn=generate_text, inputs=[ prompt_text, num_gen_tokens, num_beams, no_repeat_ngram_size, length_penalty, ], outputs=[generated_email], ) demo.launch( enable_queue=True, share=True, # for local testing )