|
import torch |
|
import locale |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.vectorstores import Chroma |
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from llm_for_langchain import LLM |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.messages import HumanMessage |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableBranch |
|
from operator import itemgetter |
|
|
|
from langchain.memory import ConversationBufferMemory |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_API_KEY"] = "ls__a8607c3efd6e40c48dd11fd667900dc4" |
|
|
|
|
|
chat = LLM(model_name_or_path=path, bit4=False) |
|
|
|
|
|
embedding_function = HuggingFaceBgeEmbeddings( |
|
model_name="DMetaSoul/Dmeta-embedding", |
|
model_kwargs={'device': 'cuda'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
|
|
urls = [ |
|
"https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html", |
|
"https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD", |
|
"https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html", |
|
"https://www.thinkhk.com/article/2023-03/24/59874.html" |
|
] |
|
|
|
loader = WebBaseLoader(urls) |
|
data = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) |
|
|
|
compressor = LLMChainExtractor.from_llm(self.llm) |
|
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) |
|
|
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. |
|
|
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
template = """Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} |
|
""" |
|
ANSWER_PROMPT = ChatPromptTemplate.from_template(template) |
|
|
|
self.memory = ConversationBufferMemory( |
|
return_messages=True, output_key="answer", input_key="question" |
|
) |
|
|
|
|
|
|
|
loaded_memory = RunnablePassthrough.assign( |
|
chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("history"), |
|
) |
|
|
|
standalone_question = { |
|
"standalone_question": { |
|
"question": lambda x: x["question"], |
|
"chat_history": lambda x: get_buffer_string(x["chat_history"]), |
|
} |
|
| CONDENSE_QUESTION_PROMPT |
|
| chat(temperature=0) |
|
| StrOutputParser(), |
|
} |
|
|
|
retrieved_documents = { |
|
"docs": itemgetter("standalone_question") | retriever, |
|
"question": lambda x: x["standalone_question"], |
|
} |
|
|
|
final_inputs = { |
|
"context": lambda x: _combine_documents(x["docs"]), |
|
"question": itemgetter("question"), |
|
} |
|
|
|
answer = { |
|
"answer": final_inputs | ANSWER_PROMPT | chat, |
|
"docs": itemgetter("docs"), |
|
} |
|
|
|
self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
output = self.final_chain.invoke(inputs) |
|
print(output['answer']) |
|
|
|
|
|
|
|
|
|
self.memory.save_context(inputs, {"answer": result["answer"].content}) |
|
memory.load_memory_variables({}) |
|
|
|
return output |
|
|
|
|
|
|