Rohan Kataria
final changes
46f398d
raw
history blame
5.66 kB
import os
import sys
import docarray
sys.path.append('../..')
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.document_loaders import TextLoader
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOllama
from langchain.document_loaders import TextLoader
from langchain.document_loaders import GitLoader
from langchain.memory import ConversationBufferMemory, ConversationBufferWindowMemory
from langchain.vectorstores import Chroma
from langchain.embeddings import OllamaEmbeddings
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate, ChatPromptTemplate
import datetime
import shutil
# Function to load the data from github using langchain with string type url, string type branch, string type file_filter
def loader(url: str, branch: str, file_filter: str):
repo_path = "./github_repo"
if os.path.exists(repo_path):
shutil.rmtree(repo_path)
loader = GitLoader(
clone_url= url,
repo_path="./github_repo/",
branch=branch,
file_filter=lambda file_path: file_path.endswith(tuple(file_filter.split(','))) # Filter out files in Data but whole repo is cloned
)
data = loader.load()
return data
#Function to split the data into chunks using recursive character text splitter
def split_data(data):
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=150,
length_function=len, # Function to measure the length of chunks while splitting
add_start_index=True # Include the starting position of each chunk in metadata
)
chunks = splitter.split_documents(data)
return chunks
#Function to ingest the chunks into a vectorstore of doc
def ingest_chunks(chunks):
embedding = OllamaEmbeddings(
base_url='https://thewise-ollama-server.hf.space',
model="nomic-embed-text",
)
vector_store = DocArrayInMemorySearch.from_documents(chunks, embedding)
repo_path = "./github_repo"
if os.path.exists(repo_path):
shutil.rmtree(repo_path)
return vector_store
#Retreival function to get the data from the database and reply to the user
def retreival(vector_store, k):
#Creating LLM
llm = ChatOllama(
base_url='https://thewise-ollama-server.hf.space',
model="codellama:34b")
# Define the system message template
#Adding CHAT HISTORY to the System template explicitly because mainly Chat history goes to Condense the Human Question with Backround (Not template), but System template goes straight the LLM Chain
#Explicitly adding chat history to access previous chats and answer "what is my previous question?"
#Great thing this also sends the chat history to the LLM Model along with the context and question
system_template = """You're a code summarisation assistant. Given the following extracted parts of a long document as "CONTEXT" create a final answer.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
Only If asked to create a "DIAGRAM" for code use "MERMAID SYNTAX LANGUAGE" in your answer from "CONTEXT" and "CHAT HISTORY" with a short explanation of diagram.
CONTEXT: {context}
=======
CHAT HISTORY: {chat_history}
=======
FINAL ANSWER:"""
human_template = """{question}"""
# ai_template = """
# FINAL ANSWER:"""
# Create the chat prompt templates
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template(human_template)
# AIMessagePromptTemplate.from_template(ai_template)
]
PROMPT = ChatPromptTemplate.from_messages(messages)
#Creating memory
# memory = ConversationBufferMemory(
# memory_key="chat_history",
# input_key="question",
# output_key="answer",
# return_messages=True)
memory = ConversationBufferWindowMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
k=5)
#Creating the retriever, this can also be a contextual compressed retriever
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k}) #search_type can be "similarity" or "mmr"
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff", #chain type can be refine, stuff, map_reduce
retriever=retriever,
memory=memory,
return_source_documents=True, #When used these 2 properties, the output gets 3 properties: answer, source_document, source_document_score and then have to speocify input and output key in memory for it to work
combine_docs_chain_kwargs=dict({"prompt": PROMPT})
)
return chain
#Class using all above components to create QA system
class ConversationalResponse:
def __init__(self, url, branch, file_filter):
self.url = url
self.branch = branch
self.file_filter = file_filter
self.data = loader(self.url, self.branch, self.file_filter)
self.chunks = split_data(self.data)
self.vector_store = ingest_chunks(self.chunks)
self.chain_type = "stuff"
self.k = 10
self.chain = retreival(self.vector_store, self.k)
def __call__(self, question):
agent = self.chain(question)
return agent['answer']