from langchain.chains import ConversationalRetrievalChain from langchain.prompts import PromptTemplate import pickle import config from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever from memory import memory3 from langchain.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.document_transformers import EmbeddingsRedundantFilter from langchain.retrievers.document_compressors import DocumentCompressorPipeline from langchain.text_splitter import CharacterTextSplitter from pydantic import BaseModel, Field from typing import Any, Optional, Dict, List from huggingface_hub import InferenceClient from langchain.llms.base import LLM import os chat_model_name = "HuggingFaceH4/zephyr-7b-alpha" reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1" hf_token = os.getenv("apiToken") kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True} class KwArgsModel(BaseModel): kwargs: Dict[str, Any] = Field(default_factory=dict) class CustomInferenceClient(LLM, KwArgsModel): model_name: str inference_client: InferenceClient def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): inference_client = InferenceClient(model=model_name, token=hf_token) super().__init__( model_name=model_name, hf_token=hf_token, kwargs=kwargs, inference_client=inference_client ) def _call( self, prompt: str, stop: Optional[List[str]] = None ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) response = ''.join(response_gen) return response @property def _llm_type(self) -> str: return "custom" @property def _identifying_params(self) -> dict: return {"model_name": self.model_name} chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs) reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs) prompt_template = config.DEFAULT_CHAT_TEMPLATE PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question", "chat_history"] ) chain_type_kwargs = {"prompt": PROMPT} embeddings = HuggingFaceEmbeddings() vectorstore = FAISS.load_local("cima_faiss_index", embeddings) retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5}) splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5) pipeline_compressor = DocumentCompressorPipeline( transformers=[splitter, redundant_filter, relevant_filter] ) compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) with open("docs_data.pkl", "rb") as file: docs = pickle.load(file) bm25_retriever = BM25Retriever.from_texts(docs) bm25_retriever.k = 2 bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. Chat History: {chat_history} Follow-up user message: {question} Rewritten user message:""" CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template) chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm, chain_type="stuff", retriever=ensemble_retriever, combine_docs_chain_kwargs=chain_type_kwargs, return_source_documents=True, get_chat_history=lambda h : h, condense_question_prompt=CUSTOM_QUESTION_PROMPT, memory=memory3, condense_question_llm = reform_llm )