use crate::{ config::{Configurable, WorkerConfig}, errors::ChromaError, sysdb::sysdb::{GrpcSysDb, SysDb}, types::VectorQueryResult, }; use async_trait::async_trait; use k8s_openapi::api::node; use num_bigint::BigInt; use parking_lot::{ MappedRwLockReadGuard, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard, }; use std::collections::HashMap; use std::sync::Arc; use uuid::Uuid; use super::distributed_hnsw_segment::DistributedHNSWSegment; use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope, VectorEmbeddingRecord}; #[derive(Clone)] pub(crate) struct SegmentManager { inner: Arc, sysdb: Box, } /// struct Inner { vector_segments: RwLock>>, collection_to_segment_cache: RwLock>>>, storage_path: Box, } impl SegmentManager { pub(crate) fn new(sysdb: Box, storage_path: &std::path::Path) -> Self { SegmentManager { inner: Arc::new(Inner { vector_segments: RwLock::new(HashMap::new()), collection_to_segment_cache: RwLock::new(HashMap::new()), storage_path: Box::new(storage_path.to_owned()), }), sysdb: sysdb, } } pub(crate) async fn write_record(&mut self, record: Box) { let collection_id = record.collection_id; let mut target_segment = None; // TODO: don't assume 1:1 mapping between collection and segment { let segments = self.get_segments(&collection_id).await; target_segment = match segments { Ok(found_segments) => { if found_segments.len() == 0 { return; // TODO: handle no segment found } Some(found_segments[0].clone()) } Err(_) => { // TODO: throw an error and log no segment found return; } }; } let target_segment = match target_segment { Some(segment) => segment, None => { // TODO: throw an error and log no segment found return; } }; println!("Writing to segment id {}", target_segment.id); let segment_cache = self.inner.vector_segments.upgradable_read(); match segment_cache.get(&target_segment.id) { Some(segment) => { segment.write_records(vec![record]); } None => { let mut segment_cache = RwLockUpgradableReadGuard::upgrade(segment_cache); let new_segment = DistributedHNSWSegment::from_segment( &target_segment, &self.inner.storage_path, // TODO: Don't unwrap - throw an error record.embedding.as_ref().unwrap().len(), ); match new_segment { Ok(new_segment) => { new_segment.write_records(vec![record]); segment_cache.insert(target_segment.id, new_segment); } Err(e) => { println!("Failed to create segment error {}", e); // TODO: fail and log an error - failed to create/init segment } } } } } pub(crate) async fn get_records( &self, segment_id: &Uuid, ids: Vec, ) -> Result>, &'static str> { // TODO: Load segment if not in cache let segment_cache = self.inner.vector_segments.read(); match segment_cache.get(segment_id) { Some(segment) => { return Ok(segment.get_records(ids)); } None => { return Err("No segment found"); } } } pub(crate) async fn query_vector( &self, segment_id: &Uuid, vectors: &[f32], k: usize, include_vector: bool, ) -> Result>, &'static str> { let segment_cache = self.inner.vector_segments.read(); match segment_cache.get(segment_id) { Some(segment) => { let mut results = Vec::new(); let (ids, distances) = segment.query(vectors, k); for (id, distance) in ids.iter().zip(distances.iter()) { let fetched_vector = match include_vector { true => Some(segment.get_records(vec![id.clone()])), false => None, }; let mut target_record = None; if include_vector { target_record = match fetched_vector { Some(fetched_vectors) => { if fetched_vectors.len() == 0 { return Err("No vector found"); } let mut target_vec = None; for vec in fetched_vectors.into_iter() { if vec.id == *id { target_vec = Some(vec); break; } } target_vec } None => { return Err("No vector found"); } }; } let ret_vec = match target_record { Some(target_record) => Some(target_record.vector), None => None, }; let result = Box::new(VectorQueryResult { id: id.to_string(), seq_id: BigInt::from(0), distance: *distance, vector: ret_vec, }); results.push(result); } return Ok(results); } None => { return Err("No segment found"); } } } async fn get_segments( &mut self, collection_uuid: &Uuid, ) -> Result>>, &'static str> { let cache_guard = self.inner.collection_to_segment_cache.read(); // This lets us return a reference to the segments with the lock. The caller is responsible // dropping the lock. let segments = RwLockReadGuard::try_map(cache_guard, |cache| { return cache.get(&collection_uuid); }); match segments { Ok(segments) => { return Ok(segments); } Err(_) => { // Data was not in the cache, so we need to get it from the database // Drop the lock since we need to upgrade it // Mappable locks cannot be upgraded, so we need to drop the lock and re-acquire it // https://github.com/Amanieu/parking_lot/issues/83 drop(segments); let segments = self .sysdb .get_segments( None, None, Some(SegmentScope::VECTOR), None, Some(collection_uuid.clone()), ) .await; match segments { Ok(segments) => { let mut cache_guard = self.inner.collection_to_segment_cache.write(); let mut arc_segments = Vec::new(); for segment in segments { arc_segments.push(Arc::new(segment)); } cache_guard.insert(collection_uuid.clone(), arc_segments); let cache_guard = RwLockWriteGuard::downgrade(cache_guard); let segments = RwLockReadGuard::map(cache_guard, |cache| { // This unwrap is safe because we just inserted the segments into the cache and currently, // there is no way to remove segments from the cache. return cache.get(&collection_uuid).unwrap(); }); return Ok(segments); } Err(e) => { return Err("Failed to get segments for collection from SysDB"); } } } } } } #[async_trait] impl Configurable for SegmentManager { async fn try_from_config(worker_config: &WorkerConfig) -> Result> { // TODO: Sysdb should have a dynamic resolution in sysdb let sysdb = GrpcSysDb::try_from_config(worker_config).await; let sysdb = match sysdb { Ok(sysdb) => sysdb, Err(err) => { return Err(err); } }; let path = std::path::Path::new(&worker_config.segment_manager.storage_path); Ok(SegmentManager::new(Box::new(sysdb), path)) } }