use num_bigint::BigInt; use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use std::collections::HashMap; use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::errors::ChromaError; use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig}; use crate::types::{EmbeddingRecord, Operation, Segment, VectorEmbeddingRecord}; pub(crate) struct DistributedHNSWSegment { index: Arc>, id: AtomicUsize, user_id_to_id: Arc>>, id_to_user_id: Arc>>, index_config: IndexConfig, hnsw_config: HnswIndexConfig, } impl DistributedHNSWSegment { pub(crate) fn new( index_config: IndexConfig, hnsw_config: HnswIndexConfig, ) -> Result> { let hnsw_index = HnswIndex::init(&index_config, Some(&hnsw_config)); let hnsw_index = match hnsw_index { Ok(index) => index, Err(e) => { // TODO: log + handle an error that we failed to init the index return Err(e); } }; let index = Arc::new(RwLock::new(hnsw_index)); return Ok(DistributedHNSWSegment { index: index, id: AtomicUsize::new(0), user_id_to_id: Arc::new(RwLock::new(HashMap::new())), id_to_user_id: Arc::new(RwLock::new(HashMap::new())), index_config: index_config, hnsw_config, }); } pub(crate) fn from_segment( segment: &Segment, persist_path: &std::path::Path, dimensionality: usize, ) -> Result, Box> { let index_config = IndexConfig::from_segment(&segment, dimensionality as i32)?; let hnsw_config = HnswIndexConfig::from_segment(segment, persist_path)?; Ok(Box::new(DistributedHNSWSegment::new( index_config, hnsw_config, )?)) } pub(crate) fn write_records(&self, records: Vec>) { for record in records { let op = Operation::try_from(record.operation); match op { Ok(Operation::Add) => { // TODO: make lock xor lock match &record.embedding { Some(vector) => { let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.user_id_to_id .write() .insert(record.id.clone(), next_id); self.id_to_user_id .write() .insert(next_id, record.id.clone()); println!("Segment adding item: {}", next_id); self.index.read().add(next_id, &vector); } None => { // TODO: log an error println!("No vector found in record"); } } } Ok(Operation::Upsert) => {} Ok(Operation::Update) => {} Ok(Operation::Delete) => {} Err(_) => { println!("Error parsing operation"); } } } } pub(crate) fn get_records(&self, ids: Vec) -> Vec> { let mut records = Vec::new(); let user_id_to_id = self.user_id_to_id.read(); let index = self.index.read(); for id in ids { let internal_id = match user_id_to_id.get(&id) { Some(internal_id) => internal_id, None => { // TODO: Error return records; } }; let vector = index.get(*internal_id); match vector { Some(vector) => { let record = VectorEmbeddingRecord { id: id, seq_id: BigInt::from(0), vector, }; records.push(Box::new(record)); } None => { // TODO: error } } } return records; } pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec, Vec) { let index = self.index.read(); let mut return_user_ids = Vec::new(); let (ids, distances) = index.query(vector, k); let user_ids = self.id_to_user_id.read(); for id in ids { match user_ids.get(&id) { Some(user_id) => return_user_ids.push(user_id.clone()), None => { // TODO: error } }; } return (return_user_ids, distances); } }