Spaces:
Sleeping
Sleeping
File size: 4,852 Bytes
36c0029 9775192 36c0029 9775192 36c0029 9775192 36c0029 9775192 36c0029 03c9deb 36c0029 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
import pickle
import config
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
from memory import memory3
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
import os
chat_model_name = "HuggingFaceH4/zephyr-7b-alpha"
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
hf_token = os.getenv("apiToken")
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs)
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs)
prompt_template = config.DEFAULT_CHAT_TEMPLATE
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question", "chat_history"]
)
chain_type_kwargs = {"prompt": PROMPT}
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("cima_faiss_index", embeddings)
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5})
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, redundant_filter, relevant_filter]
)
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
with open("docs_data.pkl", "rb") as file:
docs = pickle.load(file)
bm25_retriever = BM25Retriever.from_texts(docs)
bm25_retriever.k = 2
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST].
Chat History:
{chat_history}
Follow-up user message: {question}
Rewritten user message:"""
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)
chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm,
chain_type="stuff",
retriever=ensemble_retriever,
combine_docs_chain_kwargs=chain_type_kwargs,
return_source_documents=True,
get_chat_history=lambda h : h,
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
memory=memory3,
condense_question_llm = reform_llm
)
|