Navya-Sree's picture
Create utils/rag_system.py
ed60c1b verified
import os
from typing import List, Dict
import chromadb
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader
class RAGSystem:
"""
Retrieval-Augmented Generation system for providing documentation context.
"""
def __init__(self, collection_name="python_docs"):
self.client = chromadb.PersistentClient(path="./chroma_db")
# Use sentence transformers for embeddings
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="all-MiniLM-L6-v2"
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=collection_name,
embedding_function=self.embedding_function
)
# Load default documents if collection is empty
if self.collection.count() == 0:
self._load_default_documents()
def _load_default_documents(self):
"""Load default Python documentation."""
default_docs = [
{
"id": "1",
"text": "Python functions are defined using the def keyword. Example: def hello(): return 'Hello'",
"metadata": {"source": "python_basics"}
},
{
"id": "2",
"text": "Use type hints for better code documentation. Example: def add(a: int, b: int) -> int:",
"metadata": {"source": "best_practices"}
},
{
"id": "3",
"text": "Always handle exceptions with try-except blocks to prevent crashes.",
"metadata": {"source": "error_handling"}
},
{
"id": "4",
"text": "Use list comprehensions for concise list creation: [x*2 for x in range(10)]",
"metadata": {"source": "python_tips"}
},
{
"id": "5",
"text": "Document your code with docstrings. Use triple quotes for multi-line documentation.",
"metadata": {"source": "documentation"}
}
]
# Add documents to collection
self.collection.add(
documents=[doc["text"] for doc in default_docs],
metadatas=[doc["metadata"] for doc in default_docs],
ids=[doc["id"] for doc in default_docs]
)
def add_document(self, text: str, source: str = "user"):
"""Add a new document to the knowledge base."""
doc_id = f"doc_{self.collection.count() + 1}"
self.collection.add(
documents=[text],
metadatas=[{"source": source}],
ids=[doc_id]
)
def search(self, query: str, n_results: int = 3) -> List[Dict]:
"""
Search for relevant documents.
Args:
query: Search query
n_results: Number of results to return
Returns:
List of relevant documents
"""
results = self.collection.query(
query_texts=[query],
n_results=n_results
)
documents = []
if results['documents']:
for i, doc in enumerate(results['documents'][0]):
documents.append({
"text": doc,
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i]
})
return documents
def get_context(self, query: str) -> str:
"""
Get relevant context for a coding query.
Args:
query: Coding task or question
Returns:
Context string from relevant documents
"""
relevant_docs = self.search(query)
if not relevant_docs:
return ""
# Combine top documents into context
context_parts = ["Relevant documentation:"]
for i, doc in enumerate(relevant_docs[:2]): # Use top 2 documents
context_parts.append(f"{i+1}. {doc['text']}")
return "\n".join(context_parts)