File size: 6,787 Bytes
f707439 5eff215 f707439 5eff215 fb88183 5eff215 fb88183 d40c010 fb88183 d40c010 5e5f0bf d40c010 82a98ce 5eff215 82a98ce d40c010 82a98ce 5eff215 fb88183 f707439 5eff215 f707439 5eff215 d40c010 fb88183 5cee31e c3714b0 5eff215 c3714b0 9672c1a 5eff215 f707439 c3714b0 f707439 c3714b0 5cee31e 82a98ce 5eff215 c4df466 5eff215 5e5f0bf 5eff215 82a98ce 5eff215 f707439 82a98ce f707439 fb88183 5eff215 fb88183 d40c010 5eff215 fb88183 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import locale
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.vectorstores import Chroma
from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
from operator import itemgetter
from langchain.schema import format_document
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
class EndpointHandler():
def __init__(self, path=""):
# Config LangChain
os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] =
# Create LLM
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map={"": "cuda"},
torch_dtype=torch.float16,
load_in_8bit=True
)
model.eval()
# model_kwargs = {
# "input_ids":input_ids,
# "max_new_tokens":1024,
# "do_sample":True,
# "top_k":50,
# "top_p":self.top_p,
# "temperature":self.temperature,
# "repetition_penalty":1.2,
# "eos_token_id":self.tokenizer.eos_token_id,
# "bos_token_id":self.tokenizer.bos_token_id,
# "pad_token_id":self.tokenizer.pad_token_id
# }
tokenizer = AutoTokenizer.from_pretrained(
model_id,
)
tokenizer.pad_token = tokenizer.eos_token
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024)
chat = HuggingFacePipeline(pipeline=pipe)
# Create Text-Embedding Model
embedding_function = HuggingFaceBgeEmbeddings(
model_name="DMetaSoul/Dmeta-embedding",
model_kwargs={'device': 'cuda'},
encode_kwargs={'normalize_embeddings': True}
)
# Load Vector db
urls = [
"https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html",
"https://www.thinkhk.com/article/2023-03/24/59874.html"
]
loader = WebBaseLoader(urls)
data = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
all_splits = text_splitter.split_documents(data)
vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
retriever = vectorstore.as_retriever()
compressor = LLMChainExtractor.from_llm(chat)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
_template = """[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question: [/INST]"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
template = """[INST] Answer the question based only on the following context:
{context}
Question: {question} [/INST]
"""
ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
self.memory = ConversationBufferMemory(
return_messages=True, output_key="answer", input_key="question"
)
# First we add a step to load memory
# This adds a "memory" key to the input object
loaded_memory = RunnablePassthrough.assign(
chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
)
# Now we calculate the standalone question
standalone_question = {
"standalone_question": {
"question": lambda x: x["question"],
"chat_history": lambda x: get_buffer_string(x["chat_history"]),
}
| CONDENSE_QUESTION_PROMPT
| chat
| StrOutputParser(),
}
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
):
doc_strings = [format_document(doc, document_prompt) for doc in docs]
return document_separator.join(doc_strings)
# Now we retrieve the documents
retrieved_documents = {
"docs": itemgetter("standalone_question") | retriever,
"question": lambda x: x["standalone_question"],
}
# Now we construct the inputs for the final prompt
final_inputs = {
"context": lambda x: _combine_documents(x["docs"]),
"question": itemgetter("question"),
}
# And finally, we do the part that returns the answers
answer = {
"answer": final_inputs | ANSWER_PROMPT | chat,
"docs": itemgetter("docs"),
}
# And now we put it all together!
self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# get inputs
inputs = data.pop("inputs",data)
date = data.pop("date", None)
result = self.final_chain.invoke({"question": inputs})
answer = result['answer']
# Note that the memory does not save automatically
# This will be improved in the future
# For now you need to save it yourself
# self.memory.save_context(inputs, {"answer": answer})
self.memory.load_memory_variables({})
return answer
|