0-Parth-D
Initial Commit
dcac338
from langchain_core.prompts import PromptTemplate
from langchain_chroma import Chroma
from langchain_community.llms import Ollama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_huggingface import HuggingFaceEmbeddings
import sys
def load_vectorstore():
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectorstore = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings,
collection_name="rag_code_assistant"
)
return vectorstore
def load_prompt():
template = (
"You are an expert Python coding assistant. Use the following documentation "
"excerpts to answer the user's question accurately. \n\n"
"IMPORTANT INSTRUCTIONS:\n"
"- Prioritize standard Python concepts over C-extensions or advanced typing unless specifically asked.\n"
"- If multiple types of answers are in the context, synthesize them into a complete answer.\n"
"- If the answer is not in the context, say 'I don't know'.\n\n"
"Context:\n"
"{context}\n\n"
"Question: {question}\n\n"
"Answer:"
)
prompt = PromptTemplate(
input_variables=["context", "question"],
template=template
)
return prompt
def load_llm():
llm = Ollama(model="llama3", temperature=0.1)
return llm
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def load_retriever(vectorstore):
retriever = vectorstore.as_retriever(
search_type="mmr", # Use MMR instead of standard similarity
search_kwargs={
"k": 4, # Return 4 diverse chunks to the LLM
"fetch_k": 20 # Search top 20, then pick the 4 most diverse
}
)
return retriever
def load_rag_chain(retriever, prompt, llm):
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def load_answer(rag_chain, retriever, question):
print("=== ASSISTANT'S ANSWER ===")
# We use .stream() instead of .invoke() for the RAG chain
# This yields words one by one as they are generated
full_answer = ""
for chunk in rag_chain.stream(question):
print(chunk, end="", flush=True) # Print word immediately, without newlines
full_answer += chunk
print("\n") # Add a final newline when done
# We still use .invoke() for the retriever to get the source documents
# (Database retrieval happens instantly, no need to stream it)
source_docs = retriever.invoke(question)
return full_answer, source_docs
vectorstore = load_vectorstore()
prompt = load_prompt()
llm = load_llm()
retriever = load_retriever(vectorstore)
rag_chain = load_rag_chain(retriever, prompt, llm)
if __name__ == "__main__":
vectorstore = load_vectorstore()
prompt = load_prompt()
llm = load_llm()
retriever = load_retriever(vectorstore)
rag_chain = load_rag_chain(retriever, prompt, llm)
question = sys.argv[1] if len(sys.argv) > 1 else "What is Python?"
print(f"\nUser Question: {question}\nThinking...\n")
# load_answer now handles the streaming print
answer, source_docs = load_answer(rag_chain, retriever, question)
print("\n=== SOURCES USED ===")
for idx, doc in enumerate(source_docs):
source_file = doc.metadata.get("source", "Unknown file")
print(f"{idx + 1}. {source_file}")