from threading import Thread import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer import gradio as gr MODEL_NAME = "isek-ai/SDPrompt-RetNet-300M" DEFAULT_INPUT_TEXT = "1girl," EXAMPLE_INPUTS = [DEFAULT_INPUT_TEXT, "oil painting of", "high quality photo of"] tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True) model.eval() # streamer = TextStreamer( # tokenizer, # skip_prompt=False, # skip_special_tokens=True, # ) @torch.no_grad() def generate( input_text, max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.95, top_k=20, # no_repeat_ngram_size=3, repetition_penalty=1.2, num_beams=1, ): if input_text.strip() == "": return "" inputs = tokenizer( f"{input_text}", return_tensors="pt", add_special_tokens=False )["input_ids"] generated = model.generate( inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, # no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty, num_beams=num_beams, # streamer=streamer, ) return tokenizer.batch_decode(generated, skip_special_tokens=True)[0] def continue_generate( input_text, *args, ): return input_text, generate(input_text, *args) with gr.Blocks() as demo: gr.Markdown( """\ # SDPrompt-RetNet-300M-Demo A RetNet model trained with Stable Diffusion prompts and Danbooru tags. Model: https://huggingface.co/isek-ai/SDPrompt-RetNet-300M ### Reference: - https://github.com/syncdoth/RetNet """ ) input_text = gr.Textbox( label="Input text", value=DEFAULT_INPUT_TEXT, placeholder="beautiful photo of ...", lines=2, ) output_text = gr.Textbox( label="Output text", value="", placeholder="Output will appear here...", lines=8, interactive=False, ) with gr.Row(): generate_btn = gr.Button("Generate ✒️", variant="primary") continue_btn = gr.Button("Continue ➡️", variant="secondary") clear_btn = gr.ClearButton( value="Clear 🧹", components=[input_text, output_text], ) with gr.Accordion("Advanced settings", open=False): max_tokens = gr.Slider( label="Max tokens", minimum=8, maximum=512, value=75, step=4, ) do_sample = gr.Checkbox( label="Do sample", value=True, ) temperature = gr.Slider( label="Temperature", minimum=0, maximum=1, value=0.9, step=0.05, ) top_p = gr.Slider( label="Top p", minimum=0, maximum=1, value=0.95, step=0.05, ) top_k = gr.Slider( label="Top k", minimum=0, maximum=100, value=50, step=1, ) repetition_penalty = gr.Slider( label="Repetition penalty", minimum=0, maximum=2, value=1, step=0.1, ) num_beams = gr.Slider( label="Num beams", minimum=1, maximum=10, value=1, step=1, ) gr.Examples( examples=EXAMPLE_INPUTS, inputs=input_text, ) generate_btn.click( fn=generate, inputs=[ input_text, max_tokens, do_sample, temperature, top_p, top_k, repetition_penalty, num_beams, ], outputs=output_text, queue=False, ) continue_btn.click( fn=continue_generate, inputs=[ output_text, max_tokens, do_sample, temperature, top_p, top_k, repetition_penalty, num_beams, ], outputs=[input_text, output_text], queue=False, ) demo.queue() demo.launch( debug=True, show_error=True, )