import torch from transformers import AutoTokenizer, AutoModelForCausalLM from flask import Flask, request, jsonify, render_template_string import time # Flaskアプリケーションの設定 app = Flask(__name__) # デバイスの設定 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # トークナイザーとモデルの読み込み tokenizer = AutoTokenizer.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b", use_fast=False) model = AutoModelForCausalLM.from_pretrained("inu-ai/alpaca-guanaco-japanese-gpt-1b").to(device) # 定数 MAX_ASSISTANT_LENGTH = 100 MAX_INPUT_LENGTH = 1024 INPUT_PROMPT = r'\n以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n入力:\n{input}\n[SEP]\n応答:\n' NO_INPUT_PROMPT = r'\n以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\n[SEP]\n指示:\n{instruction}\n[SEP]\n応答:\n' # HTMLテンプレート HTML_TEMPLATE = """ Chat Interface

Chat Interface

""" def prepare_input(role_instruction, conversation_history, new_conversation): """入力テキストを整形する関数""" instruction = "".join([f"{text}\n" for text in role_instruction]) instruction += "\n".join(conversation_history) input_text = f"User:{new_conversation}" return INPUT_PROMPT.format(instruction=instruction, input=input_text) def format_output(output): """生成された出力を整形する関数""" return output.lstrip("").rstrip("").replace("[SEP]", "").replace("\\n", "\n") def trim_conversation_history(conversation_history, max_length): """会話履歴を最大長に収めるために調整する関数""" while len(conversation_history) > 2 and sum([len(tokenizer.encode(text, add_special_tokens=False)) for text in conversation_history]) + max_length > MAX_INPUT_LENGTH: conversation_history.pop(0) conversation_history.pop(0) return conversation_history def generate_response(role_instruction, conversation_history, new_conversation): """新しい会話に対する応答を生成する関数""" conversation_history = trim_conversation_history(conversation_history, MAX_ASSISTANT_LENGTH) input_text = prepare_input(role_instruction, conversation_history, new_conversation) token_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = model.generate( token_ids.to(model.device), min_length=len(token_ids[0]), max_length=min(MAX_INPUT_LENGTH, len(token_ids[0]) + MAX_ASSISTANT_LENGTH), temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=[[tokenizer.unk_token_id]] ) output = tokenizer.decode(output_ids.tolist()[0]) formatted_output_all = format_output(output) response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}" conversation_history.append(f"User:{new_conversation}".replace("\n", "\\n")) conversation_history.append(response.replace("\n", "\\n")) return formatted_output_all, response def get_default_role_instruction(): return [ "User:睡眠に悩む高校生です", "Assistant:では、お手伝いしましょう。!" ] def get_default_conversation_history(): return [ "User: こんにちは、今日は一日を有効に使いたいのですが、何かアドバイスはありますか?", "Assistant: こんにちは!一日の計画を立てることはとても重要です。朝、昼、晩それぞれの時間帯でやるべきことをリストにまとめると良いですよ。", "User: なるほど、具体的にはどんな内容をリストに入れればいいですか?", "Assistant: 朝は、起床後のルーチンや朝食の準備、健康的な運動を入れると良いですね。昼は、仕事や勉強の計画、休憩時間、昼食の準備が考えられます。晩は、夕食の準備や家事、リラックスタイム、就寝前のルーチンなどが含まれます。" ] @app.route('/') def home(): """ホームページをレンダリング""" return render_template_string(HTML_TEMPLATE) @app.route('/generate', methods=['POST']) def generate(): """Flaskエンドポイント: /generate""" data = request.json role_instruction = data.get('role_instruction', []) conversation_history = data.get('conversation_history', []) new_conversation = data.get('new_conversation', "") if not role_instruction or not new_conversation: return jsonify({"error": "role_instruction and new_conversation are required fields"}), 400 formatted_output_all, response = generate_response(role_instruction, conversation_history, new_conversation) return jsonify({"response": response, "conversation_history": conversation_history}) if __name__ == '__main__': app.run(debug=True, host="0.0.0.0", port=7860)