t2m / src /rag /vector_store.py
thanhkt's picture
Upload 75 files
9b5ca29 verified
import json
import os
import ast
from typing import List, Dict, Tuple, Optional
import uuid
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import Language
from langchain_core.embeddings import Embeddings
import statistics
import tiktoken
from tqdm import tqdm
from langfuse import Langfuse
from langchain_community.embeddings import HuggingFaceEmbeddings
import re
from mllm_tools.utils import _prepare_text_inputs
from task_generator import get_prompt_detect_plugins
class CodeAwareTextSplitter:
"""Enhanced text splitter that understands code structure."""
def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
def split_python_file(self, content: str, metadata: dict) -> List[Document]:
"""Split Python files preserving code structure."""
documents = []
try:
tree = ast.parse(content)
# Extract classes and functions with their docstrings
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
# Get the source code segment
start_line = node.lineno
end_line = getattr(node, 'end_lineno', start_line + 20)
lines = content.split('\n')
code_segment = '\n'.join(lines[start_line-1:end_line])
# Extract docstring
docstring = ast.get_docstring(node) or ""
# Create enhanced content
enhanced_content = f"""
Type: {"Class" if isinstance(node, ast.ClassDef) else "Function"}
Name: {node.name}
Docstring: {docstring}
Code:
```python
{code_segment}
```
""".strip()
# Enhanced metadata
enhanced_metadata = {
**metadata,
'type': 'class' if isinstance(node, ast.ClassDef) else 'function',
'name': node.name,
'start_line': start_line,
'end_line': end_line,
'has_docstring': bool(docstring),
'docstring': docstring[:200] + "..." if len(docstring) > 200 else docstring
}
documents.append(Document(
page_content=enhanced_content,
metadata=enhanced_metadata
))
# Also create chunks for imports and module-level code
imports_and_constants = self._extract_imports_and_constants(content)
if imports_and_constants:
documents.append(Document(
page_content=f"Module-level imports and constants:\n\n{imports_and_constants}",
metadata={**metadata, 'type': 'module_level', 'name': 'imports_constants'}
))
except SyntaxError:
# Fallback to regular text splitting for invalid Python
splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap
)
documents = splitter.split_documents([Document(page_content=content, metadata=metadata)])
return documents
def split_markdown_file(self, content: str, metadata: dict) -> List[Document]:
"""Split Markdown files preserving structure."""
documents = []
# Split by headers while preserving hierarchy
sections = self._split_by_headers(content)
for section in sections:
# Extract code blocks
code_blocks = self._extract_code_blocks(section['content'])
# Create document for text content
text_content = self._remove_code_blocks(section['content'])
if text_content.strip():
enhanced_metadata = {
**metadata,
'type': 'markdown_section',
'header': section['header'],
'level': section['level'],
'has_code_blocks': len(code_blocks) > 0
}
documents.append(Document(
page_content=f"Header: {section['header']}\n\n{text_content}",
metadata=enhanced_metadata
))
# Create separate documents for code blocks
for i, code_block in enumerate(code_blocks):
enhanced_metadata = {
**metadata,
'type': 'code_block',
'language': code_block['language'],
'in_section': section['header'],
'block_index': i
}
documents.append(Document(
page_content=f"Code example in '{section['header']}':\n\n```{code_block['language']}\n{code_block['code']}\n```",
metadata=enhanced_metadata
))
return documents
def _extract_imports_and_constants(self, content: str) -> str:
"""Extract imports and module-level constants."""
lines = content.split('\n')
relevant_lines = []
for line in lines:
stripped = line.strip()
if (stripped.startswith('import ') or
stripped.startswith('from ') or
(stripped and not stripped.startswith('def ') and
not stripped.startswith('class ') and
not stripped.startswith('#') and
'=' in stripped and stripped.split('=')[0].strip().isupper())):
relevant_lines.append(line)
return '\n'.join(relevant_lines)
def _split_by_headers(self, content: str) -> List[Dict]:
"""Split markdown content by headers."""
sections = []
lines = content.split('\n')
current_section = {'header': 'Introduction', 'level': 0, 'content': ''}
for line in lines:
header_match = re.match(r'^(#{1,6})\s+(.+)$', line)
if header_match:
# Save previous section
if current_section['content'].strip():
sections.append(current_section)
# Start new section
level = len(header_match.group(1))
header = header_match.group(2)
current_section = {'header': header, 'level': level, 'content': ''}
else:
current_section['content'] += line + '\n'
# Add last section
if current_section['content'].strip():
sections.append(current_section)
return sections
def _extract_code_blocks(self, content: str) -> List[Dict]:
"""Extract code blocks from markdown content."""
code_blocks = []
pattern = r'```(\w+)?\n(.*?)\n```'
for match in re.finditer(pattern, content, re.DOTALL):
language = match.group(1) or 'text'
code = match.group(2)
code_blocks.append({'language': language, 'code': code})
return code_blocks
def _remove_code_blocks(self, content: str) -> str:
"""Remove code blocks from content."""
pattern = r'```\w*\n.*?\n```'
return re.sub(pattern, '', content, flags=re.DOTALL)
class EnhancedRAGVectorStore:
"""Enhanced RAG vector store with improved code understanding."""
def __init__(self,
chroma_db_path: str = "chroma_db",
manim_docs_path: str = "rag/manim_docs",
embedding_model: str = "hf:ibm-granite/granite-embedding-30m-english",
trace_id: str = None,
session_id: str = None,
use_langfuse: bool = True,
helper_model = None):
self.chroma_db_path = chroma_db_path
self.manim_docs_path = manim_docs_path
self.embedding_model = embedding_model
self.trace_id = trace_id
self.session_id = session_id
self.use_langfuse = use_langfuse
self.helper_model = helper_model
self.enc = tiktoken.encoding_for_model("gpt-4")
self.plugin_stores = {}
self.code_splitter = CodeAwareTextSplitter()
self.vector_store = self._load_or_create_vector_store()
def _load_or_create_vector_store(self):
"""Enhanced vector store creation with better document processing."""
print("Creating enhanced vector store with code-aware processing...")
core_path = os.path.join(self.chroma_db_path, "manim_core_enhanced")
if os.path.exists(core_path):
print("Loading existing enhanced ChromaDB...")
self.core_vector_store = Chroma(
collection_name="manim_core_enhanced",
persist_directory=core_path,
embedding_function=self._get_embedding_function()
)
else:
print("Creating new enhanced ChromaDB...")
self.core_vector_store = self._create_enhanced_core_store()
# Process plugins with enhanced splitting
plugin_docs_path = os.path.join(self.manim_docs_path, "plugin_docs")
if os.path.exists(plugin_docs_path):
for plugin_name in os.listdir(plugin_docs_path):
plugin_store_path = os.path.join(self.chroma_db_path, f"manim_plugin_{plugin_name}_enhanced")
if os.path.exists(plugin_store_path):
print(f"Loading existing enhanced plugin store: {plugin_name}")
self.plugin_stores[plugin_name] = Chroma(
collection_name=f"manim_plugin_{plugin_name}_enhanced",
persist_directory=plugin_store_path,
embedding_function=self._get_embedding_function()
)
else:
print(f"Creating new enhanced plugin store: {plugin_name}")
plugin_path = os.path.join(plugin_docs_path, plugin_name)
if os.path.isdir(plugin_path):
plugin_store = Chroma(
collection_name=f"manim_plugin_{plugin_name}_enhanced",
embedding_function=self._get_embedding_function(),
persist_directory=plugin_store_path
)
plugin_docs = self._process_documentation_folder_enhanced(plugin_path)
if plugin_docs:
self._add_documents_to_store(plugin_store, plugin_docs, plugin_name)
self.plugin_stores[plugin_name] = plugin_store
return self.core_vector_store
def _get_embedding_function(self) -> Embeddings:
"""Enhanced embedding function with better model selection."""
if self.embedding_model.startswith('hf:'):
model_name = self.embedding_model[3:]
print(f"Using HuggingFaceEmbeddings with model: {model_name}")
# Use better models for code understanding
if 'code' not in model_name.lower():
print("Consider using a code-specific embedding model like 'microsoft/codebert-base'")
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
else:
raise ValueError("Only HuggingFace embeddings are supported in this configuration.")
def _create_enhanced_core_store(self):
"""Create enhanced core store with better document processing."""
core_vector_store = Chroma(
collection_name="manim_core_enhanced",
embedding_function=self._get_embedding_function(),
persist_directory=os.path.join(self.chroma_db_path, "manim_core_enhanced")
)
core_docs = self._process_documentation_folder_enhanced(
os.path.join(self.manim_docs_path, "manim_core")
)
if core_docs:
self._add_documents_to_store(core_vector_store, core_docs, "manim_core_enhanced")
return core_vector_store
def _process_documentation_folder_enhanced(self, folder_path: str) -> List[Document]:
"""Enhanced document processing with code-aware splitting."""
all_docs = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(('.md', '.py')):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
base_metadata = {
'source': file_path,
'filename': file,
'file_type': 'python' if file.endswith('.py') else 'markdown',
'relative_path': os.path.relpath(file_path, folder_path)
}
if file.endswith('.py'):
docs = self.code_splitter.split_python_file(content, base_metadata)
else: # .md files
docs = self.code_splitter.split_markdown_file(content, base_metadata)
# Add source prefix to content
for doc in docs:
doc.page_content = f"Source: {file_path}\nType: {doc.metadata.get('type', 'unknown')}\n\n{doc.page_content}"
all_docs.extend(docs)
except Exception as e:
print(f"Error loading file {file_path}: {e}")
print(f"Processed {len(all_docs)} enhanced document chunks from {folder_path}")
return all_docs
def _add_documents_to_store(self, vector_store: Chroma, documents: List[Document], store_name: str):
"""Enhanced document addition with better batching."""
print(f"Adding {len(documents)} enhanced documents to {store_name} store")
# Group documents by type for better organization
doc_types = {}
for doc in documents:
doc_type = doc.metadata.get('type', 'unknown')
if doc_type not in doc_types:
doc_types[doc_type] = []
doc_types[doc_type].append(doc)
print(f"Document types distribution: {dict((k, len(v)) for k, v in doc_types.items())}")
# Calculate token statistics
token_lengths = [len(self.enc.encode(doc.page_content)) for doc in documents]
print(f"Token length statistics for {store_name}: "
f"Min: {min(token_lengths)}, Max: {max(token_lengths)}, "
f"Mean: {sum(token_lengths) / len(token_lengths):.1f}, "
f"Median: {statistics.median(token_lengths):.1f}")
batch_size = 10
for i in tqdm(range(0, len(documents), batch_size), desc=f"Processing {store_name} enhanced batches"):
batch_docs = documents[i:i + batch_size]
batch_ids = [str(uuid.uuid4()) for _ in batch_docs]
vector_store.add_documents(documents=batch_docs, ids=batch_ids)
vector_store.persist()
def find_relevant_docs(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> str:
"""Find relevant documents - compatibility method that calls the enhanced version."""
return self.find_relevant_docs_enhanced(queries, k, trace_id, topic, scene_number)
def find_relevant_docs_enhanced(self, queries: List[Dict], k: int = 5, trace_id: str = None, topic: str = None, scene_number: int = None) -> str:
"""Enhanced document retrieval with type-aware search."""
# Separate queries by intent
code_queries = [q for q in queries if any(keyword in q["query"].lower()
for keyword in ["function", "class", "method", "import", "code", "implementation"])]
concept_queries = [q for q in queries if q not in code_queries]
all_results = []
# Search with different strategies for different query types
for query in code_queries:
results = self._search_with_filters(
query["query"],
k=k,
filter_metadata={'type': ['function', 'class', 'code_block']},
boost_code=True
)
all_results.extend(results)
for query in concept_queries:
results = self._search_with_filters(
query["query"],
k=k,
filter_metadata={'type': ['markdown_section', 'module_level']},
boost_code=False
)
all_results.extend(results)
# Remove duplicates and format results
unique_results = self._remove_duplicates(all_results)
return self._format_results(unique_results)
def _search_with_filters(self, query: str, k: int, filter_metadata: Dict = None, boost_code: bool = False) -> List[Dict]:
"""Search with metadata filters and result boosting."""
# This is a simplified version - in practice, you'd implement proper filtering
core_results = self.core_vector_store.similarity_search_with_relevance_scores(
query=query, k=k, score_threshold=0.3
)
formatted_results = []
for result in core_results:
doc, score = result
# Boost scores for code-related results if needed
if boost_code and doc.metadata.get('type') in ['function', 'class', 'code_block']:
score *= 1.2
formatted_results.append({
"query": query,
"source": doc.metadata['source'],
"content": doc.page_content,
"score": score,
"type": doc.metadata.get('type', 'unknown'),
"metadata": doc.metadata
})
return formatted_results
def _remove_duplicates(self, results: List[Dict]) -> List[Dict]:
"""Remove duplicate results based on content similarity."""
unique_results = []
seen_content = set()
for result in sorted(results, key=lambda x: x['score'], reverse=True):
content_hash = hash(result['content'][:200]) # Hash first 200 chars
if content_hash not in seen_content:
unique_results.append(result)
seen_content.add(content_hash)
return unique_results[:10] # Return top 10 unique results
def _format_results(self, results: List[Dict]) -> str:
"""Format results with enhanced presentation."""
if not results:
return "No relevant documentation found."
formatted = "## Relevant Documentation\n\n"
# Group by type
by_type = {}
for result in results:
result_type = result['type']
if result_type not in by_type:
by_type[result_type] = []
by_type[result_type].append(result)
for result_type, type_results in by_type.items():
formatted += f"### {result_type.replace('_', ' ').title()} Documentation\n\n"
for result in type_results:
formatted += f"**Source:** {result['source']}\n"
formatted += f"**Relevance Score:** {result['score']:.3f}\n"
formatted += f"**Content:**\n```\n{result['content'][:500]}...\n```\n\n"
return formatted
# Update the existing RAGVectorStore class alias for backward compatibility
RAGVectorStore = EnhancedRAGVectorStore