import torch import locale from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain.llms import HuggingFacePipeline from langchain.retrievers.document_compressors import LLMChainExtractor from langchain.retrievers import ContextualCompressionRetriever from langchain.vectorstores import Chroma from langchain import PromptTemplate, LLMChain from langchain.chains import RetrievalQA, ConversationalRetrievalChain from langchain.prompts import PromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.memory import ConversationBufferMemory from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.document_loaders import WebBaseLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from llm_for_langchain import LLM from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.messages import HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableBranch class EndpointHandler(): def __init__(self, path=""): # Config LangChain os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_API_KEY"] = getpass.getpass() # Create LLM chat = LLM(model_name_or_path=path, bit4=False) # Create Text-Embedding Model embedding_function = HuggingFaceBgeEmbeddings( model_name="DMetaSoul/Dmeta-embedding", model_kwargs={'device': 'cuda'}, encode_kwargs={'normalize_embeddings': True} ) # Load Vector db urls = [ "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html", "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD", "https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html", "https://www.thinkhk.com/article/2023-03/24/59874.html" ] loader = WebBaseLoader(urls) data = loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16) all_splits = text_splitter.split_documents(data) vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function) retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) compressor = LLMChainExtractor.from_llm(self.llm) retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) SYSTEM_TEMPLATE = """ Answer the user's questions based on the below context. If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know": {context} """ question_answering_prompt = ChatPromptTemplate.from_messages( [ ( "system", SYSTEM_TEMPLATE, ), MessagesPlaceholder(variable_name="messages"), ] ) # Wrap the retriever query_transforming_retriever_chain = RunnableBranch( ( lambda x: len(x.get("messages", [])) == 1, # If only one message, then we just pass that message's content to retriever (lambda x: x["messages"][-1].content) | retriever, ), # If messages, then we pass inputs to LLM chain to transform the query, then pass to retriever question_answering_prompt | chat | StrOutputParser() | retriever, ).with_config(run_name="chat_retriever_chain") document_chain = create_stuff_documents_chain(chat, question_answering_prompt) self.conversational_retrieval_chain = RunnablePassthrough.assign( context=query_transforming_retriever_chain, ).assign( answer=document_chain, ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # pseudo # self.model(input) inputs = data.pop("inputs", data) output = self.conversational_retrieval_chain.invoke( { "messages": [ HumanMessage(content=inputs) ], } ) print(output['answer']) return output