from langchain.schema.runnable import RunnableParallel from langchain_core.runnables import RunnableLambda from langchain_core.prompts import PromptTemplate from langchain_huggingface import HuggingFaceEndpoint from langchain_core.output_parsers import StrOutputParser import logging logger = logging.getLogger(__name__) def get_chain( vectordb, repo_id="HuggingFaceH4/zephyr-7b-beta", task="text-generation", max_new_tokens=512, top_k=30, temperature=0.1, repetition_penalty=1.03, search_type="mmr", k=3, fetch_k=5, template="""Use the following sentences of context to answer the question at the end. If you don't know the answer, that is if the answer is not in the context, then just say that you don't know, don't try to make up an answer. Always say "Thanks for asking!" at the end of the answer. {context} Question: {question} Helpful Answer:""" ): search_kwargs = {"k": k, "fetch_k": fetch_k} logger.info(f'Setting up vectordb retriever with search_type={search_type} and search_kwargs={search_kwargs}') retriever = vectordb.as_retriever(search_type=search_type, search_kwargs=search_kwargs) logger.info('Setting up retrieval runnable') retrieval = RunnableParallel( { "context": RunnableLambda(lambda x: retriever.invoke(x["question"])), "question": RunnableLambda(lambda x: x["question"]) } ) logger.info(f'Setting up prompt from the template:\n{template}') prompt = PromptTemplate(input_variables=["context", "question"], template=template) logger.info(f'Instantiating llm with repo_id={repo_id}, task={task}, max_new_tokens={max_new_tokens}, top_k={top_k}, temperature={temperature} and repetition_penalty={repetition_penalty}') llm = HuggingFaceEndpoint( repo_id=repo_id, task=task, max_new_tokens=max_new_tokens, top_k=top_k, temperature=temperature, repetition_penalty=repetition_penalty, ) logger.info('Instantiating and returning chain = retrieval | prompt | llm | StrOutputParser()') return retrieval | prompt | llm | StrOutputParser()