gemma / retriever /vectordb_rerank.py
dasomaru's picture
Upload folder using huggingface_hub
baaabde verified
raw
history blame contribute delete
1.81 kB
# vectordb_relank_law.py
import faiss
import numpy as np
import os
from chromadb import PersistentClient
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from retriever.reranker import rerank_documents
# chroma vector config v2
embedding_models = [
"upskyy/bge-m3-korean",
"jhgan/ko-sbert-sts",
"BM-K/KoSimCSE-roberta",
"BM-K/KoSimCSE-v2-multitask",
"snunlp/KR-SBERT-V40K-klueNLI-augSTS",
"beomi/KcELECTRA-small-v2022",
]
# law_db config v2
CHROMA_PATH = os.path.abspath("data/index/law_db")
COLLECTION_NAME = "law_all"
EMBEDDING_MODEL_NAME = embedding_models[0] # μ‚¬μš©ν•˜κ³ μž ν•˜λŠ” λͺ¨λΈ 선택
# 1. μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ v2
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# 2. μž„λ² λ”© ν•¨μˆ˜ μ„€μ •
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)
# 3. Chroma ν΄λΌμ΄μ–ΈνŠΈ 및 μ»¬λ ‰μ…˜ λ‘œλ“œ
client = PersistentClient(path=CHROMA_PATH)
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)
# 4. 검색 ν•¨μˆ˜
def search_documents(query: str, top_k: int = 5):
print(f"\nπŸ” 검색어: '{query}'")
results = collection.query(
query_texts=[query],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
for i, (doc, meta, dist) in enumerate(zip(
results['documents'][0],
results['metadatas'][0],
results['distances'][0]
)):
print(f"\nπŸ“„ κ²°κ³Ό {i+1} (μœ μ‚¬λ„: {1 - dist:.2f})")
print(f"λ¬Έμ„œ: {doc[:150]}...")
print("메타데이터:")
print(meta)