Spaces:
Sleeping
Sleeping
| from langchain.chains.base import Chain | |
| from langchain.schema import BaseRetriever | |
| from langchain.llms import BaseLLM | |
| from langchain.prompts import PromptTemplate | |
| from pydantic import Field | |
| from typing import Dict, Any | |
| class MyCustomMemoryRetrievalChain(Chain): | |
| """ | |
| Custom chain cho phép truyền question, memory. | |
| Lấy docs từ retriever, trộn với prompt, gọi LLM. | |
| """ | |
| llm: BaseLLM = Field(...) | |
| retriever: BaseRetriever = Field(...) | |
| prompt: PromptTemplate = Field(...) | |
| output_key: str = "result" | |
| def input_keys(self) -> list: | |
| return ["question", "memory"] | |
| def output_keys(self) -> list: | |
| return [self.output_key] | |
| def _call(self, inputs: Dict[str, Any], run_manager=None) -> Dict[str, Any]: | |
| question = inputs["question"] | |
| memory = inputs["memory"] | |
| docs = self.retriever.get_relevant_documents(question) | |
| context = "\n".join(doc.page_content for doc in docs) | |
| final_prompt = self.prompt.format( | |
| question=question, | |
| memory=memory, | |
| context=context | |
| ) | |
| answer = self.llm(final_prompt) | |
| return {self.output_key: answer} | |