kaktuspassion's picture
initial commit
81c89b2 verified
import gradio as gr
from transformers import pipeline
# Define the model loading function
def load_model(model_name):
# Load the text generation pipeline
generator = pipeline('text-generation', model=model_name)
return generator
# Define the text generation function
def generate_text(model_name, prompt, custom_prompt, temperature, max_length, top_p, beam_size, frequency_penalty, presence_penalty):
if temperature == 0:
temperature = 0.0001
do_sample = False
else:
do_sample = True
generator = load_model(model_name)
if custom_prompt:
prompt = custom_prompt
generate_text = generator(prompt, temperature=float(temperature), max_length=max_length, top_p=top_p, num_beams=beam_size, truncation=True)
return generate_text[0]['generated_text']
# Pre-written prompts
prompts = ["Write a tagline for an ice cream shop", "Describe the Word War II", "Write a short story about a robot", "Explain the concept of gravity"]
# Interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Radio(choices=["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"], label="Model", value="gpt2", info="Choose the size of the model to use."),
gr.Dropdown(choices=prompts, label="Prompt", info="Select a pre-written prompt."),
gr.Textbox(label="Custom Prompt", placeholder="Or write your own prompt here", lines=5),
gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=1.0, label="Temperature", info="Controls randomness: Higher values make the output more random, while lower values make the output more deterministic and repetitive."),
gr.Slider(minimum=1, maximum=256, value=16, label="Maximum Length", info="The maximum number of tokens to generate shared between the prompt and the completion."),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=1.0, label="Top P", info="Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered."),
gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Beam Size", info="Number of beams to use for beam search. 1 means Greedy decoding."),
],
outputs=["text"],
title="GPT-2 playground Mockup",
description="Adjust the sliders and enter a prompt to generate text using GPT-2."
)
demo.launch()