|
import torch |
|
import locale |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.vectorstores import Chroma |
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.messages import HumanMessage |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough |
|
from operator import itemgetter |
|
from langchain.schema import format_document |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
|
|
|
|
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.1" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map={"": "cuda"}, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True |
|
) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024) |
|
chat = HuggingFacePipeline(pipeline=pipe) |
|
|
|
|
|
embedding_function = HuggingFaceBgeEmbeddings( |
|
model_name="DMetaSoul/Dmeta-embedding", |
|
model_kwargs={'device': 'cuda'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
|
|
urls = [ |
|
"https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html", |
|
"https://www.thinkhk.com/article/2023-03/24/59874.html" |
|
] |
|
|
|
loader = WebBaseLoader(urls) |
|
data = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function) |
|
retriever = vectorstore.as_retriever() |
|
|
|
compressor = LLMChainExtractor.from_llm(chat) |
|
compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=compressor, base_retriever=retriever |
|
) |
|
|
|
_template = """[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. |
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question: [/INST]""" |
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
template = """[INST] Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} [/INST] |
|
""" |
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_template(template) |
|
|
|
self.memory = ConversationBufferMemory( |
|
return_messages=True, output_key="answer", input_key="question" |
|
) |
|
|
|
|
|
|
|
loaded_memory = RunnablePassthrough.assign( |
|
chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"), |
|
) |
|
|
|
standalone_question = { |
|
"standalone_question": { |
|
"question": lambda x: x["question"], |
|
"chat_history": lambda x: get_buffer_string(x["chat_history"]), |
|
} |
|
| CONDENSE_QUESTION_PROMPT |
|
| chat |
|
| StrOutputParser(), |
|
} |
|
|
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") |
|
|
|
def _combine_documents( |
|
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" |
|
): |
|
doc_strings = [format_document(doc, document_prompt) for doc in docs] |
|
return document_separator.join(doc_strings) |
|
|
|
|
|
retrieved_documents = { |
|
"docs": itemgetter("standalone_question") | retriever, |
|
"question": lambda x: x["standalone_question"], |
|
} |
|
|
|
final_inputs = { |
|
"context": lambda x: _combine_documents(x["docs"]), |
|
"question": itemgetter("question"), |
|
} |
|
|
|
answer = { |
|
"answer": final_inputs | ANSWER_PROMPT | chat, |
|
"docs": itemgetter("docs"), |
|
} |
|
|
|
self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
inputs = data.pop("inputs",data) |
|
date = data.pop("date", None) |
|
|
|
result = self.final_chain.invoke({"question": inputs}) |
|
|
|
answer = result['answer'] |
|
|
|
|
|
|
|
|
|
|
|
self.memory.load_memory_variables({}) |
|
|
|
return answer |
|
|