|
import torch |
|
import locale |
|
from typing import Dict, List, Any |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.vectorstores import Chroma |
|
from langchain import PromptTemplate, LLMChain |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.prompts.prompt import PromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain.document_loaders import WebBaseLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.chains.qa_with_sources import load_qa_with_sources_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.messages import HumanMessage |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough |
|
from operator import itemgetter |
|
from langchain.schema import format_document |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_id = path |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
padding_side="left", |
|
add_eos_token=True, |
|
use_fast=False |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024) |
|
chat = HuggingFacePipeline(pipeline=pipe) |
|
|
|
|
|
embedding_function = HuggingFaceBgeEmbeddings( |
|
model_name="BAAI/bge-large-zh", |
|
model_kwargs={'device': 'cuda'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
|
|
|
|
urls = [ |
|
"https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html", |
|
"https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD", |
|
"https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html", |
|
"https://www.thinkhk.com/article/2023-03/24/59874.html" |
|
] |
|
|
|
loader = WebBaseLoader(urls) |
|
data = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16) |
|
all_splits = text_splitter.split_documents(data) |
|
|
|
vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) |
|
|
|
compressor = LLMChainExtractor.from_llm(chat) |
|
retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) |
|
|
|
_template = """[INST] Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. |
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question: [/INST]""" |
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
template = """[INST] Answer the question based only on the following context: |
|
{context} |
|
|
|
Question: {question} [/INST] |
|
""" |
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_template(template) |
|
|
|
self.memory = ConversationBufferMemory( |
|
return_messages=True, output_key="answer", input_key="question" |
|
) |
|
|
|
|
|
|
|
loaded_memory = RunnablePassthrough.assign( |
|
chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"), |
|
) |
|
|
|
standalone_question = { |
|
"standalone_question": { |
|
"question": lambda x: x["question"], |
|
"chat_history": lambda x: get_buffer_string(x["chat_history"]), |
|
} |
|
| CONDENSE_QUESTION_PROMPT |
|
| chat |
|
| StrOutputParser(), |
|
} |
|
|
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") |
|
|
|
def _combine_documents( |
|
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" |
|
): |
|
doc_strings = [format_document(doc, document_prompt) for doc in docs] |
|
return document_separator.join(doc_strings) |
|
|
|
|
|
retrieved_documents = { |
|
"docs": itemgetter("standalone_question") | retriever, |
|
"question": lambda x: x["standalone_question"], |
|
} |
|
|
|
final_inputs = { |
|
"context": lambda x: _combine_documents(x["docs"]), |
|
"question": itemgetter("question"), |
|
} |
|
|
|
answer = { |
|
"answer": final_inputs | ANSWER_PROMPT | chat, |
|
"docs": itemgetter("docs"), |
|
} |
|
|
|
self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
inputs = data.pop("inputs",data) |
|
date = data.pop("date", None) |
|
|
|
result = self.final_chain.invoke({"question": inputs}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result |