|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import pickle |
|
|
from typing import List, Dict |
|
|
import uuid |
|
|
import datetime |
|
|
import shutil |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from langchain_community.vectorstores import Chroma, FAISS |
|
|
from langchain_core.documents import Document |
|
|
|
|
|
|
|
|
import faiss |
|
|
from chromadb.config import Settings as ChromaSettings |
|
|
|
|
|
def load_embedder(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"): |
|
|
model = SentenceTransformer(model_name) |
|
|
def embed_f(texts: List[str]) -> List[List[float]]: |
|
|
vecs = model.encode(texts, convert_to_numpy=True, normalize_embeddings=True) |
|
|
return vecs.tolist() |
|
|
return model, embed_f |
|
|
|
|
|
def load_chroma(persist_dir: str, collection: str, embed_f): |
|
|
from langchain.embeddings.base import Embeddings |
|
|
class STEmbeddings(Embeddings): |
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
|
return embed_f(texts) |
|
|
def embed_query(self, text: str) -> List[float]: |
|
|
return embed_f([text])[0] |
|
|
|
|
|
embeddings = STEmbeddings() |
|
|
vectordb = Chroma( |
|
|
collection_name=collection, |
|
|
persist_directory=persist_dir, |
|
|
embedding_function=embeddings |
|
|
) |
|
|
return vectordb |
|
|
|
|
|
def load_faiss_langchain(persist_dir: str, embed_f): |
|
|
|
|
|
|
|
|
|
|
|
from langchain.embeddings.base import Embeddings |
|
|
class STEmbeddings(Embeddings): |
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]: |
|
|
return embed_f(texts) |
|
|
def embed_query(self, text: str) -> List[float]: |
|
|
return embed_f([text])[0] |
|
|
embeddings = STEmbeddings() |
|
|
|
|
|
index_path = os.path.join(persist_dir, "index.faiss") |
|
|
meta_path = os.path.join(persist_dir, "meta.pkl") |
|
|
if not (os.path.exists(index_path) and os.path.exists(meta_path)): |
|
|
raise FileNotFoundError(f"FAISS files not found in {persist_dir}") |
|
|
|
|
|
index = faiss.read_index(index_path) |
|
|
with open(meta_path, "rb") as f: |
|
|
meta = pickle.load(f) |
|
|
|
|
|
|
|
|
texts = [m["text"] for m in meta] |
|
|
metadatas = [m["meta"] | {"id": m["id"]} for m in meta] |
|
|
vectordb = FAISS.from_texts(texts=texts, embedding=embeddings, metadatas=metadatas) |
|
|
|
|
|
vectordb.index = index |
|
|
return vectordb |
|
|
|
|
|
def retrieve(vdb, query: str, top_k: int = 5): |
|
|
retriever = vdb.as_retriever(search_kwargs={"k": top_k}) |
|
|
docs: List[Document] = retriever.invoke(query) |
|
|
return docs |
|
|
|
|
|
def format_context(docs: List[Document]) -> str: |
|
|
parts = [] |
|
|
for i, d in enumerate(docs, 1): |
|
|
md = d.metadata or {} |
|
|
parts.append(f"[{i}] note_id={md.get('note_id')} section={md.get('section')} chunk_idx={md.get('chunk_index')}\n{d.page_content}") |
|
|
return "\n\n---\n\n".join(parts) |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Day 7: Retriever + RAG baseline (retrieval only).") |
|
|
parser.add_argument("--db_type", choices=["chroma", "faiss"], default="chroma") |
|
|
parser.add_argument("--persist_dir", default="./data/vector_store") |
|
|
parser.add_argument("--collection", default="notes") |
|
|
parser.add_argument("--model_name", default="sentence-transformers/all-MiniLM-L6-v2") |
|
|
parser.add_argument("--query", required=True) |
|
|
parser.add_argument("--top_k", type=int, default=5) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.db_type == "chroma" and os.path.exists(args.persist_dir): |
|
|
shutil.rmtree(args.persist_dir) |
|
|
|
|
|
_, embed_f = load_embedder(args.model_name) |
|
|
|
|
|
if args.db_type == "chroma": |
|
|
vectordb = load_chroma(args.persist_dir, args.collection, embed_f) |
|
|
else: |
|
|
vectordb = load_faiss_langchain(args.persist_dir, embed_f) |
|
|
|
|
|
docs = retrieve(vectordb, args.query, args.top_k) |
|
|
context = format_context(docs) |
|
|
print("\n=== Retrieved Context (to feed Day 8 summarizer) ===\n") |
|
|
print(context) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|