VlaTal's picture
init
17dcbf0
raw
history blame
6.03 kB
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import os
import pprint
from dotenv import load_dotenv
from typing import List, Tuple, Optional, Union
from loguru import logger as log
import tiktoken
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.text_splitter import Document
from langchain.output_parsers import PydanticOutputParser
from langchain_openai import AzureOpenAIEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.document_loaders.pdf import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
import streamlit as st
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
def _calc_tokens(splits: List[Document]) -> int:
tokens = 0
for doc in splits:
encoding = tiktoken.get_encoding('cl100k_base')
tokens += len(encoding.encode(doc.page_content))
return tokens
class LineList(BaseModel):
lines: List[str] = Field(description="Lines of text")
class LineListOutputParser(PydanticOutputParser):
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)
def parse(self, text: str) -> LineList:
lines = text.strip().split("\n")
return LineList(lines=lines)
class Assistant:
def __init__(self):
load_dotenv()
self.db_dir = 'docs/chroma/'
self.embedding = AzureOpenAIEmbeddings(azure_deployment="ada_dev")
self.llm = AzureChatOpenAI(
azure_deployment="35_turbo",
model_name="gpt-35-turbo",
temperature=0
)
os.environ["AZURE_OPENAI_API_KEY"] = st.secrets["AZURE_OPENAI_API_KEY"]
os.environ["OPENAI_API_TYPE"] = st.secrets["OPENAI_API_TYPE"]
os.environ["OPENAI_API_VERSION"] = st.secrets["OPENAI_API_VERSION"]
os.environ["AZURE_OPENAI_ENDPOINT"] = st.secrets["AZURE_OPENAI_ENDPOINT"]
self.make_template()
def run(self):
st.title('Гаррі Поттер асистент')
instruction = st.text_input('Питання', '')
if st.button('Згенерувати відповідь'):
result, docs = self.stuff_search(instruction)
st.subheader('Відповідь')
st.text(result)
st.header('Знайдені чанки')
for doc in docs:
st.subheader(f'Сторінка {doc.metadata.get("page")}')
st.text(doc.page_content)
def make_template(self):
template = """Ти ШІ консультант. Твоя задача відповідати на запитання користувачів. Запитання будуть про книгу "Гаррі Поттер та філософський камінь". Додатково тобі будуть надані частини тексту з книги в якості контексту, з яких ти повинен надати відповідь. Ти повинен використовувати для відповіді лише наданий контекст і не додумувати нічого від себе. Якщо в частинах тексту немає відповідної інформації, щоб надати відповідь - вибачся та скажи, що не знаєш відповіді. ВАЖЛИВО відповідати виключно УКРАЇНСЬКОЮ мовою.
Контекст:
{context}
Запитання: {question}
Відповідь:"""
self.prompt = PromptTemplate.from_template(template)
def load_pdf(self, file_name: str) -> List[Document]:
log.info("Loading pdf")
loader = PyPDFLoader(f"files_to_load/{file_name}")
return loader.load()
def split_documents(self, pages: List[dict]) -> Union[List[Document], None]:
log.info("Splitting pdf")
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=1000,
chunk_overlap=150,
length_function=len
)
return text_splitter.split_documents(pages)
def save_in_db(self, splits: List[Document]):
log.info("Saving chunks in db")
if len(splits) == 0:
log.warning(
"There are no splits to save in db. Please provide them in arguments or call the split_documents(headers_to_split, pages) method")
return None
vectordb = Chroma.from_documents(
documents=splits,
embedding=self.embedding,
persist_directory=self.db_dir
)
log.info(f"{vectordb._collection.count()} rows were saved")
log.info(f"{_calc_tokens(splits)} tokens were affected")
return True
def stuff_search(self, question: str):
vectordb = Chroma(persist_directory=self.db_dir,
embedding_function=self.embedding)
qa_chain = RetrievalQA.from_chain_type(
self.llm,
retriever=vectordb.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt": self.prompt}
)
result = qa_chain({"query": question})
log.info(f'Questing: {question}')
log.info(f'Result: {result["result"]}')
log.info("DOCUMENTS:")
for doc in result["source_documents"]:
log.info(doc)
return result["result"], result["source_documents"]
if __name__ == "__main__":
assistant = Assistant()
vectordb = Chroma(persist_directory="docs/chroma/",
embedding_function=assistant.embedding)
if(len(vectordb.get().get("documents")) == 0):
pdf = assistant.load_pdf("Harry_Potter.pdf")
splits = assistant.split_documents(pdf)
assistant.save_in_db(splits)
assistant.run()