Park-Hip-02's picture
initial commit
d9762cf
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient, models
import logging
import pickle
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s β€” %(levelname)s β€” %(message)s',
)
logger = logging.getLogger(__name__)
def get_vectorstore() -> QdrantVectorStore:
base_dir = Path(__file__).resolve().parent.parent
doc_path = base_dir / 'data' / 'processed_data' / 'criminal_code_of_vietnam.pkl'
with open(doc_path, 'rb') as f:
doc_list = pickle.load(f)
qdrant_api_key = os.getenv('QDRANT_API_KEY')
qdrant_url = os.getenv('QDRANT_URL')
hf_api_key = os.getenv('HUGGINGFACEHUB_API_TOKEN')
collection_name = 'legal_db'
client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
model_name = 'BAAI/bge-large-en'
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logger.info('Embedding created.')
dummy_embedding = embeddings.embed_query('A dummy to test embedding dimension')
vector_dim = len(dummy_embedding)
vectors_config = models.VectorParams(size=vector_dim, distance=models.Distance.COSINE)
if collection_name in [c.name for c in client.get_collections().collections]:
logger.info('Collection exists. Connecting...')
collection_info = client.get_collection(collection_name)
existing_dim = None
if hasattr(collection_info.config, 'vectors') and hasattr(collection_info.config.vectors, 'size'):
existing_dim = collection_info.config.vectors.size
elif hasattr(collection_info.config, 'params') and hasattr(collection_info.config.params, 'vectors') and hasattr(collection_info.config.params.vectors, 'size'):
existing_dim = collection_info.config.params.vectors.size
logger.info(f'Existing dimension: {existing_dim}')
if existing_dim != vector_dim:
raise ValueError(
f'Dimension mismatch: existing collection has {existing_dim}, but embedding model gives {vector_dim}'
)
db = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
prefer_grpc=False,
url=qdrant_url,
api_key = qdrant_api_key
)
else:
logger.info(f'Collection "{collection_name}" does not exist. Creating new collection...')
client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
db = QdrantVectorStore.from_documents(
documents=doc_list,
embedding=embeddings,
url=qdrant_url,
prefer_grpc=False,
collection_name=collection_name,
retrieval_mode = RetrievalMode.DENSE,
api_key = qdrant_api_key
)
logger.info('Qdrant Index created.')
fields_to_index = {
'metadata.article': "keyword",
'metadata.chapter': "keyword",
'metadata.id': "keyword",
'metadata.source': "keyword",
'metadata.title': "keyword",
}
for field, schema in fields_to_index.items():
client.create_payload_index(
collection_name = collection_name,
field_name = field,
field_schema = schema,
)
return db