Spaces:
Sleeping
Sleeping
File size: 5,080 Bytes
9b79ef5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
#!/usr/bin/env python
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """# Nekomata-14B Instruction"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "rinna/nekomata-14b-instruction"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="auto", trust_remote_code=True, load_in_8bit=True, low_cpu_mem_usage=True
)
MAX_INPUT_TOKENS = 2048
PROMPT_TEMPLATE = """
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
{instruction}
### 入力:
{input}
### 応答:
"""
def create_prompt(instruction: str, input_text: str, prompt_template: str = PROMPT_TEMPLATE) -> str:
return prompt_template.format(instruction=instruction, input=input_text)
@spaces.GPU
@torch.inference_mode()
def run(
instruction: str,
input_text: str,
prompt_template: str = PROMPT_TEMPLATE,
max_new_tokens: int = 256,
temperature: float = 0.5,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
prompt = create_prompt(instruction, input_text, prompt_template)
input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
if input_ids.shape[-1] > MAX_INPUT_TOKENS:
raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids.to(model.device)},
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def process_example(instruction: str, input_text: str) -> Iterator[str]:
yield from run(instruction, input_text)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Row():
with gr.Column():
instruction = gr.Textbox(label="Instruction", lines=5)
input_text = gr.Textbox(label="Input", lines=5)
run_button = gr.Button()
with gr.Accordion(label="Advanced Options", open=False):
prompt_template = gr.Textbox(label="Prompt Template", lines=10, value=PROMPT_TEMPLATE)
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.5)
top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
repetition_penalty = gr.Slider(
label="Repetition Penalty", minimum=0.0, maximum=2.0, step=0.01, value=1.0
)
with gr.Column():
output = gr.Textbox(label="Output", lines=10)
run_button.click(
fn=run,
inputs=[instruction, input_text, prompt_template, max_new_tokens, temperature, top_p, repetition_penalty],
outputs=output,
api_name="run",
)
gr.Examples(
examples=[
[
"次の日本語を英語に翻訳してください。",
"大規模言語モデル(だいきぼげんごモデル、英: large language model、LLM)は、多数のパラメータ(数千万から数十億)を持つ人工ニューラルネットワークで構成されるコンピュータ言語モデルで、膨大なラベルなしテキストを使用して自己教師あり学習または半教師あり学習によって訓練が行われる。",
],
["以下のトピックに関する詳細な情報を提供してください。", "夢オチとは何かについて教えてください。"],
["以下のトピックに関する詳細な情報を提供してください。", "暴れん坊将軍について教えてください。"],
],
inputs=[instruction, input_text],
outputs=output,
fn=process_example,
cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
api_name=False,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|