import torch from RetNet.retnet.modeling_retnet import RetNetForCausalLM from transformers import AutoTokenizer import gradio as gr MODEL_NAME = "p1atdev/LightNovel-Intro-RetNet-400M" DEFAULT_INPUT_TEXT = "目が覚めると、" EXAMPLE_INPUTS = [ DEFAULT_INPUT_TEXT, "冒険者ギルドには", "真っ白い部屋の中、そこには", "20XX年、", "「なんだって!?」", "どうやらトラックにはねられ、俺は", ] tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = RetNetForCausalLM.from_pretrained(MODEL_NAME) model.eval() @torch.no_grad() def generate( input_text, max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.95, top_k=20, ): if input_text.strip() == "": return "" inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False) generated = model.custom_generate( **inputs, parallel_compute_prompt=True, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, ) return tokenizer.batch_decode(generated)[0] def continue_generate( input_text, *args, ): return input_text, generate(input_text, *args) with gr.Blocks() as demo: gr.Markdown( """\ # LightNovel-Intro-RetNet-400M-Demo ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。 ### 参考: - https://github.com/syncdoth/RetNet """ ) input_text = gr.Textbox( value=DEFAULT_INPUT_TEXT, placeholder="私の名前は...", lines=2, ) output_text = gr.Textbox( value="", placeholder="ここに出力が表示されます...", lines=8, interactive=False, ) with gr.Row(): generate_btn = gr.Button("Generate ✒️", variant="primary") continue_btn = gr.Button("Continue ➡️", variant="secondary") clear_btn = gr.ClearButton( value="Clear 🧹", components=[input_text, output_text], ) with gr.Accordion("Advanced settings", open=False): max_tokens = gr.Slider( label="Max tokens", minimum=8, maximum=512, value=128, step=4, ) do_sample = gr.Checkbox( label="Do sample", value=True, ) temperature = gr.Slider( label="Temperature", minimum=0, maximum=2, value=1, step=0.05, ) top_p = gr.Slider( label="Top p", minimum=0, maximum=1, value=0.95, step=0.05, ) top_k = gr.Slider( label="Top k", minimum=0, maximum=100, value=20, step=1, ) gr.Examples( examples=EXAMPLE_INPUTS, inputs=input_text, ) generate_btn.click( fn=generate, inputs=[ input_text, max_tokens, do_sample, temperature, top_p, top_k, ], outputs=output_text, queue=False, ) continue_btn.click( fn=continue_generate, inputs=[ input_text, max_tokens, do_sample, temperature, top_p, top_k, ], outputs=[input_text, output_text], queue=False, ) demo.launch()