p1atdev's picture
feat: create demo app
617aba5
raw
history blame
No virus
3.56 kB
import torch
from RetNet.retnet.modeling_retnet import RetNetForCausalLM
from transformers import AutoTokenizer
import gradio as gr
MODEL_NAME = "p1atdev/LightNovel-Intro-RetNet-400M"
DEFAULT_INPUT_TEXT = "目が覚めると、"
EXAMPLE_INPUTS = [
DEFAULT_INPUT_TEXT,
"冒険者ギルドには",
"真っ白い部屋の中、そこには",
"20XX年、",
"「なんだって!?」",
"どうやらトラックにはねられ、俺は",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = RetNetForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
@torch.no_grad()
def generate(
input_text,
max_new_tokens=128,
do_sample=True,
temperature=1.0,
top_p=0.95,
top_k=20,
):
if input_text.strip() == "":
return ""
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
generated = model.custom_generate(
**inputs,
parallel_compute_prompt=True,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return tokenizer.batch_decode(generated)[0]
def continue_generate(
input_text,
*args,
):
return input_text, generate(input_text, *args)
with gr.Blocks() as demo:
gr.Markdown(
"""\
# LightNovel-Intro-RetNet-400M-Demo
ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。
### 参考:
- https://github.com/syncdoth/RetNet
"""
)
input_text = gr.Textbox(
value=DEFAULT_INPUT_TEXT,
placeholder="私の名前は...",
lines=2,
)
output_text = gr.Textbox(
value="",
placeholder="ここに出力が表示されます...",
lines=8,
interactive=False,
)
with gr.Row():
generate_btn = gr.Button("Generate ✒️", variant="primary")
continue_btn = gr.Button("Continue ➡️", variant="secondary")
clear_btn = gr.ClearButton(
value="Clear 🧹",
components=[input_text, output_text],
)
with gr.Accordion("Advanced settings", open=False):
max_tokens = gr.Slider(
label="Max tokens",
minimum=8,
maximum=512,
value=128,
step=4,
)
do_sample = gr.Checkbox(
label="Do sample",
value=True,
)
temperature = gr.Slider(
label="Temperature",
minimum=0,
maximum=2,
value=1,
step=0.05,
)
top_p = gr.Slider(
label="Top p",
minimum=0,
maximum=1,
value=0.95,
step=0.05,
)
top_k = gr.Slider(
label="Top k",
minimum=0,
maximum=100,
value=20,
step=1,
)
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=input_text,
)
generate_btn.click(
fn=generate,
inputs=[
input_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
],
outputs=output_text,
queue=False,
)
continue_btn.click(
fn=continue_generate,
inputs=[
input_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
],
outputs=[input_text, output_text],
queue=False,
)
demo.launch()