Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Optional | |
| from openai import OpenAI | |
| from src.code_parser import CodeParser | |
| from src.database import Repository, get_db_session, init_db, resolve_database_url | |
| from src.embeddings import EmbeddingGenerator | |
| from src.hybrid_search import HybridSearchEngine | |
| from src.repo_fetcher import RepoFetcher | |
| from src.vector_store import QdrantVectorStore | |
| class SessionCancelledError(RuntimeError): | |
| pass | |
| class CodebaseRAGSystem: | |
| def __init__( | |
| self, | |
| database_url: str = None, | |
| repo_dir: str = None, | |
| index_path: str = None, | |
| ): | |
| self.database_url = database_url or os.getenv( | |
| "DATABASE_URL", "sqlite:///./codebase_rag.db" | |
| ) | |
| self.database_url = resolve_database_url(self.database_url) | |
| init_db(self.database_url) | |
| print(f"[database] Using database_url={self.database_url}", flush=True) | |
| self.repo_fetcher = RepoFetcher(base_dir=repo_dir) | |
| self.parser = CodeParser() | |
| self.embedder = EmbeddingGenerator() | |
| self.vector_store = QdrantVectorStore( | |
| embedding_dim=self.embedder.get_embedding_dim(), | |
| index_path=index_path or "./data/faiss/codebase_index", | |
| persist=False, | |
| ) | |
| self.hybrid_search = HybridSearchEngine( | |
| reranker_model=os.getenv( | |
| "RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2" | |
| ) | |
| ) | |
| self.app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "local")).lower() | |
| self.llm_provider = os.getenv("LLM_PROVIDER", "vertex_ai").lower() | |
| self.llm_client = None | |
| self.llm_model = "" | |
| self._configure_llm() | |
| self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120")) | |
| self.indexing_progress: Dict[int, dict] = {} | |
| self.repo_chunks: Dict[int, List[dict]] = {} | |
| self.cancelled_repo_ids = set() | |
| self.rebuild_indexes() | |
| def rebuild_indexes(self): | |
| session = get_db_session(self.database_url) | |
| try: | |
| self.vector_store.clear() | |
| self.repo_chunks.clear() | |
| self.indexing_progress.clear() | |
| self.cancelled_repo_ids.clear() | |
| repos = session.query(Repository).all() | |
| self._delete_repositories(session, repos, track_cancellation=False) | |
| self.cancelled_repo_ids.clear() | |
| session.commit() | |
| finally: | |
| session.close() | |
| def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository: | |
| info = self.repo_fetcher.parse_github_url(github_url) | |
| registry_key = self._build_registry_key(session_key, github_url) | |
| session = get_db_session(self.database_url) | |
| try: | |
| self._cleanup_expired_sessions(session) | |
| repo = session.query(Repository).filter_by(github_url=registry_key).first() | |
| if repo is None: | |
| repo = Repository( | |
| github_url=registry_key, | |
| source_url=github_url, | |
| session_key=session_key, | |
| session_expires_at=self._session_expiry(), | |
| owner=info["owner"], | |
| name=info["repo"], | |
| branch=info["branch"], | |
| status="queued", | |
| ) | |
| session.add(repo) | |
| session.flush() | |
| self.cancelled_repo_ids.discard(repo.id) | |
| else: | |
| repo.source_url = github_url | |
| repo.session_key = session_key | |
| repo.session_expires_at = self._session_expiry() | |
| repo.owner = info["owner"] | |
| repo.name = info["repo"] | |
| repo.branch = info["branch"] | |
| repo.status = "queued" | |
| repo.error_message = None | |
| repo.file_count = 0 | |
| repo.chunk_count = 0 | |
| repo.indexed_at = None | |
| self.cancelled_repo_ids.discard(repo.id) | |
| self.hybrid_search.remove_repository(repo.id) | |
| self.vector_store.remove_repository(repo.id) | |
| self.repo_chunks.pop(repo.id, None) | |
| session.commit() | |
| session.refresh(repo) | |
| return repo | |
| finally: | |
| session.close() | |
| def index_repository(self, repo_id: int): | |
| session = get_db_session(self.database_url) | |
| try: | |
| self._cleanup_expired_sessions(session) | |
| repo = session.query(Repository).filter_by(id=repo_id).first() | |
| if repo is None: | |
| raise ValueError("Repository not found") | |
| self._ensure_repo_not_cancelled(repo.id) | |
| print(f"[indexing] Starting repository index repo_id={repo.id}", flush=True) | |
| repo.status = "indexing" | |
| repo.error_message = None | |
| repo.session_expires_at = self._session_expiry() | |
| session.commit() | |
| self._set_progress(repo.id, phase="cloning", message="Cloning repository") | |
| clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url) | |
| self._ensure_repo_not_cancelled(repo.id) | |
| repo.local_path = None | |
| repo.branch = clone_info["branch"] | |
| print( | |
| f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} " | |
| f"path={clone_info['local_path']}", | |
| flush=True, | |
| ) | |
| source_files = list(self.repo_fetcher.iter_source_files(clone_info["local_path"])) | |
| total_files = len(source_files) | |
| print( | |
| f"[indexing] Found {total_files} source files for repo_id={repo.id}", | |
| flush=True, | |
| ) | |
| self._set_progress( | |
| repo.id, | |
| phase="parsing", | |
| message=f"Scanning {total_files} source files", | |
| total_files=total_files, | |
| processed_files=0, | |
| discovered_chunks=0, | |
| ) | |
| chunk_payloads = [] | |
| file_count = 0 | |
| for index, file_path in enumerate(source_files, start=1): | |
| file_chunks = self.parser.chunk_file(str(file_path), clone_info["local_path"]) | |
| if not file_chunks: | |
| self._set_progress( | |
| repo.id, | |
| phase="parsing", | |
| message=f"Parsed {index}/{total_files} files", | |
| total_files=total_files, | |
| processed_files=index, | |
| discovered_chunks=len(chunk_payloads), | |
| ) | |
| continue | |
| file_count += 1 | |
| chunk_payloads.extend(file_chunks) | |
| self._set_progress( | |
| repo.id, | |
| phase="parsing", | |
| message=f"Parsed {index}/{total_files} files", | |
| total_files=total_files, | |
| processed_files=index, | |
| discovered_chunks=len(chunk_payloads), | |
| ) | |
| searchable_texts = [chunk["searchable_text"] for chunk in chunk_payloads] | |
| print( | |
| f"[indexing] Parsed repo_id={repo.id} files={file_count} chunks={len(searchable_texts)}", | |
| flush=True, | |
| ) | |
| self._set_progress( | |
| repo.id, | |
| phase="embedding", | |
| message=f"Embedding {len(searchable_texts)} chunks", | |
| total_files=total_files, | |
| processed_files=total_files, | |
| discovered_chunks=len(chunk_payloads), | |
| total_chunks=len(chunk_payloads), | |
| embedded_chunks=0, | |
| ) | |
| embeddings = self.embedder.embed_batch( | |
| searchable_texts, | |
| progress_callback=lambda completed, total: self._set_progress( | |
| repo.id, | |
| phase="embedding", | |
| message=f"Embedding chunks ({completed}/{total})", | |
| total_files=total_files, | |
| processed_files=total_files, | |
| discovered_chunks=len(chunk_payloads), | |
| total_chunks=total, | |
| embedded_chunks=completed, | |
| ), | |
| ) | |
| self._ensure_repo_not_cancelled(repo.id) | |
| vector_metadata = [] | |
| for chunk in chunk_payloads: | |
| vector_metadata.append( | |
| { | |
| "repository_id": repo.id, | |
| "file_path": chunk["file_path"], | |
| "language": chunk["language"], | |
| "symbol_name": chunk["symbol_name"], | |
| "symbol_type": chunk["symbol_type"], | |
| "line_start": chunk["line_start"], | |
| "line_end": chunk["line_end"], | |
| "signature": chunk["signature"], | |
| "content": chunk["content"], | |
| } | |
| ) | |
| embedding_ids = self.vector_store.add_embeddings(embeddings, vector_metadata) | |
| print( | |
| f"[indexing] Uploaded {len(embedding_ids)} embeddings to vector store for repo_id={repo.id}", | |
| flush=True, | |
| ) | |
| self._set_progress( | |
| repo.id, | |
| phase="saving", | |
| message="Saving chunks and search indexes", | |
| total_files=total_files, | |
| processed_files=total_files, | |
| discovered_chunks=len(chunk_payloads), | |
| ) | |
| created_rows = [] | |
| for chunk, embedding_id in zip(chunk_payloads, embedding_ids): | |
| row = { | |
| **chunk, | |
| "id": embedding_id, | |
| "repository_id": repo.id, | |
| "embedding_id": embedding_id, | |
| } | |
| created_rows.append(row) | |
| repo.status = "indexed" | |
| repo.file_count = file_count | |
| repo.chunk_count = len(created_rows) | |
| repo.indexed_at = datetime.utcnow() | |
| repo.session_expires_at = self._session_expiry() | |
| self._ensure_repo_still_exists(session, repo.id) | |
| self._ensure_repo_not_cancelled(repo.id) | |
| session.commit() | |
| serialized = [self._serialize_chunk(chunk) for chunk in created_rows] | |
| self.repo_chunks[repo.id] = serialized | |
| self.vector_store.save() | |
| self.indexing_progress.pop(repo.id, None) | |
| self.cancelled_repo_ids.discard(repo.id) | |
| self.repo_fetcher.cleanup_repository(clone_info["local_path"]) | |
| print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True) | |
| except Exception as exc: | |
| print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True) | |
| session.rollback() | |
| self.vector_store.remove_repository(repo_id) | |
| self.repo_chunks.pop(repo_id, None) | |
| self.hybrid_search.remove_repository(repo_id) | |
| repo = session.query(Repository).filter_by(id=repo_id).first() | |
| if repo: | |
| if repo_id in self.cancelled_repo_ids: | |
| session.delete(repo) | |
| else: | |
| repo.status = "failed" | |
| repo.error_message = str(exc) | |
| session.commit() | |
| try: | |
| if "clone_info" in locals(): | |
| self.repo_fetcher.cleanup_repository(clone_info["local_path"]) | |
| except Exception: | |
| pass | |
| self.indexing_progress.pop(repo_id, None) | |
| if isinstance(exc, SessionCancelledError): | |
| return | |
| raise | |
| finally: | |
| session.close() | |
| def list_repositories(self) -> List[dict]: | |
| raise NotImplementedError | |
| def list_repositories_for_session(self, session_key: str) -> List[dict]: | |
| session = get_db_session(self.database_url) | |
| try: | |
| self._cleanup_expired_sessions(session) | |
| repos = ( | |
| session.query(Repository) | |
| .filter_by(session_key=session_key) | |
| .order_by(Repository.updated_at.desc()) | |
| .all() | |
| ) | |
| self._touch_session(session, session_key) | |
| return [self._serialize_repo(repo) for repo in repos] | |
| finally: | |
| session.close() | |
| def get_repository(self, repo_id: int) -> Optional[dict]: | |
| raise NotImplementedError | |
| def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]: | |
| session = get_db_session(self.database_url) | |
| try: | |
| self._cleanup_expired_sessions(session) | |
| repo = ( | |
| session.query(Repository) | |
| .filter_by(id=repo_id, session_key=session_key) | |
| .first() | |
| ) | |
| self._touch_session(session, session_key) | |
| return self._serialize_repo(repo) if repo else None | |
| finally: | |
| session.close() | |
| def answer_question( | |
| self, | |
| repo_id: int, | |
| session_key: str, | |
| question: str, | |
| top_k: int = 8, | |
| history: Optional[List[object]] = None, | |
| ) -> dict: | |
| session = get_db_session(self.database_url) | |
| try: | |
| self._cleanup_expired_sessions(session) | |
| repo = ( | |
| session.query(Repository) | |
| .filter_by(id=repo_id, session_key=session_key) | |
| .first() | |
| ) | |
| if repo is None: | |
| raise ValueError("Repository not found") | |
| if repo.status != "indexed": | |
| raise ValueError("Repository is not ready for questions yet") | |
| if repo_id not in self.repo_chunks: | |
| raise ValueError("Session cache expired. Re-index the repository and try again.") | |
| self._touch_session(session, session_key) | |
| normalized_history = self._normalize_history(history or []) | |
| question_intent = self._question_intent(question) | |
| search_depth = top_k * 4 if question_intent in {"api", "implementation", "cross_file", "setup"} else top_k * 2 | |
| retrieval_query = self._build_retrieval_query(question, normalized_history) | |
| query_embedding = self.embedder.embed_text(retrieval_query) | |
| semantic_hits = [] | |
| for score, meta in self.vector_store.search(query_embedding, k=search_depth, repo_filter=repo_id): | |
| serialized = dict(meta) | |
| serialized["semantic_score"] = score | |
| semantic_hits.append(serialized) | |
| lexical_hits = self.hybrid_search.bm25_search( | |
| self.repo_chunks[repo_id], | |
| retrieval_query, | |
| top_k=search_depth, | |
| ) | |
| semantic_hits = self.hybrid_search.normalize_semantic_results(semantic_hits) | |
| fused = self.hybrid_search.reciprocal_rank_fusion(lexical_hits, semantic_hits, top_k=search_depth) | |
| rerank_query = retrieval_query if question_intent in {"api", "implementation", "cross_file", "setup"} else question | |
| reranked = self.hybrid_search.rerank(rerank_query, fused, top_k=search_depth) | |
| reranked = self._prioritize_results(question, retrieval_query, reranked, top_k=top_k) | |
| reranked = self._select_answer_sources(question, reranked, top_k=top_k) | |
| answer = self._generate_answer(repo, question, reranked, normalized_history) | |
| return answer | |
| finally: | |
| session.close() | |
| def end_session(self, session_key: str): | |
| session = get_db_session(self.database_url) | |
| try: | |
| repos = session.query(Repository).filter_by(session_key=session_key).all() | |
| self._delete_repositories(session, repos) | |
| session.commit() | |
| finally: | |
| session.close() | |
| def _generate_answer( | |
| self, | |
| repo: Repository, | |
| question: str, | |
| sources: List[dict], | |
| history: Optional[List[dict]] = None, | |
| ) -> dict: | |
| if not sources: | |
| return { | |
| "answer": "I could not find enough grounded evidence in the indexed codebase to answer that confidently.", | |
| "confidence": "low", | |
| "sources": [], | |
| "repo": self._serialize_repo(repo), | |
| } | |
| context_blocks = [] | |
| slim_sources = [] | |
| for index, source in enumerate(sources, start=1): | |
| context_blocks.append( | |
| "\n".join( | |
| [ | |
| f"[Source {index}]", | |
| f"File: {source['file_path']}", | |
| f"Symbol: {source['symbol_name']}", | |
| f"Lines: {source['line_start']}-{source['line_end']}", | |
| source["content"][:2500], | |
| ] | |
| ) | |
| ) | |
| slim_sources.append( | |
| { | |
| "file_path": source["file_path"], | |
| "language": source["language"], | |
| "symbol_name": source["symbol_name"], | |
| "symbol_type": source["symbol_type"], | |
| "line_start": source["line_start"], | |
| "line_end": source["line_end"], | |
| "signature": source["signature"], | |
| "snippet": source["content"], | |
| "semantic_score": round(float(source.get("semantic_score", 0.0)), 4), | |
| "bm25_score": round(float(source.get("bm25_score", 0.0)), 4), | |
| "rrf_score": round(float(source.get("rrf_score", 0.0)), 4), | |
| "rerank_score": round(float(source.get("rerank_score", 0.0)), 4), | |
| } | |
| ) | |
| wants_repo_overview = self._is_repo_overview_question(question) | |
| question_intent = self._question_intent(question) | |
| system_prompt = """ | |
| You are answering questions as a knowledgeable teammate who has carefully read this repository. | |
| Rules: | |
| 1. Use only the supplied repository context. | |
| 2. Answer conversationally and directly, as if the repo is explaining itself to the user. | |
| 3. Do not say "Based on the provided context", "The repository is about", or similar throat-clearing phrases. | |
| 4. Be concrete about files, functions, and behavior. | |
| 5. If evidence is partial, clearly separate what is certain from what is inferred. | |
| 6. Respond in Markdown, not JSON. | |
| 7. Keep the answer complete. Do not stop mid-sentence. | |
| 8. Use short sections or bullets only when they genuinely help readability. | |
| 9. Do not leave unfinished headings, dangling bullets, or trailing markdown markers like #, ##, or ###. | |
| 10. Do not include inline citation markers like [Source 1] in the prose. The UI already shows sources separately. | |
| 11. Do not make claims that are not directly supported by the supplied sources. | |
| 12. Prefer the most canonical source files for API and implementation questions, such as package exports, core modules, and session/query code, over tutorial prose when they disagree in specificity. | |
| 13. Keep the answer tight. Lead with the direct answer, then add only the most important supporting detail. | |
| """ | |
| if wants_repo_overview: | |
| system_prompt += """ | |
| 14. For repository overview questions, lead with a direct one or two sentence summary of what the repo does. | |
| 15. Prioritize README and top-level documentation when they are present, then use code to support the explanation. | |
| 16. Mention the main workflow, core stack, and any important product constraints the user would care about. | |
| 17. Keep the answer polished and self-contained, like the overview a real user expects when they ask what a repo is about. | |
| """ | |
| elif question_intent in {"api", "implementation", "cross_file", "error_handling", "setup"}: | |
| system_prompt += """ | |
| 14. For API, implementation, setup, and cross-file questions, prefer the smallest correct answer that is directly supported by code. | |
| 15. If a detail comes only from docs or examples and not from the canonical implementation, say that clearly instead of presenting it as core behavior. | |
| 16. When describing exports or code paths, name the file first and keep the explanation precise. | |
| 17. Default to one short paragraph plus at most 3 short bullets. Avoid long explanatory walkthroughs unless the question explicitly asks for depth. | |
| """ | |
| joined_context = "\n\n".join(context_blocks) | |
| user_prompt = f""" | |
| Repository: {repo.owner}/{repo.name} | |
| Question: {question} | |
| Recent conversation: | |
| {self._format_history(history or [])} | |
| Context: | |
| {joined_context} | |
| """ | |
| answer_text, finish_reason = self._generate_markdown_response(system_prompt, user_prompt) | |
| if self._looks_incomplete(answer_text, finish_reason): | |
| repair_prompt = f""" | |
| The draft answer below appears to be cut off or incomplete. | |
| Rewrite it into a complete final answer using the same repository context and rules. | |
| Draft answer: | |
| {answer_text} | |
| """ | |
| answer_text, finish_reason = self._generate_markdown_response( | |
| system_prompt, | |
| f"{user_prompt.strip()}\n\n{repair_prompt.strip()}", | |
| ) | |
| if self._looks_incomplete(answer_text, finish_reason): | |
| short_prompt = f""" | |
| Answer the question again, but keep it concise and complete. | |
| Use 2 short paragraphs or 4-6 bullets max. | |
| Do not leave the answer unfinished. | |
| """ | |
| answer_text, _ = self._generate_markdown_response( | |
| system_prompt, | |
| f"{user_prompt.strip()}\n\n{short_prompt.strip()}", | |
| ) | |
| answer_text = self._finalize_answer(answer_text) | |
| confidence = self._estimate_confidence(sources) | |
| summary = " ".join(answer_text.split())[:160] if answer_text else "" | |
| citations = [ | |
| { | |
| "source": index, | |
| "reason": f"Relevant context from {source['file_path']}", | |
| } | |
| for index, source in enumerate(sources[: min(len(sources), 4)], start=1) | |
| ] | |
| return { | |
| "answer": answer_text, | |
| "confidence": confidence, | |
| "summary": summary, | |
| "citations": citations, | |
| "sources": slim_sources, | |
| "repo": self._serialize_repo(repo), | |
| } | |
| def _configure_llm(self): | |
| if self.llm_provider == "groq": | |
| self.llm_client = OpenAI( | |
| api_key=os.getenv("GROQ_API_KEY"), | |
| base_url=os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1"), | |
| ) | |
| self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile") | |
| return | |
| if self.llm_provider == "vertex_ai": | |
| try: | |
| from google import genai | |
| except ImportError as exc: | |
| raise RuntimeError( | |
| "Vertex AI LLM support requires the `google-genai` package. " | |
| "Install server dependencies before running local or eval queries." | |
| ) from exc | |
| project = os.getenv("GOOGLE_CLOUD_PROJECT") | |
| location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") | |
| if not project: | |
| raise RuntimeError( | |
| "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI Gemini." | |
| ) | |
| self.llm_client = genai.Client( | |
| vertexai=True, | |
| project=project, | |
| location=location, | |
| ) | |
| self.llm_model = os.getenv("VERTEX_LLM_MODEL", "gemini-2.5-pro") | |
| return | |
| raise RuntimeError(f"Unsupported LLM provider: {self.llm_provider}") | |
| def _generate_markdown_response(self, system_prompt: str, user_prompt: str) -> tuple[str, str]: | |
| if self.llm_provider == "groq": | |
| response = self.llm_client.chat.completions.create( | |
| model=self.llm_model, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=1600, | |
| ) | |
| content = response.choices[0].message.content | |
| finish_reason = getattr(response.choices[0], "finish_reason", "") or "" | |
| return self._normalize_markdown_answer(content), str(finish_reason) | |
| response = self.llm_client.models.generate_content( | |
| model=self.llm_model, | |
| contents=f"{system_prompt.strip()}\n\n{user_prompt.strip()}", | |
| config={ | |
| "temperature": 0.1, | |
| "max_output_tokens": 2200, | |
| }, | |
| ) | |
| if not getattr(response, "text", None): | |
| raise RuntimeError("Vertex AI Gemini returned an empty response.") | |
| finish_reason = "" | |
| candidates = getattr(response, "candidates", None) or [] | |
| if candidates: | |
| finish_reason = str(getattr(candidates[0], "finish_reason", "") or "") | |
| return self._normalize_markdown_answer(response.text), finish_reason | |
| def _normalize_markdown_answer(raw_text: str) -> str: | |
| cleaned = (raw_text or "").strip() | |
| cleaned = re.sub(r"^```(?:markdown|md)?\s*|\s*```$", "", cleaned, flags=re.IGNORECASE) | |
| cleaned = re.sub(r"\s*\[(?:Source\s+\d+(?:\s*,\s*Source\s+\d+)*)\]", "", cleaned, flags=re.IGNORECASE) | |
| cleaned = re.sub( | |
| r"^(?:based on the provided context[,:\s-]*|from the provided context[,:\s-]*)", | |
| "", | |
| cleaned, | |
| flags=re.IGNORECASE, | |
| ).strip() | |
| cleaned = re.sub( | |
| r"\n(?:#{1,6}|[-*])\s*$", | |
| "", | |
| cleaned, | |
| flags=re.MULTILINE, | |
| ).strip() | |
| cleaned = re.sub(r"(?:\n\s*){3,}", "\n\n", cleaned) | |
| cleaned = cleaned.strip() | |
| if not cleaned: | |
| return "I found relevant code context, but the model returned an empty response." | |
| return cleaned | |
| def _finalize_answer(answer_text: str) -> str: | |
| cleaned = (answer_text or "").strip() | |
| if not cleaned: | |
| return "I found relevant code context, but the model returned an empty response." | |
| # If the tail still looks truncated, trim back to the last complete sentence or list item | |
| if CodebaseRAGSystem._looks_incomplete(cleaned): | |
| sentence_match = re.search(r"(?s)^.*[.!?](?:['\"\)`\]]+)?", cleaned) | |
| if sentence_match: | |
| trimmed = sentence_match.group(0).strip() | |
| if len(trimmed.split()) >= 12: | |
| return trimmed | |
| lines = cleaned.splitlines() | |
| while lines and CodebaseRAGSystem._looks_incomplete(lines[-1]): | |
| lines.pop() | |
| candidate = "\n".join(line for line in lines if line.strip()).strip() | |
| if candidate: | |
| return candidate | |
| return cleaned | |
| def _looks_incomplete(answer_text: str, finish_reason: str = "") -> bool: | |
| cleaned = (answer_text or "").strip() | |
| if not cleaned: | |
| return True | |
| finish_reason = (finish_reason or "").strip().lower() | |
| if finish_reason and finish_reason not in {"stop", "stopsequence", "finish_reason_unspecified"}: | |
| return True | |
| if cleaned.endswith(("#", "-", "*", ":", "(", "[", "/", "`")): | |
| return True | |
| if cleaned.endswith(("[source", "[source 1", "[source 2", "[source 3", "[source 4")): | |
| return True | |
| if cleaned.count("```") % 2 != 0: | |
| return True | |
| if cleaned.count("(") > cleaned.count(")"): | |
| return True | |
| if cleaned.count("[") > cleaned.count("]"): | |
| return True | |
| tokens = re.findall(r"\b[\w'-]+\b", cleaned.lower()) | |
| if not tokens: | |
| return True | |
| if tokens[-1] in {"a", "an", "the", "to", "for", "with", "of", "in", "on", "from", "about"}: | |
| return True | |
| if len(tokens) >= 20 and cleaned[-1] not in {".", "!", "?", "\"", "'", "`"}: | |
| return True | |
| return False | |
| def _estimate_confidence(sources: List[dict]) -> str: | |
| if not sources: | |
| return "low" | |
| top = sources[0] | |
| rerank = float(top.get("rerank_score", 0.0)) | |
| semantic = float(top.get("semantic_score", 0.0)) | |
| if len(sources) >= 3 and (rerank >= 0.2 or semantic >= 0.75): | |
| return "high" | |
| if rerank >= 0.05 or semantic >= 0.45: | |
| return "medium" | |
| return "low" | |
| def _serialize_repo(self, repo: Repository) -> dict: | |
| payload = { | |
| "id": repo.id, | |
| "github_url": repo.source_url or repo.github_url, | |
| "owner": repo.owner, | |
| "name": repo.name, | |
| "branch": repo.branch, | |
| "local_path": repo.local_path, | |
| "status": repo.status, | |
| "error_message": repo.error_message, | |
| "file_count": repo.file_count, | |
| "chunk_count": repo.chunk_count, | |
| "indexed_at": repo.indexed_at.isoformat() if repo.indexed_at else None, | |
| "created_at": repo.created_at.isoformat() if repo.created_at else None, | |
| "updated_at": repo.updated_at.isoformat() if repo.updated_at else None, | |
| } | |
| progress = self.indexing_progress.get(repo.id) | |
| if progress: | |
| payload["progress"] = progress | |
| return payload | |
| def _set_progress(self, repo_id: int, **progress): | |
| self.indexing_progress[repo_id] = { | |
| **self.indexing_progress.get(repo_id, {}), | |
| **progress, | |
| "updated_at": datetime.utcnow().isoformat(), | |
| } | |
| def _touch_session(self, session, session_key: str): | |
| expiry = self._session_expiry() | |
| repos = session.query(Repository).filter_by(session_key=session_key).all() | |
| for repo in repos: | |
| repo.session_expires_at = expiry | |
| session.commit() | |
| def _cleanup_expired_sessions(self, session): | |
| now = datetime.utcnow() | |
| expired = ( | |
| session.query(Repository) | |
| .filter(Repository.session_expires_at.is_not(None)) | |
| .filter(Repository.session_expires_at < now) | |
| .all() | |
| ) | |
| if not expired: | |
| return | |
| self._delete_repositories(session, expired) | |
| session.commit() | |
| def _delete_repositories( | |
| self, | |
| session, | |
| repos: List[Repository], | |
| track_cancellation: bool = True, | |
| ): | |
| repo_ids = [repo.id for repo in repos] | |
| for repo_id in repo_ids: | |
| if track_cancellation: | |
| self.cancelled_repo_ids.add(repo_id) | |
| self.hybrid_search.remove_repository(repo_id) | |
| self.vector_store.remove_repository(repo_id) | |
| self.repo_chunks.pop(repo_id, None) | |
| self.indexing_progress.pop(repo_id, None) | |
| for repo in repos: | |
| session.delete(repo) | |
| def _ensure_repo_not_cancelled(self, repo_id: int): | |
| if repo_id in self.cancelled_repo_ids: | |
| raise SessionCancelledError("Session ended before indexing completed.") | |
| def _build_retrieval_query(self, question: str, history: List[dict]) -> str: | |
| normalized = " ".join(question.strip().split()) | |
| if self._is_repo_overview_question(normalized): | |
| return "\n".join( | |
| [ | |
| normalized, | |
| "repository overview purpose main workflow architecture README features stack", | |
| ] | |
| ) | |
| if not history: | |
| return normalized | |
| recent_user = [ | |
| turn["content"].strip() | |
| for turn in reversed(history) | |
| if turn.get("role") == "user" and turn.get("content", "").strip() | |
| ] | |
| recent_assistant = [ | |
| turn["content"].strip() | |
| for turn in reversed(history) | |
| if turn.get("role") == "assistant" and turn.get("content", "").strip() | |
| and self._is_substantive_assistant_message(turn.get("content", "")) | |
| ] | |
| is_follow_up = ( | |
| len(normalized.split()) <= 6 | |
| or bool(re.fullmatch(r"(give|show|where|which|how|what)(?:\s+.+)?", normalized.lower())) | |
| or any(token in normalized.lower() for token in {"code", "snippet", "implementation"}) | |
| ) | |
| if not is_follow_up or not recent_user: | |
| return self._expand_query_for_intent(normalized) | |
| parts = [self._expand_query_for_intent(normalized)] | |
| if recent_user: | |
| parts.append(f"Follow-up to: {recent_user[0]}") | |
| if recent_assistant: | |
| parts.append(f"Previous answer: {recent_assistant[0][:300]}") | |
| return "\n".join(parts) | |
| def _prioritize_results( | |
| self, | |
| question: str, | |
| retrieval_query: str, | |
| results: List[dict], | |
| top_k: int, | |
| ) -> List[dict]: | |
| combined_query = f"{question} {retrieval_query}".lower() | |
| wants_code = any( | |
| token in combined_query | |
| for token in {"code", "snippet", "implementation", "function", "class", "import"} | |
| ) | |
| wants_docs = self._is_documentation_query(combined_query) | |
| wants_repo_overview = self._is_repo_overview_question(question) or self._is_repo_overview_question( | |
| retrieval_query | |
| ) | |
| question_intent = self._question_intent(question) | |
| def sort_key(item: dict): | |
| is_doc = self._is_doc_source(item) | |
| return ( | |
| self._canonical_path_priority(item, question), | |
| self._doc_priority(item), | |
| 1 if wants_repo_overview and is_doc else 0, | |
| 1 if (wants_docs and is_doc) or (not wants_docs and not is_doc) else 0, | |
| 1 if wants_code and not is_doc else 0, | |
| 1 if question_intent in {"api", "implementation", "cross_file", "error_handling", "setup"} and not is_doc else 0, | |
| float(item.get("rerank_score", 0.0)), | |
| float(item.get("semantic_score", 0.0)), | |
| float(item.get("bm25_score", 0.0)), | |
| ) | |
| ranked = sorted(results, key=sort_key, reverse=True) | |
| if wants_docs or wants_repo_overview: | |
| return ranked[:top_k] | |
| selected = [] | |
| doc_items = [] | |
| for item in ranked: | |
| if self._is_doc_source(item): | |
| doc_items.append(item) | |
| continue | |
| selected.append(item) | |
| if len(selected) == top_k: | |
| return selected | |
| selected.extend(doc_items[: max(1, top_k - len(selected))]) | |
| return selected[:top_k] | |
| def _select_answer_sources( | |
| self, | |
| question: str, | |
| results: List[dict], | |
| top_k: int, | |
| ) -> List[dict]: | |
| if not results: | |
| return [] | |
| intent = self._question_intent(question) | |
| max_per_file = 2 if intent in {"overview", "docs"} else 1 | |
| selected = [] | |
| file_counts = {} | |
| for item in results: | |
| file_path = item.get("file_path", "") | |
| count = file_counts.get(file_path, 0) | |
| if count >= max_per_file: | |
| continue | |
| selected.append(item) | |
| file_counts[file_path] = count + 1 | |
| if len(selected) == top_k: | |
| break | |
| if len(selected) < top_k: | |
| for item in results: | |
| if item in selected: | |
| continue | |
| selected.append(item) | |
| if len(selected) == top_k: | |
| break | |
| return selected | |
| def _is_documentation_query(query: str) -> bool: | |
| return any( | |
| token in query | |
| for token in { | |
| "readme", | |
| "docs", | |
| "documentation", | |
| "setup", | |
| "install", | |
| "installation", | |
| "usage", | |
| "overview", | |
| "what is this repo", | |
| "what is the repository about", | |
| "what is the repo about", | |
| "what does the repo do", | |
| "what does this repo do", | |
| "repo summary", | |
| "repository summary", | |
| "project summary", | |
| "feature", | |
| "features", | |
| "architecture", | |
| } | |
| ) | |
| def _question_intent(question: str) -> str: | |
| normalized = " ".join((question or "").lower().split()) | |
| if not normalized: | |
| return "general" | |
| if CodebaseRAGSystem._is_repo_overview_question(normalized): | |
| return "overview" | |
| if any(token in normalized for token in {"error", "invalid", "conflict", "raises", "guard against"}): | |
| return "error_handling" | |
| if any(token in normalized for token in {"how are", "how does", "flow", "across files", "code path"}): | |
| return "cross_file" | |
| if any(token in normalized for token in {"export", "expose", "import", "public api"}): | |
| return "api" | |
| if any(token in normalized for token in {"create", "setup", "install", "configuration", "metadata", "table"}): | |
| return "setup" | |
| if any(token in normalized for token in {"function", "method", "class", "implementation", "does ", "what is special"}): | |
| return "implementation" | |
| if CodebaseRAGSystem._is_documentation_query(normalized): | |
| return "docs" | |
| return "general" | |
| def _expand_query_for_intent(self, question: str) -> str: | |
| normalized = " ".join((question or "").split()) | |
| lowered = normalized.lower() | |
| hints = [] | |
| if any(token in lowered for token in {"export", "expose", "import"}): | |
| hints.extend(["package exports", "__init__.py", "public api", "re-export"]) | |
| if "how is select exposed to users in sqlmodel" in lowered: | |
| hints.extend( | |
| [ | |
| "sqlmodel/__init__.py", | |
| "sqlmodel/sql/expression.py", | |
| "select re-export", | |
| "top-level select import", | |
| ] | |
| ) | |
| if "select" in lowered: | |
| hints.extend( | |
| [ | |
| "select", | |
| "expression", | |
| "query builder", | |
| "public api", | |
| "sqlmodel/sql/expression.py", | |
| "sqlmodel/__init__.py", | |
| "re-export", | |
| "top-level import", | |
| ] | |
| ) | |
| if "session.exec" in lowered or ("session" in lowered and "exec" in lowered): | |
| hints.extend(["session exec", "orm/session.py", "asyncio/session.py"]) | |
| if "relationship" in lowered: | |
| hints.extend(["relationship", "Relationship", "main.py"]) | |
| if "field" in lowered: | |
| hints.extend(["Field", "FieldInfo", "main.py"]) | |
| if "create_engine" in lowered: | |
| hints.extend(["create_engine", "__init__.py", "re-export"]) | |
| if "create_all" in lowered or "metadata" in lowered: | |
| hints.extend( | |
| [ | |
| "metadata create_all", | |
| "table creation", | |
| "engine", | |
| "SQLModel.metadata", | |
| "README.md", | |
| "sqlmodel/main.py", | |
| "docs_src", | |
| ] | |
| ) | |
| if "__init__" in lowered or "exports" in lowered: | |
| hints.extend(["sqlmodel/__init__.py", "package exports", "public api"]) | |
| if not hints: | |
| return normalized | |
| return "\n".join([normalized, " ".join(hints)]) | |
| def _is_repo_overview_question(question: str) -> bool: | |
| normalized = " ".join((question or "").lower().split()) | |
| return any( | |
| phrase in normalized | |
| for phrase in { | |
| "what is the repo about", | |
| "what is this repo about", | |
| "what does the repo do", | |
| "what does this repo do", | |
| "what is the repository about", | |
| "what does the repository do", | |
| "what is this project about", | |
| "what does this project do", | |
| "repo summary", | |
| "repository summary", | |
| "project summary", | |
| "summarize the repo", | |
| "summarize this repo", | |
| "repo overview", | |
| "repository overview", | |
| "project overview", | |
| } | |
| ) | |
| def _is_doc_source(item: dict) -> bool: | |
| file_path = (item.get("file_path") or "").lower() | |
| language = (item.get("language") or "").lower() | |
| return language == "text" or file_path.endswith(".md") or "/readme" in file_path | |
| def _doc_priority(item: dict) -> int: | |
| file_path = (item.get("file_path") or "").lower() | |
| if file_path in {"readme.md", "readme"}: | |
| return 3 | |
| if file_path.startswith("docs/") or "/docs/" in file_path: | |
| return 2 | |
| if file_path.endswith(".md"): | |
| return 1 | |
| return 0 | |
| def _canonical_path_priority(self, item: dict, question: str) -> int: | |
| file_path = (item.get("file_path") or "").lower() | |
| normalized = " ".join((question or "").lower().split()) | |
| score = 0 | |
| if file_path == "sqlmodel/__init__.py": | |
| score += 4 if any(token in normalized for token in {"export", "expose", "import", "create_engine", "select"}) else 0 | |
| if file_path == "sqlmodel/sql/expression.py": | |
| score += 5 if "select" in normalized else 0 | |
| if file_path == "sqlmodel/sql/_expression_select_gen.py": | |
| score += 2 if "select" in normalized else 0 | |
| if file_path == "sqlmodel/sql/_expression_select_cls.py": | |
| score += 2 if "select" in normalized else 0 | |
| if file_path == "readme.md": | |
| score += 4 if any(token in normalized for token in {"metadata", "create_all", "workflow", "readme"}) else 0 | |
| if file_path.startswith("docs_src/"): | |
| score += 3 if any(token in normalized for token in {"metadata", "create_all", "table", "workflow"}) else 0 | |
| if file_path == "sqlmodel/main.py": | |
| score += 3 if any(token in normalized for token in {"field", "relationship", "metadata", "table", "sqlmodel"}) else 0 | |
| if "__init__.py" in file_path: | |
| score += 2 if any(token in normalized for token in {"export", "expose", "import", "public api"}) else 0 | |
| if any(token in normalized for token in {"select", "expression"}): | |
| if "expression" in file_path or "_expression_select" in file_path: | |
| score += 3 | |
| if normalized == "how is select exposed to users in sqlmodel?": | |
| if file_path == "sqlmodel/__init__.py": | |
| score += 6 | |
| if file_path == "sqlmodel/sql/expression.py": | |
| score += 6 | |
| if "session" in normalized: | |
| if file_path.endswith("session.py") or "/session.py" in file_path: | |
| score += 3 | |
| if "relationship" in normalized and file_path.endswith("main.py"): | |
| score += 2 | |
| if "field" in normalized and file_path.endswith("main.py"): | |
| score += 2 | |
| if any(token in normalized for token in {"create_engine", "export", "expose"}) and "__init__.py" in file_path: | |
| score += 2 | |
| if any(token in normalized for token in {"metadata", "create_all", "table"}) and ( | |
| "docs_src/" in file_path or file_path.endswith("main.py") or file_path == "readme.md" | |
| ): | |
| score += 2 | |
| if self._is_doc_source(item) and self._question_intent(question) in { | |
| "api", | |
| "implementation", | |
| "cross_file", | |
| "error_handling", | |
| "setup", | |
| }: | |
| score -= 1 | |
| return score | |
| def _is_substantive_assistant_message(content: str) -> bool: | |
| normalized = " ".join((content or "").strip().lower().split()) | |
| if len(normalized) < 24: | |
| return False | |
| if normalized in { | |
| "hey, what question do you have for me today?", | |
| "ask a question", | |
| }: | |
| return False | |
| return True | |
| def _normalize_history(history: List[object]) -> List[dict]: | |
| normalized = [] | |
| for turn in history: | |
| if isinstance(turn, dict): | |
| role = turn.get("role") | |
| content = turn.get("content") | |
| else: | |
| role = getattr(turn, "role", None) | |
| content = getattr(turn, "content", None) | |
| if not role or not content: | |
| continue | |
| normalized.append( | |
| { | |
| "role": str(role), | |
| "content": str(content), | |
| } | |
| ) | |
| return normalized | |
| def _format_history(history: List[dict]) -> str: | |
| if not history: | |
| return "None" | |
| lines = [] | |
| for turn in history[-4:]: | |
| role = turn.get("role", "user").capitalize() | |
| content = " ".join(turn.get("content", "").split()) | |
| if content: | |
| lines.append(f"{role}: {content[:400]}") | |
| return "\n".join(lines) if lines else "None" | |
| def _ensure_repo_still_exists(session, repo_id: int): | |
| if session.query(Repository.id).filter_by(id=repo_id).first() is None: | |
| raise RuntimeError("Repository was removed before indexing completed.") | |
| def _session_expiry(self) -> datetime: | |
| return datetime.utcnow() + timedelta(minutes=self.session_ttl_minutes) | |
| def _build_registry_key(session_key: str, github_url: str) -> str: | |
| return f"{session_key}::{github_url}" | |
| def _serialize_chunk(chunk: dict) -> dict: | |
| return { | |
| "id": chunk["id"], | |
| "file_path": chunk["file_path"], | |
| "language": chunk["language"], | |
| "symbol_name": chunk["symbol_name"], | |
| "symbol_type": chunk["symbol_type"], | |
| "line_start": chunk["line_start"], | |
| "line_end": chunk["line_end"], | |
| "signature": chunk["signature"], | |
| "content": chunk["content"], | |
| "searchable_text": chunk["searchable_text"], | |
| "metadata_json": chunk.get("metadata_json") or {}, | |
| } | |