Spaces:
Build error
Build error
use std::collections::HashMap; | |
use std::sync::atomic::AtomicBool; | |
use std::sync::Arc; | |
use common::cpu::CpuPermit; | |
use rand::prelude::StdRng; | |
use rand::SeedableRng; | |
use segment::data_types::query_context::QueryContext; | |
use segment::data_types::vectors::{only_default_vector, DEFAULT_VECTOR_NAME}; | |
use segment::entry::entry_point::SegmentEntry; | |
use segment::fixtures::index_fixtures::random_vector; | |
use segment::fixtures::payload_fixtures::random_int_payload; | |
use segment::index::hnsw_index::graph_links::GraphLinksRam; | |
use segment::index::hnsw_index::hnsw::{HNSWIndex, HnswIndexOpenArgs}; | |
use segment::index::hnsw_index::num_rayon_threads; | |
use segment::index::VectorIndex; | |
use segment::json_path::JsonPath; | |
use segment::segment_constructor::build_segment; | |
use segment::types::{ | |
Condition, Distance, FieldCondition, Filter, HnswConfig, Indexes, Payload, PayloadSchemaType, | |
SegmentConfig, SeqNumberType, VectorDataConfig, VectorStorageType, WithPayload, | |
}; | |
use serde_json::json; | |
use tempfile::Builder; | |
fn test_batch_and_single_request_equivalency() { | |
let num_vectors: u64 = 1_000; | |
let distance = Distance::Cosine; | |
let num_payload_values = 2; | |
let dim = 8; | |
let mut rnd = StdRng::seed_from_u64(42); | |
let dir = Builder::new().prefix("segment_dir").tempdir().unwrap(); | |
let config = SegmentConfig { | |
vector_data: HashMap::from([( | |
DEFAULT_VECTOR_NAME.to_owned(), | |
VectorDataConfig { | |
size: dim, | |
distance, | |
storage_type: VectorStorageType::Memory, | |
index: Indexes::Plain {}, | |
quantization_config: None, | |
multivector_config: None, | |
datatype: None, | |
}, | |
)]), | |
sparse_vector_data: Default::default(), | |
payload_storage_type: Default::default(), | |
}; | |
let int_key = "int"; | |
let mut segment = build_segment(dir.path(), &config, true).unwrap(); | |
segment | |
.create_field_index( | |
0, | |
&JsonPath::new(int_key), | |
Some(&PayloadSchemaType::Integer.into()), | |
) | |
.unwrap(); | |
for n in 0..num_vectors { | |
let idx = n.into(); | |
let vector = random_vector(&mut rnd, dim); | |
let int_payload = random_int_payload(&mut rnd, num_payload_values..=num_payload_values); | |
let payload: Payload = json!({int_key:int_payload,}).into(); | |
segment | |
.upsert_point(n as SeqNumberType, idx, only_default_vector(&vector)) | |
.unwrap(); | |
segment | |
.set_full_payload(n as SeqNumberType, idx, &payload) | |
.unwrap(); | |
} | |
for _ in 0..10 { | |
let query_vector_1 = random_vector(&mut rnd, dim).into(); | |
let query_vector_2 = random_vector(&mut rnd, dim).into(); | |
let payload_value = random_int_payload(&mut rnd, 1..=1).pop().unwrap(); | |
let filter = Filter::new_must(Condition::Field(FieldCondition::new_match( | |
JsonPath::new(int_key), | |
payload_value.into(), | |
))); | |
let search_res_1 = segment | |
.search( | |
DEFAULT_VECTOR_NAME, | |
&query_vector_1, | |
&WithPayload::default(), | |
&false.into(), | |
Some(&filter), | |
10, | |
None, | |
) | |
.unwrap(); | |
let search_res_2 = segment | |
.search( | |
DEFAULT_VECTOR_NAME, | |
&query_vector_2, | |
&WithPayload::default(), | |
&false.into(), | |
Some(&filter), | |
10, | |
None, | |
) | |
.unwrap(); | |
let query_context = QueryContext::default(); | |
let segment_query_context = query_context.get_segment_query_context(); | |
let batch_res = segment | |
.search_batch( | |
DEFAULT_VECTOR_NAME, | |
&[&query_vector_1, &query_vector_2], | |
&WithPayload::default(), | |
&false.into(), | |
Some(&filter), | |
10, | |
None, | |
&segment_query_context, | |
) | |
.unwrap(); | |
// Ignore the hardware counter in tests | |
segment_query_context | |
.take_hardware_counter() | |
.discard_results(); | |
assert_eq!(search_res_1, batch_res[0]); | |
assert_eq!(search_res_2, batch_res[1]); | |
} | |
let hnsw_dir = Builder::new().prefix("hnsw_dir").tempdir().unwrap(); | |
let stopped = AtomicBool::new(false); | |
let payload_index_ptr = segment.payload_index.clone(); | |
let m = 8; | |
let ef_construct = 100; | |
let full_scan_threshold = 10000; | |
let hnsw_config = HnswConfig { | |
m, | |
ef_construct, | |
full_scan_threshold, | |
max_indexing_threads: 2, | |
on_disk: Some(false), | |
payload_m: None, | |
}; | |
let permit_cpu_count = num_rayon_threads(hnsw_config.max_indexing_threads); | |
let permit = Arc::new(CpuPermit::dummy(permit_cpu_count as u32)); | |
let vector_storage = &segment.vector_data[DEFAULT_VECTOR_NAME].vector_storage; | |
let quantized_vectors = &segment.vector_data[DEFAULT_VECTOR_NAME].quantized_vectors; | |
let hnsw_index = HNSWIndex::<GraphLinksRam>::open(HnswIndexOpenArgs { | |
path: hnsw_dir.path(), | |
id_tracker: segment.id_tracker.clone(), | |
vector_storage: vector_storage.clone(), | |
quantized_vectors: quantized_vectors.clone(), | |
payload_index: payload_index_ptr, | |
hnsw_config, | |
permit: Some(permit), | |
stopped: &stopped, | |
}) | |
.unwrap(); | |
for _ in 0..10 { | |
let query_vector_1 = random_vector(&mut rnd, dim).into(); | |
let query_vector_2 = random_vector(&mut rnd, dim).into(); | |
let payload_value = random_int_payload(&mut rnd, 1..=1).pop().unwrap(); | |
let filter = Filter::new_must(Condition::Field(FieldCondition::new_match( | |
JsonPath::new(int_key), | |
payload_value.into(), | |
))); | |
let search_res_1 = hnsw_index | |
.search( | |
&[&query_vector_1], | |
Some(&filter), | |
10, | |
None, | |
&Default::default(), | |
) | |
.unwrap(); | |
let search_res_2 = hnsw_index | |
.search( | |
&[&query_vector_2], | |
Some(&filter), | |
10, | |
None, | |
&Default::default(), | |
) | |
.unwrap(); | |
let batch_res = hnsw_index | |
.search( | |
&[&query_vector_1, &query_vector_2], | |
Some(&filter), | |
10, | |
None, | |
&Default::default(), | |
) | |
.unwrap(); | |
assert_eq!(search_res_1[0], batch_res[0]); | |
assert_eq!(search_res_2[0], batch_res[1]); | |
} | |
} | |