File size: 5,000 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use num_bigint::BigInt;
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::errors::ChromaError;
use crate::index::{HnswIndex, HnswIndexConfig, Index, IndexConfig};
use crate::types::{EmbeddingRecord, Operation, Segment, VectorEmbeddingRecord};

pub(crate) struct DistributedHNSWSegment {
    index: Arc<RwLock<HnswIndex>>,
    id: AtomicUsize,
    user_id_to_id: Arc<RwLock<HashMap<String, usize>>>,
    id_to_user_id: Arc<RwLock<HashMap<usize, String>>>,
    index_config: IndexConfig,
    hnsw_config: HnswIndexConfig,
}

impl DistributedHNSWSegment {
    pub(crate) fn new(
        index_config: IndexConfig,
        hnsw_config: HnswIndexConfig,
    ) -> Result<Self, Box<dyn ChromaError>> {
        let hnsw_index = HnswIndex::init(&index_config, Some(&hnsw_config));
        let hnsw_index = match hnsw_index {
            Ok(index) => index,
            Err(e) => {
                // TODO: log + handle an error that we failed to init the index
                return Err(e);
            }
        };
        let index = Arc::new(RwLock::new(hnsw_index));
        return Ok(DistributedHNSWSegment {
            index: index,
            id: AtomicUsize::new(0),
            user_id_to_id: Arc::new(RwLock::new(HashMap::new())),
            id_to_user_id: Arc::new(RwLock::new(HashMap::new())),
            index_config: index_config,
            hnsw_config,
        });
    }

    pub(crate) fn from_segment(
        segment: &Segment,
        persist_path: &std::path::Path,
        dimensionality: usize,
    ) -> Result<Box<DistributedHNSWSegment>, Box<dyn ChromaError>> {
        let index_config = IndexConfig::from_segment(&segment, dimensionality as i32)?;
        let hnsw_config = HnswIndexConfig::from_segment(segment, persist_path)?;
        Ok(Box::new(DistributedHNSWSegment::new(
            index_config,
            hnsw_config,
        )?))
    }

    pub(crate) fn write_records(&self, records: Vec<Box<EmbeddingRecord>>) {
        for record in records {
            let op = Operation::try_from(record.operation);
            match op {
                Ok(Operation::Add) => {
                    // TODO: make lock xor lock
                    match &record.embedding {
                        Some(vector) => {
                            let next_id = self.id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
                            self.user_id_to_id
                                .write()
                                .insert(record.id.clone(), next_id);
                            self.id_to_user_id
                                .write()
                                .insert(next_id, record.id.clone());
                            println!("Segment adding item: {}", next_id);
                            self.index.read().add(next_id, &vector);
                        }
                        None => {
                            // TODO: log an error
                            println!("No vector found in record");
                        }
                    }
                }
                Ok(Operation::Upsert) => {}
                Ok(Operation::Update) => {}
                Ok(Operation::Delete) => {}
                Err(_) => {
                    println!("Error parsing operation");
                }
            }
        }
    }

    pub(crate) fn get_records(&self, ids: Vec<String>) -> Vec<Box<VectorEmbeddingRecord>> {
        let mut records = Vec::new();
        let user_id_to_id = self.user_id_to_id.read();
        let index = self.index.read();
        for id in ids {
            let internal_id = match user_id_to_id.get(&id) {
                Some(internal_id) => internal_id,
                None => {
                    // TODO: Error
                    return records;
                }
            };
            let vector = index.get(*internal_id);
            match vector {
                Some(vector) => {
                    let record = VectorEmbeddingRecord {
                        id: id,
                        seq_id: BigInt::from(0),
                        vector,
                    };
                    records.push(Box::new(record));
                }
                None => {
                    // TODO: error
                }
            }
        }
        return records;
    }

    pub(crate) fn query(&self, vector: &[f32], k: usize) -> (Vec<String>, Vec<f32>) {
        let index = self.index.read();
        let mut return_user_ids = Vec::new();
        let (ids, distances) = index.query(vector, k);
        let user_ids = self.id_to_user_id.read();
        for id in ids {
            match user_ids.get(&id) {
                Some(user_id) => return_user_ids.push(user_id.clone()),
                None => {
                    // TODO: error
                }
            };
        }
        return (return_user_ids, distances);
    }
}