Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
DESCRIPTION = """# Nekomata-14B Instruction""" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
model_id = "rinna/nekomata-14b-instruction" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, device_map="auto", trust_remote_code=True, load_in_8bit=True, low_cpu_mem_usage=True | |
) | |
MAX_INPUT_TOKENS = 2048 | |
PROMPT_TEMPLATE = """ | |
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 | |
### 指示: | |
{instruction} | |
### 入力: | |
{input} | |
### 応答: | |
""" | |
def create_prompt(instruction: str, input_text: str, prompt_template: str = PROMPT_TEMPLATE) -> str: | |
return prompt_template.format(instruction=instruction, input=input_text) | |
def run( | |
instruction: str, | |
input_text: str, | |
prompt_template: str = PROMPT_TEMPLATE, | |
max_new_tokens: int = 256, | |
temperature: float = 0.5, | |
top_p: float = 0.95, | |
repetition_penalty: float = 1.0, | |
) -> Iterator[str]: | |
prompt = create_prompt(instruction, input_text, prompt_template) | |
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
if input_ids.shape[-1] > MAX_INPUT_TOKENS: | |
raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})") | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids.to(model.device)}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def process_example(instruction: str, input_text: str) -> Iterator[str]: | |
yield from run(instruction, input_text) | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.DuplicateButton( | |
value="Duplicate Space for private use", | |
elem_id="duplicate-button", | |
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
instruction = gr.Textbox(label="Instruction", lines=5) | |
input_text = gr.Textbox(label="Input", lines=5) | |
run_button = gr.Button() | |
with gr.Accordion(label="Advanced Options", open=False): | |
prompt_template = gr.Textbox(label="Prompt Template", lines=10, value=PROMPT_TEMPLATE) | |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256) | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.5) | |
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95) | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", minimum=0.0, maximum=2.0, step=0.01, value=1.0 | |
) | |
with gr.Column(): | |
output = gr.Textbox(label="Output", lines=10) | |
run_button.click( | |
fn=run, | |
inputs=[instruction, input_text, prompt_template, max_new_tokens, temperature, top_p, repetition_penalty], | |
outputs=output, | |
api_name="run", | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"次の日本語を英語に翻訳してください。", | |
"大規模言語モデル(だいきぼげんごモデル、英: large language model、LLM)は、多数のパラメータ(数千万から数十億)を持つ人工ニューラルネットワークで構成されるコンピュータ言語モデルで、膨大なラベルなしテキストを使用して自己教師あり学習または半教師あり学習によって訓練が行われる。", | |
], | |
["以下のトピックに関する詳細な情報を提供してください。", "夢オチとは何かについて教えてください。"], | |
["以下のトピックに関する詳細な情報を提供してください。", "暴れん坊将軍について教えてください。"], | |
], | |
inputs=[instruction, input_text], | |
outputs=output, | |
fn=process_example, | |
cache_examples=os.getenv("CACHE_EXAMPLES") == "1", | |
api_name=False, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |