|
from threading import Thread |
|
|
|
import torch |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
|
|
import gradio as gr |
|
|
|
MODEL_NAME = "isek-ai/SDPrompt-RetNet-300M" |
|
|
|
DEFAULT_INPUT_TEXT = "1girl," |
|
|
|
EXAMPLE_INPUTS = [DEFAULT_INPUT_TEXT, "oil painting of", "high quality photo of"] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
input_text, |
|
max_new_tokens=128, |
|
do_sample=True, |
|
temperature=1.0, |
|
top_p=0.95, |
|
top_k=20, |
|
|
|
repetition_penalty=1.2, |
|
num_beams=1, |
|
): |
|
if input_text.strip() == "": |
|
return "" |
|
|
|
inputs = tokenizer( |
|
f"<s>{input_text}", return_tensors="pt", add_special_tokens=False |
|
)["input_ids"] |
|
|
|
generated = model.generate( |
|
inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
|
|
repetition_penalty=repetition_penalty, |
|
num_beams=num_beams, |
|
|
|
) |
|
|
|
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0] |
|
|
|
|
|
def continue_generate( |
|
input_text, |
|
*args, |
|
): |
|
return input_text, generate(input_text, *args) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
"""\ |
|
# SDPrompt-RetNet-300M-Demo |
|
A RetNet model trained with Stable Diffusion prompts and Danbooru tags. |
|
|
|
Model: https://huggingface.co/isek-ai/SDPrompt-RetNet-300M |
|
|
|
### Reference: |
|
- https://github.com/syncdoth/RetNet |
|
""" |
|
) |
|
|
|
input_text = gr.Textbox( |
|
label="Input text", |
|
value=DEFAULT_INPUT_TEXT, |
|
placeholder="beautiful photo of ...", |
|
lines=2, |
|
) |
|
output_text = gr.Textbox( |
|
label="Output text", |
|
value="", |
|
placeholder="Output will appear here...", |
|
lines=8, |
|
interactive=False, |
|
) |
|
|
|
with gr.Row(): |
|
generate_btn = gr.Button("Generate ✒️", variant="primary") |
|
continue_btn = gr.Button("Continue ➡️", variant="secondary") |
|
clear_btn = gr.ClearButton( |
|
value="Clear 🧹", |
|
components=[input_text, output_text], |
|
) |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
max_tokens = gr.Slider( |
|
label="Max tokens", |
|
minimum=8, |
|
maximum=512, |
|
value=75, |
|
step=4, |
|
) |
|
do_sample = gr.Checkbox( |
|
label="Do sample", |
|
value=True, |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0, |
|
maximum=1, |
|
value=0.9, |
|
step=0.05, |
|
) |
|
top_p = gr.Slider( |
|
label="Top p", |
|
minimum=0, |
|
maximum=1, |
|
value=0.95, |
|
step=0.05, |
|
) |
|
top_k = gr.Slider( |
|
label="Top k", |
|
minimum=0, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
) |
|
repetition_penalty = gr.Slider( |
|
label="Repetition penalty", |
|
minimum=0, |
|
maximum=2, |
|
value=1, |
|
step=0.1, |
|
) |
|
num_beams = gr.Slider( |
|
label="Num beams", |
|
minimum=1, |
|
maximum=10, |
|
value=1, |
|
step=1, |
|
) |
|
|
|
gr.Examples( |
|
examples=EXAMPLE_INPUTS, |
|
inputs=input_text, |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate, |
|
inputs=[ |
|
input_text, |
|
max_tokens, |
|
do_sample, |
|
temperature, |
|
top_p, |
|
top_k, |
|
repetition_penalty, |
|
num_beams, |
|
], |
|
outputs=output_text, |
|
queue=False, |
|
) |
|
continue_btn.click( |
|
fn=continue_generate, |
|
inputs=[ |
|
output_text, |
|
max_tokens, |
|
do_sample, |
|
temperature, |
|
top_p, |
|
top_k, |
|
repetition_penalty, |
|
num_beams, |
|
], |
|
outputs=[input_text, output_text], |
|
queue=False, |
|
) |
|
|
|
demo.queue() |
|
demo.launch( |
|
debug=True, |
|
show_error=True, |
|
) |
|
|