"""Ask a question to the netspresso database.""" import json import sys import argparse from typing import List from langchain.chat_models import ChatOpenAI # for `gpt-3.5-turbo` & `gpt-4` from langchain.chains import RetrievalQAWithSourcesChain from langchain.prompts import ( ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import BaseRetriever, Document import gradio as gr from search_online import OnlineSearcher # DEFAULT_QUESTION = "모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n예를들어 \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." DEFAULT_QUESTION = "Ask the Netspresso bot about model lightweighting and optimization.\nFor example \n\n- Why do I need to use Netspresso?\n- Summarize how to compress the model with netspresso.\n- Tell me what the pruning is.\n- What kinds of hardware can I use with this toolkit?\n- Can I use YOLOv8 with this tool? If so, tell me the examples." TEMPERATURE = 0 # manual arguments (FIXME) args = argparse.Namespace args.index_type = "hybrid" args.index = ( "/root/indexes/docs-netspresso-ai/sparse,/root/indexes/docs-netspresso-ai/dense" ) if isinstance( args.index, tuple ): # black extension automatically convert long str to tuple assert len(args.index) == 1 args.index = args.index[0] args.encoder = "castorini/mdpr-question-nq" args.device = "cuda:0" args.alpha = 0.5 args.normalization = True args.lang_abbr = "en" args.K = 10 # initialize qabot print("initialize NP doc retrieval bot") RETRIEVER = OnlineSearcher(args) class LangChainCustomRetrieverWrapper(BaseRetriever): def __init__(self, args): super().__init__() # self.retriever = RETRIEVER # TODO. should be initialize from args # self.args = args print("Initialize LangChainCustomRetrieverWrapper, TODO: fix minor bug") def get_relevant_documents(self, query: str) -> List[Document]: """Get texts relevant for a query. Args: query: string to find relevant texts for Returns: List of relevant documents """ print(f"query = {query}") # retrieve # hits = self.retriever.search(query, self.args.K) hits = RETRIEVER.search( query, args.K ) # TODO: fix bug that BaseRetriever object cannot have extra field # extract docs results = [ { "contents": json.loads( # self.retriever.searcher.sparse_searcher.doc(hits[i].docid).raw() # TODO: fix bug that BaseRetriever object cannot have extra field RETRIEVER.searcher.sparse_searcher.doc(hits[i].docid).raw() )["contents"], "docid": hits[i].docid, } for i in range(len(hits)) ] # make result list of Document object return [ Document( page_content=result["contents"], metadata={"source": result["docid"]} ) for result in results ] async def aget_relevant_documents( self, query: str ) -> List[Document]: # abstractmethod raise NotImplementedError class RaLM: def __init__(self, args): self.args = args self.initialize_ralm() def initialize_ralm(self): # initialize custom retriever self.retriever = LangChainCustomRetrieverWrapper(self.args) # prompt for RaLM system_template = """Use the following pieces of context to answer the users question. Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources. Always try to generate answer from source. ---------------- {summaries}""" messages = [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template("{question}"), ] prompt = ChatPromptTemplate.from_messages(messages) chain_type_kwargs = {"prompt": prompt} llm = ChatOpenAI(model_name=self.args.model_name, temperature=TEMPERATURE) self.chain = RetrievalQAWithSourcesChain.from_chain_type( llm=llm, chain_type="stuff", retriever=self.retriever, return_source_documents=True, reduce_k_below_max_tokens=True, chain_type_kwargs=chain_type_kwargs, ) def run_chain(self, question, force_korean=False): if force_korean: question = f"{question} 본문을 참고해서 한글로 대답해줘" result = self.chain({"question": question}) # postprocess result["answer"] = self.postprocess(result["answer"]) if isinstance(result["sources"], str): result["sources"] = self.postprocess(result["sources"]) result["sources"] = result["sources"].split(", ") result["sources"] = [src.strip() for src in result["sources"]] # print result self.print_result(result) return result def print_result( self, result ): # print result of RetrievalQAWithSourcesChain of langchain print(f"Answer: {result['answer']}") print(f"Sources: ") print(result["sources"]) assert isinstance(result["sources"], list) nSource = len(result["sources"]) for i in range(nSource): source_title = result["sources"][i] print(f"{source_title}: ") if "source_documents" in result: for j in range(len(result["source_documents"])): if result["source_documents"][j].metadata["source"] == source_title: print(result["source_documents"][j].page_content) break def postprocess(self, text): # remove final parenthesis (bug with unknown cause) if ( text.endswith(")") or text.endswith("(") or text.endswith("[") or text.endswith("]") ): text = text[:-1] return text.strip() if __name__ == "__main__": parser = argparse.ArgumentParser( description="Ask a question to the netspresso docs." ) # General # parser.add_argument( # "--question", # type=str, # default=None, # required=True, # help="The question to ask for database", # ) parser.add_argument( "--model_name", type=str, default="gpt-3.5-turbo-16k-0613", help="model name for openai api", ) # Retriever: fixed arg for now """ parser.add_argument( "--query_encoder_name_or_dir", type=str, default="princeton-nlp/densephrases-multi-query-multi", help="query encoder name registered in huggingface model hub OR custom query encoder checkpoint directory", ) parser.add_argument( "--index_name", type=str, default="1048576_flat_OPQ96", help="index name appended to index directory prefix", ) """ args = parser.parse_args() # to prevent collision with DensePhrase native argparser sys.argv = [sys.argv[0]] # initialize class app = RaLM(args) def question_answer(question): result = app.run_chain(question=question, force_korean=False) return result[ "answer" ], "\n######################################################\n\n".join( [ f"Source {idx}\n{doc.page_content}" for idx, doc in enumerate(result["source_documents"]) ] ) # launch gradio gr.Interface( fn=question_answer, inputs=gr.inputs.Textbox(default=DEFAULT_QUESTION, label="Question"), outputs=[ gr.inputs.Textbox(default="", label="Bot response"), gr.inputs.Textbox(default="", label="Search result used by bot"), ], title="Netspresso Q&A bot", theme="dark-grass", description="Ask the Netspresso bot about model lightweighting and optimization.", # simplified version, hide detail version # description="모델 경량화 및 최적화와 관련하여 Netspresso bot에게 물어보세요.\n\n retriever: BM25&mdpr-question-nq, generator: gpt-3.5-turbo-16k-0613 (API)", ).launch(share=True)