Spaces:
Running
Running
use crate::{ | |
config::{Configurable, WorkerConfig}, | |
errors::ChromaError, | |
sysdb::sysdb::{GrpcSysDb, SysDb}, | |
types::VectorQueryResult, | |
}; | |
use async_trait::async_trait; | |
use k8s_openapi::api::node; | |
use num_bigint::BigInt; | |
use parking_lot::{ | |
MappedRwLockReadGuard, RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard, | |
}; | |
use std::collections::HashMap; | |
use std::sync::Arc; | |
use uuid::Uuid; | |
use super::distributed_hnsw_segment::DistributedHNSWSegment; | |
use crate::types::{EmbeddingRecord, MetadataValue, Segment, SegmentScope, VectorEmbeddingRecord}; | |
pub(crate) struct SegmentManager { | |
inner: Arc<Inner>, | |
sysdb: Box<dyn SysDb>, | |
} | |
/// | |
struct Inner { | |
vector_segments: RwLock<HashMap<Uuid, Box<DistributedHNSWSegment>>>, | |
collection_to_segment_cache: RwLock<HashMap<Uuid, Vec<Arc<Segment>>>>, | |
storage_path: Box<std::path::PathBuf>, | |
} | |
impl SegmentManager { | |
pub(crate) fn new(sysdb: Box<dyn SysDb>, storage_path: &std::path::Path) -> Self { | |
SegmentManager { | |
inner: Arc::new(Inner { | |
vector_segments: RwLock::new(HashMap::new()), | |
collection_to_segment_cache: RwLock::new(HashMap::new()), | |
storage_path: Box::new(storage_path.to_owned()), | |
}), | |
sysdb: sysdb, | |
} | |
} | |
pub(crate) async fn write_record(&mut self, record: Box<EmbeddingRecord>) { | |
let collection_id = record.collection_id; | |
let mut target_segment = None; | |
// TODO: don't assume 1:1 mapping between collection and segment | |
{ | |
let segments = self.get_segments(&collection_id).await; | |
target_segment = match segments { | |
Ok(found_segments) => { | |
if found_segments.len() == 0 { | |
return; // TODO: handle no segment found | |
} | |
Some(found_segments[0].clone()) | |
} | |
Err(_) => { | |
// TODO: throw an error and log no segment found | |
return; | |
} | |
}; | |
} | |
let target_segment = match target_segment { | |
Some(segment) => segment, | |
None => { | |
// TODO: throw an error and log no segment found | |
return; | |
} | |
}; | |
println!("Writing to segment id {}", target_segment.id); | |
let segment_cache = self.inner.vector_segments.upgradable_read(); | |
match segment_cache.get(&target_segment.id) { | |
Some(segment) => { | |
segment.write_records(vec![record]); | |
} | |
None => { | |
let mut segment_cache = RwLockUpgradableReadGuard::upgrade(segment_cache); | |
let new_segment = DistributedHNSWSegment::from_segment( | |
&target_segment, | |
&self.inner.storage_path, | |
// TODO: Don't unwrap - throw an error | |
record.embedding.as_ref().unwrap().len(), | |
); | |
match new_segment { | |
Ok(new_segment) => { | |
new_segment.write_records(vec![record]); | |
segment_cache.insert(target_segment.id, new_segment); | |
} | |
Err(e) => { | |
println!("Failed to create segment error {}", e); | |
// TODO: fail and log an error - failed to create/init segment | |
} | |
} | |
} | |
} | |
} | |
pub(crate) async fn get_records( | |
&self, | |
segment_id: &Uuid, | |
ids: Vec<String>, | |
) -> Result<Vec<Box<VectorEmbeddingRecord>>, &'static str> { | |
// TODO: Load segment if not in cache | |
let segment_cache = self.inner.vector_segments.read(); | |
match segment_cache.get(segment_id) { | |
Some(segment) => { | |
return Ok(segment.get_records(ids)); | |
} | |
None => { | |
return Err("No segment found"); | |
} | |
} | |
} | |
pub(crate) async fn query_vector( | |
&self, | |
segment_id: &Uuid, | |
vectors: &[f32], | |
k: usize, | |
include_vector: bool, | |
) -> Result<Vec<Box<VectorQueryResult>>, &'static str> { | |
let segment_cache = self.inner.vector_segments.read(); | |
match segment_cache.get(segment_id) { | |
Some(segment) => { | |
let mut results = Vec::new(); | |
let (ids, distances) = segment.query(vectors, k); | |
for (id, distance) in ids.iter().zip(distances.iter()) { | |
let fetched_vector = match include_vector { | |
true => Some(segment.get_records(vec![id.clone()])), | |
false => None, | |
}; | |
let mut target_record = None; | |
if include_vector { | |
target_record = match fetched_vector { | |
Some(fetched_vectors) => { | |
if fetched_vectors.len() == 0 { | |
return Err("No vector found"); | |
} | |
let mut target_vec = None; | |
for vec in fetched_vectors.into_iter() { | |
if vec.id == *id { | |
target_vec = Some(vec); | |
break; | |
} | |
} | |
target_vec | |
} | |
None => { | |
return Err("No vector found"); | |
} | |
}; | |
} | |
let ret_vec = match target_record { | |
Some(target_record) => Some(target_record.vector), | |
None => None, | |
}; | |
let result = Box::new(VectorQueryResult { | |
id: id.to_string(), | |
seq_id: BigInt::from(0), | |
distance: *distance, | |
vector: ret_vec, | |
}); | |
results.push(result); | |
} | |
return Ok(results); | |
} | |
None => { | |
return Err("No segment found"); | |
} | |
} | |
} | |
async fn get_segments( | |
&mut self, | |
collection_uuid: &Uuid, | |
) -> Result<MappedRwLockReadGuard<Vec<Arc<Segment>>>, &'static str> { | |
let cache_guard = self.inner.collection_to_segment_cache.read(); | |
// This lets us return a reference to the segments with the lock. The caller is responsible | |
// dropping the lock. | |
let segments = RwLockReadGuard::try_map(cache_guard, |cache| { | |
return cache.get(&collection_uuid); | |
}); | |
match segments { | |
Ok(segments) => { | |
return Ok(segments); | |
} | |
Err(_) => { | |
// Data was not in the cache, so we need to get it from the database | |
// Drop the lock since we need to upgrade it | |
// Mappable locks cannot be upgraded, so we need to drop the lock and re-acquire it | |
// https://github.com/Amanieu/parking_lot/issues/83 | |
drop(segments); | |
let segments = self | |
.sysdb | |
.get_segments( | |
None, | |
None, | |
Some(SegmentScope::VECTOR), | |
None, | |
Some(collection_uuid.clone()), | |
) | |
.await; | |
match segments { | |
Ok(segments) => { | |
let mut cache_guard = self.inner.collection_to_segment_cache.write(); | |
let mut arc_segments = Vec::new(); | |
for segment in segments { | |
arc_segments.push(Arc::new(segment)); | |
} | |
cache_guard.insert(collection_uuid.clone(), arc_segments); | |
let cache_guard = RwLockWriteGuard::downgrade(cache_guard); | |
let segments = RwLockReadGuard::map(cache_guard, |cache| { | |
// This unwrap is safe because we just inserted the segments into the cache and currently, | |
// there is no way to remove segments from the cache. | |
return cache.get(&collection_uuid).unwrap(); | |
}); | |
return Ok(segments); | |
} | |
Err(e) => { | |
return Err("Failed to get segments for collection from SysDB"); | |
} | |
} | |
} | |
} | |
} | |
} | |
impl Configurable for SegmentManager { | |
async fn try_from_config(worker_config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> { | |
// TODO: Sysdb should have a dynamic resolution in sysdb | |
let sysdb = GrpcSysDb::try_from_config(worker_config).await; | |
let sysdb = match sysdb { | |
Ok(sysdb) => sysdb, | |
Err(err) => { | |
return Err(err); | |
} | |
}; | |
let path = std::path::Path::new(&worker_config.segment_manager.storage_path); | |
Ok(SegmentManager::new(Box::new(sysdb), path)) | |
} | |
} | |