Spaces:
Running
Running
File size: 5,000 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
use num_bigint::BigInt;
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use crate::errors::ChromaError;
use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig};
use crate::types::{EmbeddingRecord, Operation, Segment, VectorEmbeddingRecord};
pub(crate) struct DistributedHNSWSegment {
index: Arc<RwLock<HnswIndex>>,
id: AtomicUsize,
user_id_to_id: Arc<RwLock<HashMap<String, usize>>>,
id_to_user_id: Arc<RwLock<HashMap<usize, String>>>,
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
}
impl DistributedHNSWSegment {
pub(crate) fn new(
index_config: IndexConfig,
hnsw_config: HnswIndexConfig,
) -> Result<Self, Box<dyn ChromaError>> {
let hnsw_index = HnswIndex::init(&index_config, Some(&hnsw_config));
let hnsw_index = match hnsw_index {
Ok(index) => index,
Err(e) => {
// TODO: log + handle an error that we failed to init the index
return Err(e);
}
};
let index = Arc::new(RwLock::new(hnsw_index));
return Ok(DistributedHNSWSegment {
index: index,
id: AtomicUsize::new(0),
user_id_to_id: Arc::new(RwLock::new(HashMap::new())),
id_to_user_id: Arc::new(RwLock::new(HashMap::new())),
index_config: index_config,
hnsw_config,
});
}
pub(crate) fn from_segment(
segment: &Segment,
persist_path: &std::path::Path,
dimensionality: usize,
) -> Result<Box<DistributedHNSWSegment>, Box<dyn ChromaError>> {
let index_config = IndexConfig::from_segment(&segment, dimensionality as i32)?;
let hnsw_config = HnswIndexConfig::from_segment(segment, persist_path)?;
Ok(Box::new(DistributedHNSWSegment::new(
index_config,
hnsw_config,
)?))
}
pub(crate) fn write_records(&self, records: Vec<Box<EmbeddingRecord>>) {
for record in records {
let op = Operation::try_from(record.operation);
match op {
Ok(Operation::Add) => {
// TODO: make lock xor lock
match &record.embedding {
Some(vector) => {
let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.user_id_to_id
.write()
.insert(record.id.clone(), next_id);
self.id_to_user_id
.write()
.insert(next_id, record.id.clone());
println!("Segment adding item: {}", next_id);
self.index.read().add(next_id, &vector);
}
None => {
// TODO: log an error
println!("No vector found in record");
}
}
}
Ok(Operation::Upsert) => {}
Ok(Operation::Update) => {}
Ok(Operation::Delete) => {}
Err(_) => {
println!("Error parsing operation");
}
}
}
}
pub(crate) fn get_records(&self, ids: Vec<String>) -> Vec<Box<VectorEmbeddingRecord>> {
let mut records = Vec::new();
let user_id_to_id = self.user_id_to_id.read();
let index = self.index.read();
for id in ids {
let internal_id = match user_id_to_id.get(&id) {
Some(internal_id) => internal_id,
None => {
// TODO: Error
return records;
}
};
let vector = index.get(*internal_id);
match vector {
Some(vector) => {
let record = VectorEmbeddingRecord {
id: id,
seq_id: BigInt::from(0),
vector,
};
records.push(Box::new(record));
}
None => {
// TODO: error
}
}
}
return records;
}
pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<String>, Vec<f32>) {
let index = self.index.read();
let mut return_user_ids = Vec::new();
let (ids, distances) = index.query(vector, k);
let user_ids = self.id_to_user_id.read();
for id in ids {
match user_ids.get(&id) {
Some(user_id) => return_user_ids.push(user_id.clone()),
None => {
// TODO: error
}
};
}
return (return_user_ids, distances);
}
}
|