Spaces:
Running
Running
# -------------------------------------- | |
# Libraries | |
# -------------------------------------- | |
import os | |
import time | |
import gc # メモリ解放 | |
import re # 正規表現で文章をクリーンアップ | |
# HuggingFace | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# OpenAI | |
import openai | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
# LangChain | |
from langchain.llms import HuggingFacePipeline | |
from transformers import pipeline | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import VectorDBQA | |
from langchain.vectorstores import Chroma | |
from langchain import PromptTemplate, ConversationChain | |
from langchain.chains.question_answering import load_qa_chain # QA Chat | |
from langchain.document_loaders import SeleniumURLLoader # URL取得 | |
from langchain.docstore.document import Document # テキストをドキュメント化 | |
# from langchain.memory import ConversationBufferWindowMemory # チャット履歴 | |
from langchain.memory import ConversationSummaryBufferMemory # チャット履歴 | |
from typing import Any | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# Gradio | |
import gradio as gr | |
# PyPdf | |
from pypdf import PdfReader | |
# test | |
import langchain # (debug=Trueにするため) | |
# -------------------------------------- | |
# ユーザ別セッションの変数値を記録するクラス | |
# (参考)https://blog.shikoan.com/gradio-state/ | |
# -------------------------------------- | |
class SessionState: | |
def __init__(self): | |
# Hugging Face | |
self.tokenizer = None | |
self.pipe = None | |
self.model = None | |
# LangChain | |
self.llm = None | |
self.embeddings = None | |
self.current_model = "" | |
self.current_embedding = "" | |
self.db = None # Vector DB | |
self.memory = None # Langchain Chat Memory | |
self.qa_chain = None # load_qa_chain | |
self.conversation_chain = None # ConversationChain | |
self.embedded_urls = [] | |
# Apps | |
self.dialogue = [] # Recent Chat History for display | |
# -------------------------------------- | |
# Empty Cache | |
# -------------------------------------- | |
def cache_clear(self): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() # GPU Memory Clear | |
gc.collect() # CPU Memory Clear | |
# -------------------------------------- | |
# Clear Models (llm: llm model, embd: embeddings, db: vectordb) | |
# -------------------------------------- | |
def clear_memory(self, llm=False, embd=False, db=False): | |
# DB | |
if db and self.db: | |
self.db.delete_collection() | |
self.db = None | |
self.embedded_urls = [] | |
# Embeddings model | |
if llm or embd: | |
self.embeddings = None | |
self.current_embedding = "" | |
self.qa_chain = None | |
# LLM model | |
if llm: | |
self.llm = None | |
self.pipe = None | |
self.model = None | |
self.current_model = "" | |
self.tokenizer = None | |
self.memory = None | |
self.chat_history = [] # ←必要性を要検証 | |
self.cache_clear() | |
# # -------------------------------------- | |
# # Load Chat History as a list | |
# # -------------------------------------- | |
# def load_chat_history(self) -> list: | |
# chat_history = [] | |
# try: | |
# chat_memory = self.memory.load_memory_variables({})['chat_history'] | |
# except KeyError: | |
# return chat_history | |
# # チャット履歴をペアごとに読み取る | |
# for i in range(0, len(chat_memory), 2): | |
# user_message = chat_memory[i].content | |
# ai_message = "" | |
# if i + 1 < len(chat_memory): | |
# ai_message = chat_memory[i + 1].content | |
# chat_history.append([user_message, ai_message]) | |
# return chat_history | |
# -------------------------------------- | |
# 自作TextSplitter(テキストをLLMのトークン数内に分割) | |
# (参考)https://www.sato-susumu.com/entry/2023/04/30/131338 | |
# → 「!」、「?」、「)」、「.」、「!」、「?」、「,」などを追加 | |
# -------------------------------------- | |
class JPTextSplitter(RecursiveCharacterTextSplitter): | |
def __init__(self, **kwargs: Any): | |
separators = ["\n\n", "\n", "。", "!", "?", ")","、", ".", "!", "?", ",", " ", ""] | |
super().__init__(separators=separators, **kwargs) | |
# チャンクの分割 | |
chunk_size = 512 | |
chunk_overlap = 35 | |
text_splitter = JPTextSplitter( | |
chunk_size = chunk_size, # チャンクの最大文字数 | |
chunk_overlap = chunk_overlap, # オーバーラップの最大文字数 | |
) | |
# -------------------------------------- | |
# DeepL でメモリを翻訳しトークン数を削減(OpenAIモデル利用時) | |
# -------------------------------------- | |
DEEPL_API_ENDPOINT = "https://api-free.deepl.com/v2/translate" | |
DEEPL_API_KEY = "YOUR_DEEPL_API_KEY" | |
def deepl_memory(ss: SessionState) -> (SessionState): | |
if ss.current_model == "gpt-3.5-turbo": | |
# メモリから会話履歴を取得 | |
user_message = ss.memory.chat_memory.messages[-1][0].content | |
ai_message = ss.memory.chat_memory.messages[-1][1].content | |
text = [user_message, ai_message] | |
# DeepL設定 | |
params = { | |
"auth_key": DEEPL_API_KEY, | |
"text": text, | |
"target_lang": "EN", | |
"source_lang": "JA" | |
} | |
request = requests.post(DEEPL_API_ENDPOINT, data=params) | |
request.raise_for_status() # 応答のステータスコードがエラーの場合は例外を発生させます。 | |
response = request.json() | |
# JSONから翻訳文を取得 | |
user_message = response["translations"][0]["text"] | |
ai_message = response["translations"][1]["text"] | |
# memoryの最後の会話を削除し、翻訳文を追加 | |
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] | |
ss.memory.chat_memory.add_user_message(user_message) | |
ss.memory.chat_memory.add_ai_message(ai_message) | |
return ss | |
# -------------------------------------- | |
# LangChain カスタムプロンプト各種 | |
# llama tokenizer | |
# https://belladoreai.github.io/llama-tokenizer-js/example-demo/build/ | |
# OpenAI tokenizer | |
# https://platform.openai.com/tokenizer | |
# -------------------------------------- | |
# -------------------------------------- | |
# Conversation Chain Template | |
# -------------------------------------- | |
# Tokens: OpenAI 104/ Llama 105 <- In Japanese: Tokens: OpenAI 191/ Llama 162 | |
sys_chat_message = """ | |
The following is a conversation between an AI concierge and a customer. | |
The AI understands what the customer wants to know from the conversation history and the latest question, | |
and gives many specific details in Japanese. If the AI does not know the answer to a question, it does not | |
make up an answer and says "誠に申し訳ございませんが、その点についてはわかりかねます". | |
""".replace("\n", "") | |
chat_common_format = """ | |
=== | |
Question: {query} | |
=== | |
Conversation History: | |
{chat_history} | |
=== | |
日本語の回答:""" | |
chat_template_std = f"{sys_chat_message}{chat_common_format}" | |
chat_template_llama2 = f"<s>[INST] <<SYS>>{sys_chat_message}<</SYS>>{chat_common_format}[/INST]" | |
# -------------------------------------- | |
# QA Chain Template | |
# -------------------------------------- | |
# Tokens: OpenAI 113/ Llama 111 <- In Japanese: Tokens: OpenAI 256/ Llama 225 | |
sys_qa_message = """ | |
You are an AI concierge who carefully answers questions from customers based on references. | |
You understand what the customer wants to know from the "Conversation History" and "Question", | |
and give a specific answer in Japanese using sentences extracted from the following references. | |
If you do not know the answer, do not make up an answer and reply, | |
"誠に申し訳ございませんが、その点についてはわかりかねます". | |
""".replace("\n", "") | |
qa_common_format = """ | |
=== | |
Question: | |
{query} | |
=== | |
References: | |
{context} | |
=== | |
Conversation History: | |
{chat_history} | |
=== | |
日本語の回答:""" | |
qa_template_std = f"{sys_qa_message}{qa_common_format}" | |
qa_template_llama2 = f"<s>[INST] <<SYS>>{sys_qa_message}<</SYS>>{qa_common_format}[/INST]" | |
# -------------------------------------- | |
# ConversationSummaryBufferMemoryの要約プロンプト | |
# ソース → https://github.com/langchain-ai/langchain/blob/894c272a562471aadc1eb48e4a2992923533dea0/langchain/memory/prompt.py#L26-L49 | |
# -------------------------------------- | |
# Tokens: OpenAI 212/ Llama 214 <- In Japanese: Tokens: OpenAI 397/ Llama 297 | |
conversation_summary_template = """ | |
Using the example as a guide, compose a summary in English that gives an overview of the conversation by summarizing the "current summary" and the "new conversation". | |
=== | |
Example | |
[Current Summary] Customer asks AI what it thinks about Artificial Intelligence, AI says Artificial Intelligence is a good tool. | |
[New Conversation] | |
Human: なぜ人工知能が良いツールだと思いますか? | |
AI: 人工知能は「人間の可能性を最大限に引き出すことを助ける」からです。 | |
[New Summary] Customer asks what you think about Artificial Intelligence, and AI responds that it is a good force that helps humans reach their full potential. | |
=== | |
[Current Summary] {summary} | |
[New Conversation] | |
{new_lines} | |
[New Summary] | |
""".strip() | |
# モデル読み込み | |
def load_models( | |
ss: SessionState, | |
model_id: str, | |
embedding_id: str, | |
openai_api_key: str, | |
load_in_8bit: bool, | |
verbose: bool, | |
temperature: float, | |
min_length: int, | |
max_new_tokens: int, | |
top_k: int, | |
top_p: float, | |
repetition_penalty: float, | |
num_return_sequences: int, | |
) -> (SessionState, str): | |
# -------------------------------------- | |
# OpenAI API KEYの確認 | |
# -------------------------------------- | |
if (model_id == "gpt-3.5-turbo" or embedding_id == "text-embedding-ada-002"): | |
# 前処理 | |
if not os.environ["OPENAI_API_KEY"]: | |
status_message = "❌ OpenAI API KEY を設定してください" | |
return ss, status_message | |
# -------------------------------------- | |
# LLMの設定 | |
# -------------------------------------- | |
# OpenAI Model | |
if model_id == "gpt-3.5-turbo": | |
ss.clear_memory(llm=True, db=True) | |
ss.llm = ChatOpenAI( | |
model_name = model_id, | |
temperature = temperature, | |
verbose = verbose, | |
max_tokens = max_new_tokens, | |
) | |
# Hugging Face GPT Model | |
else: | |
ss.clear_memory(llm=True, db=True) | |
if model_id == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
ss.tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) | |
else: | |
ss.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
ss.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
load_in_8bit = load_in_8bit, | |
torch_dtype = torch.float16, | |
device_map = "auto", | |
) | |
ss.pipe = pipeline( | |
"text-generation", | |
model = ss.model, | |
tokenizer = ss.tokenizer, | |
min_length = min_length, | |
max_new_tokens = max_new_tokens, | |
do_sample = True, | |
top_k = top_k, | |
top_p = top_p, | |
repetition_penalty = repetition_penalty, | |
num_return_sequences = num_return_sequences, | |
temperature = temperature, | |
) | |
ss.llm = HuggingFacePipeline(pipeline=ss.pipe) | |
# -------------------------------------- | |
# 埋め込みモデルの設定 | |
# -------------------------------------- | |
if ss.current_embedding == embedding_id: | |
return | |
# Reset embeddings and vectordb | |
ss.clear_memory(embd=True, db=True) | |
if embedding_id == "None": | |
pass | |
# OpenAI | |
elif embedding_id == "text-embedding-ada-002": | |
ss.embeddings = OpenAIEmbeddings() | |
# Hugging Face | |
else: | |
ss.embeddings = HuggingFaceEmbeddings(model_name=embedding_id) | |
# -------------------------------------- | |
# 現在のモデル名を SessionStateオブジェクトに保存 | |
#--------------------------------------- | |
ss.current_model = model_id | |
ss.current_embedding = embedding_id | |
# Status Message | |
status_message = "✅ LLM: " + ss.current_model + ", embeddings: " + ss.current_embedding | |
return ss, status_message | |
def conversation_prep(ss: SessionState) -> SessionState: | |
if ss.conversation_chain is None: | |
human_prefix = "Human: " | |
ai_prefix = "AI: " | |
chat_template = chat_template_std | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照) | |
chat_template = chat_template.replace("\n", "<NL>") | |
human_prefix = "ユーザー: " | |
ai_prefix = "システム: " | |
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"): | |
chat_template = chat_template_llama2 | |
chat_prompt = PromptTemplate(input_variables=['query', 'chat_history'], template=chat_template) | |
if ss.memory is None: | |
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template) | |
ss.memory = ConversationSummaryBufferMemory( | |
llm = ss.llm, | |
memory_key = "chat_history", | |
input_key = "query", | |
output_key = "output_text", | |
return_messages = True, | |
human_prefix = human_prefix, | |
ai_prefix = ai_prefix, | |
max_token_limit = 512, | |
prompt = conversation_summary_prompt, | |
) | |
ss.conversation_chain = ConversationChain( | |
llm=ss.llm, | |
prompt = chat_prompt, | |
memory = ss.memory | |
) | |
return ss | |
def initialize_db(ss: SessionState) -> SessionState: | |
# client = chromadb.PersistentClient(path="./db") | |
ss.db = Chroma( | |
collection_name = "user_reference", | |
embedding_function = ss.embeddings, | |
# client = client | |
) | |
return ss | |
def embedding_process(ss: SessionState, ref_documents: Document) -> SessionState: | |
# -------------------------------------- | |
# 文章構成と不要な文字列の削除 | |
# -------------------------------------- | |
for i in range(len(ref_documents)): | |
content = ref_documents[i].page_content.strip() | |
# -------------------------------------- | |
# PDFの場合は読み取りエラー対策で文書修正を強めに実施 | |
# -------------------------------------- | |
if ".pdf" in ref_documents[i].metadata['source']: | |
pdf_replacement_sets = [ | |
('\n ', '**PLACEHOLDER+SPACE**'), | |
('\n\u3000', '**PLACEHOLDER+SPACE**'), | |
('.\n', '。**PLACEHOLDER**'), | |
(',\n', '。**PLACEHOLDER**'), | |
('?\n', '。**PLACEHOLDER**'), | |
('!\n', '。**PLACEHOLDER**'), | |
('!\n', '。**PLACEHOLDER**'), | |
('。\n', '。**PLACEHOLDER**'), | |
('!\n', '!**PLACEHOLDER**'), | |
(')\n', '!**PLACEHOLDER**'), | |
(']\n', '!**PLACEHOLDER**'), | |
('?\n', '?**PLACEHOLDER**'), | |
(')\n', '?**PLACEHOLDER**'), | |
('】\n', '?**PLACEHOLDER**'), | |
] | |
for original, replacement in pdf_replacement_sets: | |
content = content.replace(original, replacement) | |
content = content.replace(" ", "") | |
# -------------------------------------- | |
# 不要文字列・空白の削除 | |
remove_texts = ["\n", "\r", " "] | |
for remove_text in remove_texts: | |
content = content.replace(remove_text, "") | |
# タブや連続空白をシングルスペースに変換 | |
replace_texts = ["\t", "\u3000"] | |
for replace_text in replace_texts: | |
content = content.replace(replace_text, " ") | |
# PDFの正当な改行をもとに戻す。 | |
if ".pdf" in ref_documents[i].metadata['source']: | |
content = content.replace('**PLACEHOLDER**', '\n').replace('**PLACEHOLDER+SPACE**', '\n ') | |
ref_documents[i].page_content = content | |
# -------------------------------------- | |
# チャンクに分割 | |
texts = text_splitter.split_documents(ref_documents) | |
# -------------------------------------- | |
# multi-e5 モデルの学習環境に合わせて文言を追加 | |
# https://hironsan.hatenablog.com/entry/2023/07/05/073150 | |
# -------------------------------------- | |
if ss.current_embedding == "intfloat/multilingual-e5-large": | |
for i in range(len(texts)): | |
texts[i].page_content = "passage:" + texts[i].page_content | |
# vectordb の初期化 | |
if ss.db is None: | |
ss = initialize_db(ss) | |
# db に埋め込み | |
# ss.db = Chroma.from_documents(texts, ss.embeddings) | |
ss.db.add_documents(documents=texts, embedding=ss.embeddings) | |
# -------------------------------------- | |
# QAチェーンの設定 | |
# -------------------------------------- | |
if ss.qa_chain is None: | |
# QAメモリ | |
human_prefix = "Human: " | |
ai_prefix = "AI: " | |
qa_template = qa_template_std | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
# Rinnaモデル向けの設定(改行コード修正、メモリ用prefix (公式ページ参照) | |
qa_template = qa_template.replace("\n", "<NL>") | |
human_prefix = "ユーザー: " | |
ai_prefix = "システム: " | |
elif ss.current_model.startswith("elyza/ELYZA-japanese-Llama-2-7b"): | |
qa_template = qa_template_llama2 | |
qa_prompt = PromptTemplate(input_variables=['context', 'query', 'chat_history'], template=qa_template) | |
if ss.memory is None: | |
conversation_summary_prompt = PromptTemplate(input_variables=['summary', 'new_lines'], template=conversation_summary_template) | |
ss.memory = ConversationSummaryBufferMemory( | |
llm = ss.llm, | |
memory_key = "chat_history", | |
input_key = "query", | |
output_key = "output_text", | |
return_messages = True, | |
human_prefix = human_prefix, | |
ai_prefix = ai_prefix, | |
max_token_limit = 512, | |
prompt = conversation_summary_prompt, | |
) | |
ss.qa_chain = load_qa_chain(ss.llm, chain_type="stuff", memory=ss.memory, prompt=qa_prompt) | |
return ss | |
def embed_ref(ss: SessionState, urls: str, fileobj: list, header_lim: int, footer_lim: int) -> (SessionState, str): | |
url_flag = "-" | |
pdf_flag = "-" | |
# -------------------------------------- | |
# URLの読み込みとvectordb登録 | |
# -------------------------------------- | |
# URLリストの前処理(リスト化、重複削除、非URL排除) | |
urls = list({url for url in urls.split("\n") if url and "://" in url}) | |
if urls: | |
# 登録済みURL(ss.embedded_urls)との重複を排除。登録済みリストに登録 | |
urls = [url for url in urls if url not in ss.embedded_urls] | |
ss.embedded_urls.extend(urls) | |
# ウェブページの読み込み | |
loader = SeleniumURLLoader(urls=urls) | |
ref_documents = loader.load() | |
# 埋め込み処理の実行 | |
ss = embedding_process(ss, ref_documents) | |
url_flag = "✅ 登録済" | |
# -------------------------------------- | |
# PDFのヘッダーとフッターを除去してvectordb登録 | |
# https://pypdf.readthedocs.io/en/stable/user/extract-text.html | |
# -------------------------------------- | |
if fileobj is None: | |
pass | |
else: | |
# ファイル名リストを取得 | |
pdf_paths = [] | |
for path in fileobj: | |
pdf_paths.append(path.name) | |
# リストの初期化 | |
ref_documents = [] | |
# 各PDFファイルを読み込み | |
for pdf_path in pdf_paths: | |
pdf = PdfReader(pdf_path) | |
body = [] | |
def visitor_body(text, cm, tm, font_dict, font_size): | |
y = tm[5] | |
if y > footer_lim and y < header_lim: # y座標がヘッダーとフッターの間にあるかどうかを確認 | |
parts.append(text) | |
for page in pdf.pages: | |
parts = [] | |
page.extract_text(visitor_text=visitor_body) | |
body.append("".join(parts)) | |
body = "\n".join(body) | |
# パスからファイル名のみを取得 | |
filename = os.path.basename(pdf_path) | |
# 取得テキスト → LangChain ドキュメント変換 | |
ref_documents.append(Document(page_content=body, metadata={"source": filename})) | |
# 埋め込み処理の実行 | |
ss = embedding_process(ss, ref_documents) | |
pdf_flag = "✅ 登録済" | |
langchain.debug=True | |
status_message = "URL: " + url_flag + " / PDF: " + pdf_flag | |
return ss, status_message | |
def clear_db(ss: SessionState) -> (SessionState, str): | |
try: | |
ss.db.delete_collection() | |
status_message = "✅ 参照データを削除しました。" | |
except NameError: | |
status_message = "❌ 参照データが登録されていません。" | |
return ss, status_message | |
# ---------------------------------------------------------------------------- | |
# query入力 ▶ [def user] ▶ [ def bot ] ▶ [def show_response] ▶ チャットボット画面 | |
# ⬇ ⬇ ⬆ | |
# チャットボット画面 [qa_predict / conversation_predict] | |
# ---------------------------------------------------------------------------- | |
def user(ss: SessionState, query) -> (SessionState, list): | |
# 会話履歴が一定数を超えた場合は、最初の履歴を削除する | |
if len(ss.dialogue) > 10: | |
ss.dialogue.pop(0) | |
ss.dialogue = ss.dialogue + [(query, None)] # 会話履歴(None はボットの回答欄=空欄) | |
chat_history = ss.dialogue | |
# チャット画面=chat_history | |
return ss, chat_history | |
def bot(ss: SessionState, query, qa_flag) -> (SessionState, str): | |
if ss.llm is None: | |
response = "LLMが設定されていません。設定画面で任意のモデルを選択してください。" | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss, "" | |
elif qa_flag is True and ss.embeddings is None: | |
response = "Embeddingモデルが設定されていません。設定画面で任意のモデルを選択してください。" | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
# QA Model | |
elif qa_flag is True and ss.embeddings is not None: | |
ss = qa_predict(ss, query) # LLMで回答を生成 | |
# Chat Model | |
else: | |
ss = conversation_prep(ss) | |
ss = chat_predict(ss, query) | |
return ss, "" # ssとquery欄(空欄) | |
def chat_predict(ss: SessionState, query) -> SessionState: | |
response = ss.conversation_chain.predict(query=query) | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss | |
def qa_predict(ss: SessionState, query) -> SessionState: | |
# Rinnaモデル向けの設定(クエリの改行コード修正) | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
query = query.strip().replace("\n", "<NL>") | |
else: | |
query = query.strip() | |
# multilingual-e5向けのクエリ文言prefix | |
if ss.current_embedding == "intfloat/multilingual-e5-large": | |
db_query_str = "query: " + query | |
else: | |
db_query_str = query | |
# DBから関連文書と出典を抽出 | |
docs = ss.db.similarity_search(db_query_str, k=2) | |
sources= "\n\n[Sources]\n" + '\n - '.join(list(set(doc.metadata['source'] for doc in docs if 'source' in doc.metadata))) | |
# Rinnaモデル向けの設定(抽出文書の改行コード修正) | |
if ss.current_model == "rinna/bilingual-gpt-neox-4b-instruction-sft": | |
for i in range(len(docs)): | |
docs[i].page_content = docs[i].page_content.strip().replace("\n", "<NL>") | |
# 回答の生成(最大3回の試行) | |
for _ in range(3): | |
result = ss.qa_chain({"input_documents": docs, "query": query}) | |
result["output_text"] = result["output_text"].replace("<NL>", "\n").strip("...").strip("回答:").strip() | |
# result["output_text"]が空欄でない場合、メモリーを更新して返す | |
if result["output_text"] != "": | |
response = result["output_text"] + sources | |
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] # 最後の会話を削除 | |
ss.memory.chat_memory.add_user_message(query) | |
ss.memory.chat_memory.add_ai_message(response) | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) | |
return ss | |
else: | |
# 空欄の場合は直近の履歴を削除してやり直し | |
ss.memory.chat_memory.messages = ss.memory.chat_memory.messages[:-1] | |
# 3回の試行後も空欄の場合 | |
response = "3回試行しましたが、情報製生成できませんでした。" | |
if sources != "": | |
response += "参考文献の抽出には成功していますので、言語モデルを変えてお試しください。" | |
# ユーザーメッセージと AI メッセージの追加 | |
ss.memory.chat_memory.add_user_message(query.replace("<NL>", "\n")) | |
ss.memory.chat_memory.add_ai_message(response) | |
ss.dialogue[-1] = (ss.dialogue[-1][0], response) # 会話履歴 | |
return ss | |
# 回答を1文字ずつチャット画面に表示する | |
def show_response(ss: SessionState) -> str: | |
chat_history = [list(item) for item in ss.dialogue] # タプルをリストに変換して、メモリから会話履歴を取得 | |
response = chat_history[-1][1] # メモリから最新の会話[-1]を取得し、チャットボットの回答[1]を退避 | |
chat_history[-1][1] = "" # 逐次表示のため、チャットボットの回答[1]を空にする | |
if response is None: | |
response = "回答を生成できませんでした。" | |
for character in response: | |
chat_history[-1][1] += character | |
time.sleep(0.05) | |
yield chat_history | |
with gr.Blocks() as demo: | |
# ユーザ別セッションメモリのインスタンス化(リロードでリセット) | |
ss = gr.State(SessionState()) | |
# -------------------------------------- | |
# API KEY をセット/クリアする関数 | |
# -------------------------------------- | |
def openai_api_setfn(openai_api_key) -> str: | |
if openai_api_key == "kikagaku": | |
os.environ["OPENAI_API_KEY"] = os.getenv("kikagaku_demo") | |
status_message = "✅ キカガク専用DEMOへようこそ!APIキーを設定しました" | |
return status_message | |
elif not openai_api_key or not openai_api_key.startswith("sk-") or len(openai_api_key) < 50: | |
os.environ["OPENAI_API_KEY"] = "" | |
status_message = "❌ 有効なAPIキーを入力してください" | |
return status_message | |
else: | |
os.environ["OPENAI_API_KEY"] = openai_api_key | |
status_message = "✅ APIキーを設定しました" | |
return status_message | |
def openai_api_clsfn(ss) -> (str, str): | |
openai_api_key = "" | |
os.environ["OPENAI_API_KEY"] = "" | |
status_message = "✅ APIキーの削除が完了しました" | |
return status_message, "" | |
# -------------------------------------- | |
# 回答の継続ボタン | |
# -------------------------------------- | |
def continue_pred(): | |
query = "回答を続けてください" | |
return query | |
with gr.Tabs(): | |
# -------------------------------------- | |
# Setting Tab | |
# -------------------------------------- | |
with gr.TabItem("1. LLM設定"): | |
with gr.Row(): | |
model_id = gr.Dropdown( | |
choices=[ | |
'elyza/ELYZA-japanese-Llama-2-7b-fast-instruct', | |
'rinna/bilingual-gpt-neox-4b-instruction-sft', | |
'gpt-3.5-turbo', | |
], | |
value="elyza/ELYZA-japanese-Llama-2-7b-fast-instruct", | |
label='LLM model', | |
interactive=True, | |
) | |
with gr.Row(): | |
embedding_id = gr.Dropdown( | |
choices=[ | |
'intfloat/multilingual-e5-large', | |
'sonoisa/sentence-bert-base-ja-mean-tokens-v2', | |
'oshizo/sbert-jsnli-luke-japanese-base-lite', | |
'text-embedding-ada-002', | |
# "None" | |
], | |
value="sonoisa/sentence-bert-base-ja-mean-tokens-v2", | |
label = 'Embedding model', | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Column(scale=19): | |
openai_api_key = gr.Textbox(label="OpenAI API Key (Optional)", interactive=True, type="password", value="", placeholder="Your OpenAI API Key for OpenAI models.", max_lines=1) | |
with gr.Column(scale=1): | |
openai_api_set = gr.Button(value="Set API KEY", size="sm") | |
openai_api_cls = gr.Button(value="Delete API KEY", size="sm") | |
# 詳細設定(折りたたみ) | |
with gr.Accordion(label="Advanced Setting", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
load_in_8bit = gr.Checkbox(label="8bit Quantize (HF)", value=True, interactive=True) | |
verbose = gr.Checkbox(label="Verbose (OpenAI, HF)", value=True, interactive=True) | |
with gr.Column(): | |
temperature = gr.Slider(label='Temperature (OpenAI, HF)', minimum=0.0, maximum=1.0, step=0.1, value=0.2, interactive=True) | |
with gr.Column(): | |
min_length = gr.Slider(label="min_length (HF)", minimum=1, maximum=100, step=1, value=10, interactive=True) | |
with gr.Column(): | |
max_new_tokens = gr.Slider(label="max_tokens(OpenAI), max_new_tokens(HF)", minimum=1, maximum=1024, step=1, value=256, interactive=True) | |
with gr.Column(): | |
top_k = gr.Slider(label='top_k (HF)', minimum=1, maximum=100, step=1, value=40, interactive=True) | |
with gr.Column(): | |
top_p = gr.Slider(label='top_p (HF)', minimum=0.01, maximum=0.99, step=0.01, value=0.92, interactive=True) | |
with gr.Column(): | |
repetition_penalty = gr.Slider(label='repetition_penalty (HF)', minimum=0.5, maximum=2, step=0.1, value=1.2, interactive=True) | |
with gr.Column(): | |
num_return_sequences = gr.Slider(label='num_return_sequences (HF)', minimum=1, maximum=20, step=1, value=3, interactive=True) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
config_btn = gr.Button(value="Configure") | |
with gr.Column(scale=13): | |
status_cfg = gr.Textbox(show_label=False, interactive=False, value="モデルを設定してください", container=False, max_lines=1) | |
# ボタン等のアクション設定 | |
openai_api_set.click(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
openai_api_cls.click(openai_api_clsfn, inputs=[openai_api_key], outputs=[status_cfg, openai_api_key], show_progress="full") | |
openai_api_key.submit(openai_api_setfn, inputs=[openai_api_key], outputs=[status_cfg], show_progress="full") | |
config_btn.click( | |
fn = load_models, | |
inputs = [ss, model_id, embedding_id, openai_api_key, load_in_8bit, verbose, temperature, | |
min_length, max_new_tokens, top_k, top_p, repetition_penalty, num_return_sequences], | |
outputs = [ss, status_cfg], | |
queue = True, | |
show_progress = "full" | |
) | |
# -------------------------------------- | |
# Reference Tab | |
# -------------------------------------- | |
with gr.TabItem("2. References"): | |
urls = gr.TextArea( | |
max_lines = 60, | |
show_label=False, | |
info = "List any reference URLs for Q&A retrieval.", | |
placeholder = "https://blog.kikagaku.co.jp/deep-learning-transformer\nhttps://note.com/elyza/n/na405acaca130", | |
interactive=True, | |
) | |
with gr.Row(): | |
pdf_paths = gr.File(label="PDFs", height=150, min_width=60, scale=7, file_types=[".pdf"], file_count="multiple", interactive=True) | |
header_lim = gr.Number(label="Header (pt)", step=1, value=792, precision=0, min_width=70, scale=1, interactive=True) | |
footer_lim = gr.Number(label="Footer (pt)", step=1, value=0, precision=0, min_width=70, scale=1, interactive=True) | |
pdf_ref = gr.Textbox(show_label=False, value="A4 Size:\n(下)0-792pt(上)\n *28.35pt/cm", container=False, scale=1, interactive=False) | |
with gr.Row(): | |
ref_set_btn = gr.Button(value="コンテンツ登録", scale=1) | |
ref_clear_btn = gr.Button(value="登録データ削除", scale=1) | |
status_ref = gr.Textbox(show_label=False, interactive=False, value="参照データ未登録", container=False, max_lines=1, scale=18) | |
ref_set_btn.click(fn=embed_ref, inputs=[ss, urls, pdf_paths, header_lim, footer_lim], outputs=[ss, status_ref], queue=True, show_progress="full") | |
ref_clear_btn.click(fn=clear_db, inputs=[ss], outputs=[ss, status_ref], show_progress="full") | |
# -------------------------------------- | |
# Chatbot Tab | |
# -------------------------------------- | |
with gr.TabItem("3. Q&A Chat"): | |
chat_history = gr.Chatbot([], elem_id="chatbot").style(height=600, color_map=('green', 'gray')) | |
with gr.Row(): | |
with gr.Column(scale=95): | |
query = gr.Textbox( | |
show_label=False, | |
placeholder="Send a message with [Shift]+[Enter] key.", | |
lines=4, | |
container=False, | |
autofocus=True, | |
interactive=True, | |
) | |
with gr.Column(scale=5): | |
qa_flag = gr.Checkbox(label="QA mode", value=True, min_width=60, interactive=False) | |
query_send_btn = gr.Button(value="▶") | |
# gr.Examples(["機械学習について説明してください"], inputs=[query]) | |
query.submit(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history]) | |
query_send_btn.click(user, [ss, query], [ss, chat_history]).then(bot, [ss, query, qa_flag], [ss, query]).then(show_response, [ss], [chat_history]) | |
if __name__ == "__main__": | |
demo.queue(concurrency_count=5) | |
demo.launch(debug=True, inbrowser=True) |