chatdemo2 / app.py
oggata's picture
Update app.py
f95231e verified
raw
history blame
3.98 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
"""
デバッグ用のシンプルなSarashinaチャットボット
additional_inputsなしでテスト
"""
# モデルとトークナイザーの初期化
MODEL_NAME = "sbintuitions/sarashina2.2-3b-instruct-v0.1"
print("モデルを読み込み中...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
print("モデルの読み込みが完了しました。")
def respond(message, history):
"""
シンプルなチャットボット応答関数
additional_inputsなし
"""
try:
# デバッグ情報を出力
print(f"DEBUG - message: {message} (type: {type(message)})")
print(f"DEBUG - history: {history} (type: {type(history)})")
# システムメッセージ(固定)
system_message = "あなたは親切で知識豊富な日本語アシスタントです。ユーザーの質問に丁寧に答えてください。"
# 会話履歴を含むプロンプトを構築
conversation = f"システム: {system_message}\n"
# 会話履歴を追加
if history and isinstance(history, list):
for item in history:
if isinstance(item, (list, tuple)) and len(item) >= 2:
user_msg, bot_msg = item[0], item[1]
if user_msg:
conversation += f"ユーザー: {user_msg}\n"
if bot_msg:
conversation += f"アシスタント: {bot_msg}\n"
# 現在のメッセージを追加
conversation += f"ユーザー: {message}\nアシスタント: "
# トークン化
inputs = tokenizer.encode(conversation, return_tensors="pt")
# GPU使用時はCUDAに移動
if torch.cuda.is_available():
inputs = inputs.cuda()
# 応答生成
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
# 生成されたテキストをデコード
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 応答部分のみを抽出
full_response = generated[len(conversation):].strip()
# 不要な部分を除去
if "ユーザー:" in full_response:
full_response = full_response.split("ユーザー:")[0].strip()
# ストリーミング風の出力
for i in range(len(full_response)):
response = full_response[:i+1]
yield response
except Exception as e:
yield f"エラーが発生しました: {str(e)}"
"""
シンプルなChatInterface(additional_inputsなし)
"""
demo = gr.ChatInterface(
respond,
title="🤖 Sarashina Chatbot (Simple)",
description="Sarashina2.2-3b-instruct モデルを使用した日本語チャットボットです。(デバッグ用)",
theme=gr.themes.Soft(),
examples=[
"こんにちは!",
"日本について教えて",
"プログラミングの質問があります",
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_api=True,
debug=True
)