confluence_qa / confluence_qa.py
gkrthk
fix error
eeab813
raw
history blame
2.97 kB
from langchain.document_loaders import ConfluenceLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,pipeline
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
class ConfluenceQA:
def init_embeddings(self) -> None:
self.embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
def define_model(self) -> None:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024)
self.llm = HuggingFacePipeline(pipeline = pipe,model_kwargs={"temperature": 0, "max_length": 1024},)
def store_in_vector_db(self) -> None:
persist_directory = self.config.get("persist_directory",None)
confluence_url = self.config.get("confluence_url",None)
username = self.config.get("username",None)
api_key = self.config.get("api_key",None)
space_key = self.config.get("space_key",None)
include_attachment = self.config.get("include_attachment", False)
loader = ConfluenceLoader(
url=confluence_url, username=username, api_key=api_key
)
documents = loader.load(include_attachments=include_attachment, limit=50, space_key=space_key)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=150)
documents = text_splitter.split_documents(documents)
# text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10) # This the encoding for text-embedding-ada-002
# texts = text_splitter.split_documents(texts)
self.db = Chroma.from_documents(documents, self.embeddings)
def retrieve_qa_chain(self) -> None:
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate(
template=template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": QA_CHAIN_PROMPT}
self.qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=self.db.as_retriever(), chain_type_kwargs=chain_type_kwargs)
def __init__(self,config:dict = {}) -> None:
self.db=None
self.embeddings=None
self.llm=None
self.config=config
self.qa=None
def qa_bot(self, query:str):
result = self.qa.run(query)
return result