Spaces:
Runtime error
Runtime error
import torch | |
from RetNet.retnet.modeling_retnet import RetNetForCausalLM | |
from transformers import AutoTokenizer | |
import gradio as gr | |
MODEL_NAME = "isek-ai/LightNovel-Intro-RetNet-400M" | |
DEFAULT_INPUT_TEXT = "目が覚めると、" | |
EXAMPLE_INPUTS = [ | |
DEFAULT_INPUT_TEXT, | |
"冒険者ギルドには", | |
"真っ白い部屋の中、そこには", | |
"20XX年、", | |
"「なんだって!?」", | |
"どうやらトラックにはねられ、俺は", | |
] | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = RetNetForCausalLM.from_pretrained(MODEL_NAME) | |
model.eval() | |
def generate( | |
input_text, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=1.0, | |
top_p=0.95, | |
top_k=20, | |
no_repeat_ngram_size=3, | |
repetition_penalty=1.2, | |
num_beams=1, | |
): | |
if input_text.strip() == "": | |
return "" | |
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False) | |
# generated = model.custom_generate( | |
# **inputs, | |
# parallel_compute_prompt=True, | |
# max_new_tokens=max_new_tokens, | |
# do_sample=do_sample, | |
# temperature=temperature, | |
# top_p=top_p, | |
# top_k=top_k, | |
# ) | |
generated = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
repetition_penalty=repetition_penalty, | |
num_beams=num_beams, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
return tokenizer.batch_decode(generated, skip_special_tokens=True)[0] | |
def continue_generate( | |
input_text, | |
*args, | |
): | |
return input_text, generate(input_text, *args) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"""\ | |
# LightNovel-Intro-RetNet-400M-Demo | |
ライトノベルの冒頭だけを学習した 400M パラメータの RetNet モデルのデモです。 | |
モデル: https://huggingface.co/p1atdev/LightNovel-Intro-RetNet-400M | |
### 参考: | |
- https://github.com/syncdoth/RetNet | |
""" | |
) | |
input_text = gr.Textbox( | |
label="Input text", | |
value=DEFAULT_INPUT_TEXT, | |
placeholder="私の名前は...", | |
lines=2, | |
) | |
output_text = gr.Textbox( | |
label="Output text", | |
value="", | |
placeholder="ここに出力が表示されます...", | |
lines=8, | |
interactive=False, | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate ✒️", variant="primary") | |
continue_btn = gr.Button("Continue ➡️", variant="secondary") | |
clear_btn = gr.ClearButton( | |
value="Clear 🧹", | |
components=[input_text, output_text], | |
) | |
with gr.Accordion("Advanced settings", open=False): | |
max_tokens = gr.Slider( | |
label="Max tokens", | |
minimum=8, | |
maximum=512, | |
value=64, | |
step=4, | |
) | |
do_sample = gr.Checkbox( | |
label="Do sample", | |
value=True, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0, | |
maximum=2, | |
value=1, | |
step=0.05, | |
) | |
top_p = gr.Slider( | |
label="Top p", | |
minimum=0, | |
maximum=1, | |
value=0.95, | |
step=0.05, | |
) | |
top_k = gr.Slider( | |
label="Top k", | |
minimum=0, | |
maximum=100, | |
value=20, | |
step=1, | |
) | |
no_repeat_ngram_size = gr.Slider( | |
label="No repeat ngram size", | |
minimum=0, | |
maximum=10, | |
value=3, | |
step=1, | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition penalty", | |
minimum=0, | |
maximum=2, | |
value=1.2, | |
step=0.1, | |
) | |
num_beams = gr.Slider( | |
label="Num beams", | |
minimum=1, | |
maximum=8, | |
value=1, | |
step=1, | |
) | |
gr.Examples( | |
examples=EXAMPLE_INPUTS, | |
inputs=input_text, | |
) | |
generate_btn.click( | |
fn=generate, | |
inputs=[ | |
input_text, | |
max_tokens, | |
do_sample, | |
temperature, | |
top_p, | |
top_k, | |
no_repeat_ngram_size, | |
repetition_penalty, | |
num_beams, | |
], | |
outputs=output_text, | |
queue=False, | |
) | |
continue_btn.click( | |
fn=continue_generate, | |
inputs=[ | |
output_text, | |
max_tokens, | |
do_sample, | |
temperature, | |
top_p, | |
top_k, | |
no_repeat_ngram_size, | |
repetition_penalty, | |
num_beams, | |
], | |
outputs=[input_text, output_text], | |
queue=False, | |
) | |
demo.launch() | |