File size: 5,424 Bytes
37adcee
 
 
 
 
 
 
 
6480777
37adcee
 
6480777
 
37adcee
 
 
 
6480777
 
37adcee
 
 
6480777
37adcee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6480777
 
37adcee
 
 
 
 
 
 
 
 
 
 
 
6480777
 
37adcee
13d80e4
 
 
 
 
 
 
 
 
37adcee
b2c8b1e
13d80e4
b2c8b1e
6480777
b2c8b1e
 
6480777
13d80e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c8b1e
 
 
 
 
37adcee
b2c8b1e
13d80e4
 
 
 
 
 
 
 
 
 
37adcee
 
 
 
 
 
 
 
 
 
 
 
6480777
 
37adcee
 
 
 
 
 
6480777
37adcee
 
 
 
 
 
6480777
37adcee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6480777
37adcee
f713d11
 
e995600
6480777
f713d11
37adcee
 
6480777
 
 
 
37adcee
 
 
 
 
f713d11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import time
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import gradio as gr
from threading import Thread

# モデルの定義
MODELS = {
    "Borea-Phi-3.5-mini-Jp": "AXCXEPT/Borea-Phi-3.5-mini-Instruct-Jp",
    "EZO-Common-9B": "HODACHI/EZO-Common-9B-gemma-2-it",
    "Phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
}

HF_TOKEN = os.environ.get("HF_TOKEN", None)

# タイトルとプレースホルダーを日本語に変更
TITLE = "<h1><center>Borea/EZO デモアプリ</center></h1>"

PLACEHOLDER = """
<center>
<p>こんにちは、私はAIアシスタントです。何でも質問してください。</p>
</center>
"""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = None
tokenizer = None

def load_model(model_name):
    global model, tokenizer
    model_path = MODELS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        quantization_config=quantization_config
    )

@spaces.GPU()
def stream_chat(
    message: str, 
    history: list,
    system_prompt: str,
    temperature: float = 0.8, 
    max_new_tokens: int = 1024, 
    top_p: float = 1.0, 
    top_k: int = 20, 
    repetition_penalty: float = 1.2,
    model_name: str = "Phi-3.5-mini"
):
    global model, tokenizer

    if model is None or tokenizer is None or model.name_or_path != MODELS[model_name]:
        load_model(model_name)

    print(f'message: {message}')
    print(f'history: {history}')

    conversation = [
        {"role": "system", "content": system_prompt}
    ]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=input_ids, 
        max_new_tokens=max_new_tokens,
        do_sample=False if temperature == 0 else True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        
    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer

chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)

with gr.Blocks(css=CSS, theme='ParityError/Interstellar') as demo:
    gr.HTML(TITLE)
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs=[
            gr.Textbox(
                value="あなたは親切なアシスタントです。",
                label="システムプロンプト",
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="温度 (Temperature)",
            ),
            gr.Slider(
                minimum=128,
                maximum=8192,
                step=1,
                value=1024,
                label="最大新規トークン数",
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
            ),
            gr.Slider(
                minimum=1.0,
                maximum=2.0,
                step=0.1,
                value=1.2,
                label="繰り返しペナルティ",
            ),
            gr.Dropdown(
                choices=list(MODELS.keys()),
                value="Borea-Phi-3.5-mini-Jp",
                label="モデル選択",
            ),
        ],
        examples=[
            ["語彙の勉強を手伝ってください。空欄を埋めるための文章を書いてください。私は正しい選択肢を選びます。"],
            ["子供のアート作品でできる5つの創造的なことを教えてください。捨てたくはないのですが、散らかってしまいます。"],
            ["ローマ帝国についてのランダムな面白い事実を教えてください。"],
            ["ウェブサイトの固定ヘッダーのCSSとJavaScriptのコードスニペットを見せてください。"],
        ],
        cache_examples=False,
    )

if __name__ == "__main__":
    demo.launch()