SLAPaper's picture
Update app.py
73943f4 verified
import functools as ft
import gradio as gr
import torch
import transformers
from transformers import T5ForConditionalGeneration, T5Tokenizer
tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
"roborovski/superprompt-v1"
)
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
"roborovski/superprompt-v1"
)
@ft.lru_cache(maxsize=1024)
def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str:
transformers.set_seed(seed)
if max_new_tokens <= 0:
max_new_tokens = 150
with torch.inference_mode():
if prompt:
input_text = f"{prompt} {text}"
else:
input_text = f"Expand the following prompt to add more detail: {text}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
max_length=max_new_tokens,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
demo = gr.Interface(
fn=super_prompt,
inputs=[
gr.Textbox(label="input text"),
gr.Slider(label="seed", minimum=0, maximum=2**32-1, step=1),
gr.Slider(label="max_new_tokens", minimum=0, maximum=375, step=1),
gr.Textbox(label="custom prompt", placeholder="leave empty to use default")],
outputs=[gr.Textbox(label="output", lines=6)],
)
demo.launch()