|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
def load_model(model_name): |
|
|
|
generator = pipeline('text-generation', model=model_name) |
|
return generator |
|
|
|
|
|
def generate_text(model_name, prompt, custom_prompt, temperature, max_length, top_p, beam_size, frequency_penalty, presence_penalty): |
|
if temperature == 0: |
|
temperature = 0.0001 |
|
do_sample = False |
|
else: |
|
do_sample = True |
|
generator = load_model(model_name) |
|
if custom_prompt: |
|
prompt = custom_prompt |
|
generate_text = generator(prompt, temperature=float(temperature), max_length=max_length, top_p=top_p, num_beams=beam_size, truncation=True) |
|
return generate_text[0]['generated_text'] |
|
|
|
|
|
prompts = ["Write a tagline for an ice cream shop", "Describe the Word War II", "Write a short story about a robot", "Explain the concept of gravity"] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Radio(choices=["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"], label="Model", value="gpt2", info="Choose the size of the model to use."), |
|
gr.Dropdown(choices=prompts, label="Prompt", info="Select a pre-written prompt."), |
|
gr.Textbox(label="Custom Prompt", placeholder="Or write your own prompt here", lines=5), |
|
gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Temperature", info="Controls randomness: Higher values make the output more random, while lower values make the output more deterministic and repetitive."), |
|
gr.Slider(minimum=1, maximum=256, value=16, label="Maximum Length", info="The maximum number of tokens to generate shared between the prompt and the completion."), |
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Top P", info="Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered."), |
|
gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Beam Size", info="Number of beams to use for beam search. 1 means Greedy decoding."), |
|
|
|
], |
|
outputs=["text"], |
|
title="GPT-2 playground Mockup", |
|
description="Adjust the sliders and enter a prompt to generate text using GPT-2." |
|
) |
|
|
|
demo.launch() |