doropiza commited on
Commit
e1d42ff
·
1 Parent(s): ec0c074
Files changed (2) hide show
  1. app.py +311 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ChatGPT Clone - 日本語対応チャットボット
3
+ Hugging Face Spaces (ZeroGPU) 対応版
4
+
5
+ 使用モデル:
6
+ - elyza/Llama-3-ELYZA-JP-8B
7
+ - Fugaku-LLM/Fugaku-LLM-13B
8
+ - openai/gpt-oss-20b
9
+ """
10
+
11
+ import gradio as gr
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ import os
15
+ from typing import List, Tuple
16
+
17
+ # Hugging Face token from environment variable
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+ # トークンのチェック
21
+ if not HF_TOKEN:
22
+ print("警告: HF_TOKENが設定されていません。プライベートモデルへのアクセスが制限される場合があります。")
23
+
24
+ # Check if running on ZeroGPU
25
+ try:
26
+ import spaces
27
+ IS_ZEROGPU = True
28
+ print("ZeroGPU環境を検出しました。")
29
+ except ImportError:
30
+ IS_ZEROGPU = False
31
+ print("通常のGPU/CPU環境で実行しています。")
32
+
33
+ class ChatBot:
34
+ def __init__(self):
35
+ self.model = None
36
+ self.tokenizer = None
37
+ self.current_model = None
38
+
39
+ def load_model(self, model_name: str):
40
+ """モデルとトークナイザーをロード"""
41
+ if self.current_model == model_name and self.model is not None:
42
+ return
43
+
44
+ try:
45
+ # メモリクリア
46
+ if self.model is not None:
47
+ del self.model
48
+ del self.tokenizer
49
+ if torch.cuda.is_available():
50
+ torch.cuda.empty_cache()
51
+ torch.cuda.synchronize()
52
+
53
+ # トークナイザーロード
54
+ self.tokenizer = AutoTokenizer.from_pretrained(
55
+ model_name,
56
+ token=HF_TOKEN,
57
+ trust_remote_code=True,
58
+ padding_side="left"
59
+ )
60
+
61
+ # パッドトークンの設定
62
+ if self.tokenizer.pad_token is None:
63
+ self.tokenizer.pad_token = self.tokenizer.eos_token
64
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
65
+
66
+ # モデルロード(ZeroGPU対応)
67
+ self.model = AutoModelForCausalLM.from_pretrained(
68
+ model_name,
69
+ token=HF_TOKEN,
70
+ torch_dtype=torch.float16,
71
+ low_cpu_mem_usage=True,
72
+ trust_remote_code=True,
73
+ load_in_8bit=False, # ZeroGPU環境では8bit量子化は使わない
74
+ device_map=None # ZeroGPU環境では自動マッピングしない
75
+ )
76
+
77
+ self.current_model = model_name
78
+ print(f"モデル {model_name} のロードが完了しました。")
79
+
80
+ except Exception as e:
81
+ print(f"モデルのロード中にエラーが発生しました: {str(e)}")
82
+ raise
83
+
84
+ def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str,
85
+ temperature: float = 0.7, max_tokens: int = 512) -> str:
86
+ """GPU上で応答を生成する実際の処理"""
87
+ # モデルロード
88
+ self.load_model(model_name)
89
+
90
+ # GPUに移動
91
+ self.model.to('cuda')
92
+
93
+ # プロンプト構築
94
+ prompt = self._build_prompt(message, history)
95
+
96
+ # トークナイズ
97
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to('cuda')
98
+
99
+ # 生成
100
+ with torch.no_grad():
101
+ outputs = self.model.generate(
102
+ inputs,
103
+ max_new_tokens=max_tokens,
104
+ temperature=temperature,
105
+ do_sample=True,
106
+ top_p=0.95,
107
+ top_k=50,
108
+ repetition_penalty=1.1,
109
+ pad_token_id=self.tokenizer.pad_token_id,
110
+ eos_token_id=self.tokenizer.eos_token_id
111
+ )
112
+
113
+ # デコード
114
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
115
+
116
+ # CPUに戻す(メモリ節約)
117
+ self.model.to('cpu')
118
+ torch.cuda.empty_cache()
119
+ torch.cuda.synchronize()
120
+
121
+ return response.strip()
122
+
123
+ def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str,
124
+ temperature: float = 0.7, max_tokens: int = 512) -> str:
125
+ """メッセージに対する応答を生成"""
126
+ if IS_ZEROGPU:
127
+ # ZeroGPU環境の場合
128
+ return self._generate_response_gpu(message, history, model_name, temperature, max_tokens)
129
+ else:
130
+ # 通常環境の場合
131
+ self.load_model(model_name)
132
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
133
+
134
+ if device == 'cuda':
135
+ self.model.to(device)
136
+
137
+ prompt = self._build_prompt(message, history)
138
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
139
+
140
+ with torch.no_grad():
141
+ outputs = self.model.generate(
142
+ inputs,
143
+ max_new_tokens=max_tokens,
144
+ temperature=temperature,
145
+ do_sample=True,
146
+ top_p=0.95,
147
+ top_k=50,
148
+ repetition_penalty=1.1,
149
+ pad_token_id=self.tokenizer.pad_token_id,
150
+ eos_token_id=self.tokenizer.eos_token_id
151
+ )
152
+
153
+ response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
154
+ return response.strip()
155
+
156
+ def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str:
157
+ """会話履歴からプロンプトを構築"""
158
+ prompt = ""
159
+
160
+ # 履歴を追加(最新3件のみ使用 - メモリ効率のため)
161
+ for user_msg, assistant_msg in history[-3:]:
162
+ prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
163
+
164
+ # 現在のメッセージを追加
165
+ prompt += f"User: {message}\nAssistant: "
166
+
167
+ return prompt
168
+
169
+ # ChatBotインスタンス
170
+ chatbot = ChatBot()
171
+
172
+ # ZeroGPU環境の場合、GPUデコレータを適用
173
+ if IS_ZEROGPU:
174
+ chatbot._generate_response_gpu = spaces.GPU(duration=120)(chatbot._generate_response_gpu)
175
+
176
+ def respond(message: str, history: List[Tuple[str, str]], model_name: str,
177
+ temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]:
178
+ """Gradioのコールバック関数"""
179
+ if not message:
180
+ return history, ""
181
+
182
+ try:
183
+ # 応答生成
184
+ response = chatbot.generate_response(message, history, model_name, temperature, max_tokens)
185
+
186
+ # 履歴に追加
187
+ history.append((message, response))
188
+
189
+ return history, ""
190
+ except RuntimeError as e:
191
+ if "out of memory" in str(e).lower():
192
+ error_msg = "メモリ不足エラー: より小さいモデルを使用するか、最大トークン数を減らしてください。"
193
+ else:
194
+ error_msg = f"実行時エラー: {str(e)}"
195
+ history.append((message, error_msg))
196
+ return history, ""
197
+ except Exception as e:
198
+ error_msg = f"エラーが発生しました: {str(e)}"
199
+ history.append((message, error_msg))
200
+ return history, ""
201
+
202
+ def clear_chat() -> Tuple[List, str]:
203
+ """チャット履歴をクリア"""
204
+ return [], ""
205
+
206
+ # Gradio UI
207
+ with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app:
208
+ gr.Markdown("# 🤖 ChatGPT Clone")
209
+ gr.Markdown("""
210
+ 日本語対応のLLMを使用したチャットボットです。
211
+
212
+ **使用可能モデル:**
213
+ - [elyza/Llama-3-ELYZA-JP-8B](https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B)
214
+ - [Fugaku-LLM/Fugaku-LLM-13B](https://huggingface.co/Fugaku-LLM/Fugaku-LLM-13B)
215
+ - [openai/gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) - OpenAIの最新オープンウェイト推論モデル
216
+ """)
217
+
218
+ with gr.Row():
219
+ with gr.Column(scale=3):
220
+ chatbot_ui = gr.Chatbot(
221
+ label="Chat",
222
+ height=500,
223
+ show_label=False,
224
+ container=True
225
+ )
226
+
227
+ with gr.Row():
228
+ msg_input = gr.Textbox(
229
+ label="メッセージを入力",
230
+ placeholder="ここにメッセージを入力してください...",
231
+ lines=2,
232
+ scale=4,
233
+ show_label=False
234
+ )
235
+ send_btn = gr.Button("送信", variant="primary", scale=1)
236
+
237
+ with gr.Row():
238
+ clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary")
239
+
240
+ with gr.Column(scale=1):
241
+ model_select = gr.Dropdown(
242
+ choices=[
243
+ "elyza/Llama-3-ELYZA-JP-8B",
244
+ "Fugaku-LLM/Fugaku-LLM-13B",
245
+ "openai/gpt-oss-20b",
246
+ ],
247
+ value="elyza/Llama-3-ELYZA-JP-8B",
248
+ label="モデル選択",
249
+ interactive=True
250
+ )
251
+
252
+ temperature = gr.Slider(
253
+ minimum=0.1,
254
+ maximum=1.0,
255
+ value=0.7,
256
+ step=0.1,
257
+ label="Temperature",
258
+ info="生成の創造性を調整"
259
+ )
260
+
261
+ max_tokens = gr.Slider(
262
+ minimum=64,
263
+ maximum=512,
264
+ value=256,
265
+ step=64,
266
+ label="最大トークン数",
267
+ info="生成する最大トークン数"
268
+ )
269
+
270
+ gr.Markdown("""
271
+ ### 使い方
272
+ 1. モデルを選択
273
+ 2. メッセージを入力
274
+ 3. 送信ボタンをクリック
275
+
276
+ ### 注意事項
277
+ - 初回のモデル読み込みには時間がかかります
278
+ - ZeroGPU使用により高速推論が可能
279
+ - 1回の生成は120秒以内に完了します
280
+ - 大きなモデル使用時は、短めの応答になる場合があります
281
+ - gpt-oss-20bは推論専用モデルです
282
+ """)
283
+
284
+ # イベントハンドラ
285
+ msg_input.submit(
286
+ fn=respond,
287
+ inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
288
+ outputs=[chatbot_ui, msg_input]
289
+ )
290
+
291
+ send_btn.click(
292
+ fn=respond,
293
+ inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens],
294
+ outputs=[chatbot_ui, msg_input]
295
+ )
296
+
297
+ clear_btn.click(
298
+ fn=clear_chat,
299
+ outputs=[chatbot_ui, msg_input]
300
+ )
301
+
302
+ if __name__ == "__main__":
303
+ # Hugging Face Spaces環境かどうかを確認
304
+ is_hf_spaces = os.getenv("SPACE_ID") is not None
305
+
306
+ app.launch(
307
+ share=False,
308
+ show_error=True,
309
+ server_name="0.0.0.0" if is_hf_spaces else "127.0.0.1",
310
+ server_port=7860
311
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.19.2
2
+ transformers==4.38.2
3
+ torch==2.2.0
4
+ accelerate==0.27.2
5
+ sentencepiece==0.2.0
6
+ protobuf==4.25.3