# from time import time import gradio as gr from langchain.chains import RetrievalQA from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline # from langchain.llms import OpenAI from langchain.chat_models import ChatOpenAI from langchain.vectorstores import Qdrant from openai.error import InvalidRequestError from qdrant_client import QdrantClient from config import DB_CONFIG, DB_E5_CONFIG def _get_config_and_embeddings(collection_name: str | None) -> tuple: if collection_name is None or collection_name == "E5": db_config = DB_E5_CONFIG model_name = "intfloat/multilingual-e5-large" model_kwargs = {"device": "cpu"} encode_kwargs = {"normalize_embeddings": False} embeddings = HuggingFaceEmbeddings( model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs, ) elif collection_name == "OpenAI": db_config = DB_CONFIG embeddings = OpenAIEmbeddings() else: raise ValueError("Unknow collection name") return db_config, embeddings def _get_rinna_llm(temperature: float): model = "rinna/bilingual-gpt-neox-4b-instruction-ppo" tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) model = AutoModelForCausalLM.from_pretrained( model, load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, temperature=temperature, ) llm = HuggingFacePipeline(pipeline=pipe) return llm def _get_llm_model( model_name: str | None, temperature: float, ): if model_name is None: model = "rinna" elif model_name == "rinna": model = "rinna" elif model_name == "GPT-3.5": model = "gpt-3.5-turbo" elif model_name == "GPT-4": model = "gpt-4" else: raise ValueError("Unknow model name") if model.startswith("gpt"): llm = ChatOpenAI(model=model, temperature=temperature) elif model == "rinna": llm = _get_rinna_llm(temperature) return llm # prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. # {context} # Question: {question} # Answer in Japanese:""" # PROMPT = PromptTemplate( # template=prompt_template, input_variables=["context", "question"] # ) def get_retrieval_qa( collection_name: str | None, model_name: str | None, temperature: float, option: str | None, ) -> RetrievalQA: db_config, embeddings = _get_config_and_embeddings(collection_name) db_url, db_api_key, db_collection_name = db_config client = QdrantClient(url=db_url, api_key=db_api_key) db = Qdrant( client=client, collection_name=db_collection_name, embeddings=embeddings ) if option is None or option == "All": retriever = db.as_retriever() else: retriever = db.as_retriever( search_kwargs={ "filter": {"category": option}, } ) llm = _get_llm_model(model_name, temperature) # chain_type_kwargs = {"prompt": PROMPT} result = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, # chain_type_kwargs=chain_type_kwargs, ) return result def get_related_url(metadata): urls = set() for m in metadata: # p = m['source'] url = m["url"] if url in urls: continue urls.add(url) category = m["category"] # print(m) yield f'

URL: {url} (category: {category})

' def main( query: str, collection_name: str, model_name: str, option: str, temperature: float ): qa = get_retrieval_qa(collection_name, model_name, temperature, option) try: result = qa(query) except InvalidRequestError as e: return "回答が見つかりませんでした。別な質問をしてみてください", str(e) else: metadata = [s.metadata for s in result["source_documents"]] html = "
" + "\n".join(get_related_url(metadata)) + "
" return result["result"], html nvdajp_book_qa = gr.Interface( fn=main, inputs=[ gr.Textbox(label="query"), gr.Radio(["E5", "OpenAI"], label="Embedding", info="選択なしで「E5」を使用"), gr.Radio(["rinna", "GPT-3.5", "GPT-4"], label="Model", info="選択なしで「rinna」を使用"), gr.Radio( ["All", "ja-book", "ja-nvda-user-guide", "en-nvda-user-guide"], label="絞り込み", info="ドキュメント制限する?", ), gr.Slider(0, 2), ], outputs=[gr.Textbox(label="answer"), gr.outputs.HTML()], ) nvdajp_book_qa.launch()