p1atdev's picture
Update app.py
ed17f60
import torch
from RetNet.retnet.modeling_retnet import RetNetForCausalLM
from transformers import AutoTokenizer
import gradio as gr
MODEL_NAME = "isek-ai/LightNovel-Intro-RetNet-400M"
DEFAULT_INPUT_TEXT = "目が覚めると、"
EXAMPLE_INPUTS = [
DEFAULT_INPUT_TEXT,
"冒険者ギルドには",
"真っ白い部屋の中、そこには",
"20XX年、",
"「なんだって!?」",
"どうやらトラックにはねられ、俺は",
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = RetNetForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
@torch.no_grad()
def generate(
input_text,
max_new_tokens=128,
do_sample=True,
temperature=1.0,
top_p=0.95,
top_k=20,
no_repeat_ngram_size=3,
repetition_penalty=1.2,
num_beams=1,
):
if input_text.strip() == "":
return ""
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False)
# generated = model.custom_generate(
# **inputs,
# parallel_compute_prompt=True,
# max_new_tokens=max_new_tokens,
# do_sample=do_sample,
# temperature=temperature,
# top_p=top_p,
# top_k=top_k,
# )
generated = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
def continue_generate(
input_text,
*args,
):
return input_text, generate(input_text, *args)
with gr.Blocks() as demo:
gr.Markdown(
"""\
# LightNovel-Intro-RetNet-400M-Demo
ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。
モデル: https://huggingface.co/isek-ai/LightNovel-Intro-RetNet-400M
### 参考:
- https://github.com/syncdoth/RetNet
"""
)
input_text = gr.Textbox(
label="Input text",
value=DEFAULT_INPUT_TEXT,
placeholder="私の名前は...",
lines=2,
)
output_text = gr.Textbox(
label="Output text",
value="",
placeholder="ここに出力が表示されます...",
lines=8,
interactive=False,
)
with gr.Row():
generate_btn = gr.Button("Generate ✒️", variant="primary")
continue_btn = gr.Button("Continue ➡️", variant="secondary")
clear_btn = gr.ClearButton(
value="Clear 🧹",
components=[input_text, output_text],
)
with gr.Accordion("Advanced settings", open=False):
max_tokens = gr.Slider(
label="Max tokens",
minimum=8,
maximum=512,
value=64,
step=4,
)
do_sample = gr.Checkbox(
label="Do sample",
value=True,
)
temperature = gr.Slider(
label="Temperature",
minimum=0,
maximum=2,
value=1,
step=0.05,
)
top_p = gr.Slider(
label="Top p",
minimum=0,
maximum=1,
value=0.95,
step=0.05,
)
top_k = gr.Slider(
label="Top k",
minimum=0,
maximum=100,
value=20,
step=1,
)
no_repeat_ngram_size = gr.Slider(
label="No repeat ngram size",
minimum=0,
maximum=10,
value=3,
step=1,
)
repetition_penalty = gr.Slider(
label="Repetition penalty",
minimum=0,
maximum=2,
value=1.2,
step=0.1,
)
num_beams = gr.Slider(
label="Num beams",
minimum=1,
maximum=8,
value=1,
step=1,
)
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=input_text,
)
generate_btn.click(
fn=generate,
inputs=[
input_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
no_repeat_ngram_size,
repetition_penalty,
num_beams,
],
outputs=output_text,
queue=False,
)
continue_btn.click(
fn=continue_generate,
inputs=[
output_text,
max_tokens,
do_sample,
temperature,
top_p,
top_k,
no_repeat_ngram_size,
repetition_penalty,
num_beams,
],
outputs=[input_text, output_text],
queue=False,
)
demo.launch()