from pymilvus import MilvusClient, DataType import numpy as np import concurrent.futures class MilvusManager: def __init__(self, milvus_uri, collection_name, create_collection, dim=128): self.client = MilvusClient(uri=milvus_uri) self.collection_name = collection_name if self.client.has_collection(collection_name=self.collection_name): self.client.load_collection(collection_name) self.dim = dim if create_collection: self.create_collection() self.create_index() def create_collection(self): if self.client.has_collection(collection_name=self.collection_name): self.client.drop_collection(collection_name=self.collection_name) schema = self.client.create_schema( auto_id=True, enable_dynamic_fields=True, ) schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim ) schema.add_field(field_name="seq_id", datatype=DataType.INT16) schema.add_field(field_name="doc_id", datatype=DataType.INT64) schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) self.client.create_collection( collection_name=self.collection_name, schema=schema ) def create_index(self): self.client.release_collection(collection_name=self.collection_name) self.client.drop_index( collection_name=self.collection_name, index_name="vector" ) index_params = self.client.prepare_index_params() index_params.add_index( field_name="vector", index_name="vector_index", index_type="FLAT", metric_type="IP", params={ "M": 16, "efConstruction": 500, }, ) self.client.create_index( collection_name=self.collection_name, index_params=index_params, sync=True ) def create_scalar_index(self): self.client.release_collection(collection_name=self.collection_name) index_params = self.client.prepare_index_params() index_params.add_index( field_name="doc_id", index_name="int32_index", index_type="INVERTED", ) self.client.create_index( collection_name=self.collection_name, index_params=index_params, sync=True ) def search(self, data, topk): search_params = {"metric_type": "IP", "params": {}} results = self.client.search( self.collection_name, data, limit=int(50), output_fields=["vector", "seq_id", "doc_id"], search_params=search_params, ) doc_ids = set() for r_id in range(len(results)): for r in range(len(results[r_id])): doc_ids.add(results[r_id][r]["entity"]["doc_id"]) scores = [] def rerank_single_doc(doc_id, data, client, collection_name): doc_colbert_vecs = client.query( collection_name=collection_name, filter=f"doc_id in [{doc_id}, {doc_id + 1}]", output_fields=["seq_id", "vector", "doc"], limit=1000, ) doc_vecs = np.vstack( [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] ) score = np.dot(data, doc_vecs.T).max(1).sum() return (score, doc_id) with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: futures = { executor.submit( rerank_single_doc, doc_id, data, self.client, self.collection_name ): doc_id for doc_id in doc_ids } for future in concurrent.futures.as_completed(futures): score, doc_id = future.result() scores.append((score, doc_id)) scores.sort(key=lambda x: x[0], reverse=True) # 📊 DETAILED SCORE LOGGING - Print page numbers with highest scores print("\n" + "="*80) print("📊 RETRIEVAL SCORES - PAGE NUMBERS WITH HIGHEST SCORES") print("="*80) print(f"🔍 Collection: {self.collection_name}") print(f"📄 Total documents found: {len(scores)}") print(f"🎯 Requested top-k: {topk}") print("-"*80) # Display top 10 scores with detailed information display_count = min(10, len(scores)) for i, (score, doc_id) in enumerate(scores[:display_count]): page_num = doc_id + 1 # Convert doc_id to page number (0-based to 1-based) relevance_level = self._get_relevance_level(score) print(f"📄 Page {page_num:2d} (doc_id: {doc_id:2d}) | Score: {score:8.4f} | {relevance_level}") if len(scores) > display_count: print(f"... and {len(scores) - display_count} more results") print("-"*80) print(f"🏆 HIGHEST SCORING PAGES:") top_3 = scores[:3] for i, (score, doc_id) in enumerate(top_3, 1): page_num = doc_id + 1 print(f" {i}. Page {page_num} - Score: {score:.4f}") print("="*80 + "\n") if len(scores) >= topk: return scores[:topk] else: return scores def _get_relevance_level(self, score): """Get human-readable relevance level based on score""" if score >= 0.90: return "🟢 EXCELLENT - Highly relevant" elif score >= 0.80: return "🟡 VERY GOOD - Very relevant" elif score >= 0.70: return "🟠 GOOD - Relevant" elif score >= 0.60: return "🔵 MODERATE - Somewhat relevant" elif score >= 0.50: return "🟣 BASIC - Minimally relevant" else: return "🔴 POOR - Not relevant" def insert(self, data): colbert_vecs = [vec for vec in data["colbert_vecs"]] seq_length = len(colbert_vecs) doc_ids = [data["doc_id"] for i in range(seq_length)] seq_ids = list(range(seq_length)) docs = [""] * seq_length docs[0] = data["filepath"] self.client.insert( self.collection_name, [ { "vector": colbert_vecs[i], "seq_id": seq_ids[i], "doc_id": doc_ids[i], "doc": docs[i], } for i in range(seq_length) ], ) def get_images_as_doc(self, images_with_vectors:list): images_data = [] for i in range(len(images_with_vectors)): data = { "colbert_vecs": images_with_vectors[i]["colbert_vecs"], "doc_id": i, "filepath": images_with_vectors[i]["filepath"], } images_data.append(data) return images_data def insert_images_data(self, image_data): data = self.get_images_as_doc(image_data) for i in range(len(data)): self.insert(data[i]) def drop_collection(self): """Drop the current collection from Milvus""" try: if self.client.has_collection(collection_name=self.collection_name): self.client.drop_collection(collection_name=self.collection_name) print(f"🗑️ Dropped Milvus collection: {self.collection_name}") return True else: print(f"⚠️ Collection {self.collection_name} does not exist in Milvus") return False except Exception as e: print(f"❌ Error dropping collection {self.collection_name}: {e}") return False