|
import torch |
|
import locale |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.vectorstores import Chroma |
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from llm_for_langchain import LLM |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.messages import HumanMessage |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableBranch |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass() |
|
|
|
|
|
chat = LLM(model_name_or_path=path, bit4=False) |
|
|
|
|
|
embedding_function = HuggingFaceBgeEmbeddings( |
|
model_name="DMetaSoul/Dmeta-embedding", |
|
model_kwargs={'device': 'cuda'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
|
|
urls = [ |
|
"https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html", |
|
"https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD", |
|
"https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html", |
|
"https://www.thinkhk.com/article/2023-03/24/59874.html" |
|
] |
|
|
|
loader = WebBaseLoader(urls) |
|
data = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) |
|
|
|
compressor = LLMChainExtractor.from_llm(self.llm) |
|
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) |
|
|
|
SYSTEM_TEMPLATE = """ |
|
Answer the user's questions based on the below context. |
|
If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know": |
|
|
|
<context> |
|
{context} |
|
</context> |
|
""" |
|
|
|
question_answering_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
SYSTEM_TEMPLATE, |
|
), |
|
MessagesPlaceholder(variable_name="messages"), |
|
] |
|
) |
|
|
|
|
|
query_transforming_retriever_chain = RunnableBranch( |
|
( |
|
lambda x: len(x.get("messages", [])) == 1, |
|
|
|
(lambda x: x["messages"][-1].content) | retriever, |
|
), |
|
|
|
question_answering_prompt | chat | StrOutputParser() | retriever, |
|
).with_config(run_name="chat_retriever_chain") |
|
|
|
document_chain = create_stuff_documents_chain(chat, question_answering_prompt) |
|
|
|
self.conversational_retrieval_chain = RunnablePassthrough.assign( |
|
context=query_transforming_retriever_chain, |
|
).assign( |
|
answer=document_chain, |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
output = self.conversational_retrieval_chain.invoke( |
|
{ |
|
"messages": [ |
|
HumanMessage(content=inputs) |
|
], |
|
} |
|
) |
|
print(output['answer']) |
|
|
|
return output |
|
|
|
|
|
|