import torch from RetNet.retnet.modeling_retnet import RetNetForCausalLM from transformers import AutoTokenizer import gradio as gr MODEL_NAME = "isek-ai/LightNovel-Intro-RetNet-400M" DEFAULT_INPUT_TEXT = "目が覚めると、" EXAMPLE_INPUTS = [ DEFAULT_INPUT_TEXT, "冒険者ギルドには", "真っ白い部屋の中、そこには", "20XX年、", "「なんだって!?」", "どうやらトラックにはねられ、俺は", ] tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = RetNetForCausalLM.from_pretrained(MODEL_NAME) model.eval() @torch.no_grad() def generate( input_text, max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.95, top_k=20, no_repeat_ngram_size=3, repetition_penalty=1.2, num_beams=1, ): if input_text.strip() == "": return "" inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False) # generated = model.custom_generate( # **inputs, # parallel_compute_prompt=True, # max_new_tokens=max_new_tokens, # do_sample=do_sample, # temperature=temperature, # top_p=top_p, # top_k=top_k, # ) generated = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, top_k=top_k, no_repeat_ngram_size=no_repeat_ngram_size, repetition_penalty=repetition_penalty, num_beams=num_beams, eos_token_id=tokenizer.eos_token_id, ) return tokenizer.batch_decode(generated, skip_special_tokens=True)[0] def continue_generate( input_text, *args, ): return input_text, generate(input_text, *args) with gr.Blocks() as demo: gr.Markdown( """\ # LightNovel-Intro-RetNet-400M-Demo ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。 モデル: https://huggingface.co/isek-ai/LightNovel-Intro-RetNet-400M ### 参考: - https://github.com/syncdoth/RetNet """ ) input_text = gr.Textbox( label="Input text", value=DEFAULT_INPUT_TEXT, placeholder="私の名前は...", lines=2, ) output_text = gr.Textbox( label="Output text", value="", placeholder="ここに出力が表示されます...", lines=8, interactive=False, ) with gr.Row(): generate_btn = gr.Button("Generate ✒️", variant="primary") continue_btn = gr.Button("Continue ➡️", variant="secondary") clear_btn = gr.ClearButton( value="Clear 🧹", components=[input_text, output_text], ) with gr.Accordion("Advanced settings", open=False): max_tokens = gr.Slider( label="Max tokens", minimum=8, maximum=512, value=64, step=4, ) do_sample = gr.Checkbox( label="Do sample", value=True, ) temperature = gr.Slider( label="Temperature", minimum=0, maximum=2, value=1, step=0.05, ) top_p = gr.Slider( label="Top p", minimum=0, maximum=1, value=0.95, step=0.05, ) top_k = gr.Slider( label="Top k", minimum=0, maximum=100, value=20, step=1, ) no_repeat_ngram_size = gr.Slider( label="No repeat ngram size", minimum=0, maximum=10, value=3, step=1, ) repetition_penalty = gr.Slider( label="Repetition penalty", minimum=0, maximum=2, value=1.2, step=0.1, ) num_beams = gr.Slider( label="Num beams", minimum=1, maximum=8, value=1, step=1, ) gr.Examples( examples=EXAMPLE_INPUTS, inputs=input_text, ) generate_btn.click( fn=generate, inputs=[ input_text, max_tokens, do_sample, temperature, top_p, top_k, no_repeat_ngram_size, repetition_penalty, num_beams, ], outputs=output_text, queue=False, ) continue_btn.click( fn=continue_generate, inputs=[ output_text, max_tokens, do_sample, temperature, top_p, top_k, no_repeat_ngram_size, repetition_penalty, num_beams, ], outputs=[input_text, output_text], queue=False, ) demo.launch()