p1atdev's picture
feat: gradio interface
f8eb38f
raw history blame
No virus
4.34 kB
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import gradio as gr
MODEL_NAME = "isek-ai/SDPrompt-RetNet-300M"
DEFAULT_INPUT_TEXT = "1girl,"
EXAMPLE_INPUTS = [DEFAULT_INPUT_TEXT, "oil painting of", "high quality photo of"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.eval()
# streamer = TextStreamer(
# tokenizer,
# skip_prompt=False,
# skip_special_tokens=True,
# )
@torch.no_grad()
def generate(
input_text,
max_new_tokens=128,
do_sample=True,
temperature=1.0,
top_p=0.95,
top_k=20,
# no_repeat_ngram_size=3,
repetition_penalty=1.2,
num_beams=1,
):
if input_text.strip() == "":
return ""
inputs = tokenizer(
f"<s>{input_text}", return_tensors="pt", add_special_tokens=False
)["input_ids"]
generated = model.generate(
inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
# no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
# streamer=streamer,
)
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
def continue_generate(
input_text,
*args,
):
return input_text, generate(input_text, *args)
with gr.Blocks() as demo:
gr.Markdown(
"""\
# SDPrompt-RetNet-300M-Demo
A RetNet model trained with Stable Diffusion prompts and Danbooru tags.
Model: https://huggingface.co/isek-ai/SDPrompt-RetNet-300M
### Reference:
- https://github.com/syncdoth/RetNet
"""
)
input_text = gr.Textbox(
label="Input text",
value=DEFAULT_INPUT_TEXT,
placeholder="beautiful photo of ...",
lines=2,
)
output_text = gr.Textbox(
label="Output text",
value="",
placeholder="Output will appear here...",
lines=8,
interactive=False,
)
with gr.Row():
generate_btn = gr.Button("Generate ✒️", variant="primary")
continue_btn = gr.Button("Continue ➡️", variant="secondary")
clear_btn = gr.ClearButton(
value="Clear 🧹",
components=[input_text, output_text],
)
with gr.Accordion("Advanced settings", open=False):
max_tokens = gr.Slider(
label="Max tokens",
minimum=8,
maximum=512,
value=75,
step=4,
)
do_sample = gr.Checkbox(
label="Do sample",
value=True,
)
temperature = gr.Slider(
label="Temperature",
minimum=0,
maximum=1,
value=0.9,
step=0.05,
)
top_p = gr.Slider(
label="Top p",
minimum=0,
maximum=1,
value=0.95,
step=0.05,
)
top_k = gr.Slider(
label="Top k",
minimum=0,
maximum=100,
value=50,
step=1,
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
minimum=0,
maximum=2,
value=1,
step=0.1,
)
num_beams = gr.Slider(
label="Num beams",
minimum=1,
maximum=10,
value=4,
step=1,
)
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=input_text,
)
generate_btn.click(
fn=generate,
inputs=[
input_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
num_beams,
],
outputs=output_text,
queue=False,
)
continue_btn.click(
fn=continue_generate,
inputs=[
output_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
repetition_penalty,
num_beams,
],
outputs=[input_text, output_text],
queue=False,
)
demo.queue()
demo.launch(
debug=True,
show_error=True,
)