SuperPrompt-v1 / app.py
Nick088's picture
Update app.py
244f082 verified
raw
history blame
No virus
3.53 kB
import gradio as gr
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
def load_model(model_path, dtype):
if dtype == "fp32":
torch_dtype = torch.float32
elif dtype == "fp16":
torch_dtype = torch.float16
else:
raise ValueError("Invalid dtype. Only 'fp32' or 'fp16' are supported.")
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch_dtype)
return model
def generate(
prompt,
history,
max_new_tokens,
repetition_penalty,
temperature,
top_p,
top_k,
seed,
model_path="roborovski/superprompt-v1",
dtype="fp16",
):
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = load_model(model_path, dtype)
if torch.cuda.is_available():
device = "cuda"
print("Using GPU")
else:
device = "cpu"
print("Using CPU")
model.to(device)
input_text = f"{prompt}, {history}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
torch.manual_seed(seed)
outputs = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
better_prompt = tokenizer.decode(outputs[0])
return better_prompt
additional_inputs = [
gr.Slider(
value=512,
minimum=250,
maximum=512,
step=1,
interactive=True,
label="Max New Tokens",
info="The maximum numbers of new tokens, controls how long is the output",
),
gr.Slider(
value=1.2,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Repetition Penalty",
info="Penalize repeated tokens, making the AI repeat less itself",
),
gr.Slider(
value=0.5,
minimum=0,
maximum=1,
step=0.05,
interactive=True,
label="Temperature",
info="Higher values produce more diverse outputs",
),
gr.Slider(
value=1,
minimum=0,
maximum=2,
step=0.05,
interactive=True,
label="Top P",
info="Higher values sample more low-probability tokens",
),
gr.Slider(
value=1,
minimum=1,
maximum=100,
step=1,
interactive=True,
label="Top K",
info="Higher k means more diverse outputs by considering a range of tokens",
),
gr.Number(
value=42,
interactive=True,
label="Seed",
info="A starting point to initiate the generation process",
),
gr.Radio(
choices=["fp32", "fp16"],
value="fp16",
label="Model Precision",
info="Select the precision of the model: fp32 or fp16",
),
]
examples = [
[
"Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it.",
None,
None,
None,
None,
None,
None,
None,
"roborovski/superprompt-v1",
"fp16",
]
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(
show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"
),
additional_inputs=additional_inputs,
title="SuperPrompt-v1",
description="Make your prompts more detailed!",
examples=examples,
concurrency_limit=20,
).launch(show_api=False)