import functools as ft import gradio as gr import torch import transformers from transformers import T5ForConditionalGeneration, T5Tokenizer tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( "roborovski/superprompt-v1" ) model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained( "roborovski/superprompt-v1" ) @ft.lru_cache(maxsize=1024) def super_prompt(text: str, seed: int, max_new_tokens: int, prompt: str) -> str: transformers.set_seed(seed) if max_new_tokens <= 0: max_new_tokens = 150 with torch.inference_mode(): if prompt: input_text = f"{prompt} {text}" else: input_text = f"Expand the following prompt to add more detail: {text}" input_ids = tokenizer(input_text, return_tensors="pt").input_ids outputs = model.generate( input_ids, max_length=max_new_tokens, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) demo = gr.Interface( fn=super_prompt, inputs=[ gr.Textbox(label="input text"), gr.Slider(label="seed", minimum=0, maximum=2**32-1, step=1), gr.Slider(label="max_new_tokens", minimum=0, maximum=375, step=1), gr.Textbox(label="custom prompt", placeholder="leave empty to use default")], outputs=[gr.Textbox(label="output", lines=6)], ) demo.launch()