File size: 5,080 Bytes
9b79ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python

import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """# Nekomata-14B Instruction"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

if torch.cuda.is_available():
    model_id = "rinna/nekomata-14b-instruction"

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="auto", trust_remote_code=True, load_in_8bit=True, low_cpu_mem_usage=True
    )

MAX_INPUT_TOKENS = 2048

PROMPT_TEMPLATE = """
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
{instruction}

### 入力:
{input}

### 応答:
"""


def create_prompt(instruction: str, input_text: str, prompt_template: str = PROMPT_TEMPLATE) -> str:
    return prompt_template.format(instruction=instruction, input=input_text)


@spaces.GPU
@torch.inference_mode()
def run(
    instruction: str,
    input_text: str,
    prompt_template: str = PROMPT_TEMPLATE,
    max_new_tokens: int = 256,
    temperature: float = 0.5,
    top_p: float = 0.95,
    repetition_penalty: float = 1.0,
) -> Iterator[str]:
    prompt = create_prompt(instruction, input_text, prompt_template)
    input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    if input_ids.shape[-1] > MAX_INPUT_TOKENS:
        raise gr.Error(f"Input exceeds maximum number of tokens ({MAX_INPUT_TOKENS})")

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids.to(model.device)},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


def process_example(instruction: str, input_text: str) -> Iterator[str]:
    yield from run(instruction, input_text)


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )

    with gr.Row():
        with gr.Column():
            instruction = gr.Textbox(label="Instruction", lines=5)
            input_text = gr.Textbox(label="Input", lines=5)
            run_button = gr.Button()

            with gr.Accordion(label="Advanced Options", open=False):
                prompt_template = gr.Textbox(label="Prompt Template", lines=10, value=PROMPT_TEMPLATE)
                max_new_tokens = gr.Slider(label="Max New Tokens", minimum=1, maximum=1024, step=1, value=256)
                temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=0.5)
                top_p = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.95)
                repetition_penalty = gr.Slider(
                    label="Repetition Penalty", minimum=0.0, maximum=2.0, step=0.01, value=1.0
                )

        with gr.Column():
            output = gr.Textbox(label="Output", lines=10)

        run_button.click(
            fn=run,
            inputs=[instruction, input_text, prompt_template, max_new_tokens, temperature, top_p, repetition_penalty],
            outputs=output,
            api_name="run",
        )

    gr.Examples(
        examples=[
            [
                "次の日本語を英語に翻訳してください。",
                "大規模言語モデル(だいきぼげんごモデル、英: large language model、LLM)は、多数のパラメータ(数千万から数十億)を持つ人工ニューラルネットワークで構成されるコンピュータ言語モデルで、膨大なラベルなしテキストを使用して自己教師あり学習または半教師あり学習によって訓練が行われる。",
            ],
            ["以下のトピックに関する詳細な情報を提供してください。", "夢オチとは何かについて教えてください。"],
            ["以下のトピックに関する詳細な情報を提供してください。", "暴れん坊将軍について教えてください。"],
        ],
        inputs=[instruction, input_text],
        outputs=output,
        fn=process_example,
        cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
        api_name=False,
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()