Spaces:
Running
Running
use std::f32::consts::E; | |
use crate::chroma_proto; | |
use crate::chroma_proto::{ | |
GetVectorsRequest, GetVectorsResponse, QueryVectorsRequest, QueryVectorsResponse, | |
}; | |
use crate::config::{Configurable, WorkerConfig}; | |
use crate::errors::ChromaError; | |
use crate::segment::SegmentManager; | |
use crate::types::ScalarEncoding; | |
use async_trait::async_trait; | |
use kube::core::request; | |
use tonic::{transport::Server, Request, Response, Status}; | |
use uuid::Uuid; | |
pub struct WorkerServer { | |
segment_manager: Option<SegmentManager>, | |
port: u16, | |
} | |
impl Configurable for WorkerServer { | |
async fn try_from_config(config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> { | |
Ok(WorkerServer { | |
segment_manager: None, | |
port: config.my_port, | |
}) | |
} | |
} | |
impl WorkerServer { | |
pub(crate) async fn run(worker: WorkerServer) -> Result<(), Box<dyn std::error::Error>> { | |
let addr = format!("[::]:{}", worker.port).parse().unwrap(); | |
println!("Worker listening on {}", addr); | |
let server = Server::builder() | |
.add_service(chroma_proto::vector_reader_server::VectorReaderServer::new( | |
worker, | |
)) | |
.serve(addr) | |
.await?; | |
println!("Worker shutting down"); | |
Ok(()) | |
} | |
pub(crate) fn set_segment_manager(&mut self, segment_manager: SegmentManager) { | |
self.segment_manager = Some(segment_manager); | |
} | |
} | |
impl chroma_proto::vector_reader_server::VectorReader for WorkerServer { | |
async fn get_vectors( | |
&self, | |
request: Request<GetVectorsRequest>, | |
) -> Result<Response<GetVectorsResponse>, Status> { | |
let request = request.into_inner(); | |
let segment_uuid = match Uuid::parse_str(&request.segment_id) { | |
Ok(uuid) => uuid, | |
Err(_) => { | |
return Err(Status::invalid_argument("Invalid UUID")); | |
} | |
}; | |
let segment_manager = match self.segment_manager { | |
Some(ref segment_manager) => segment_manager, | |
None => { | |
return Err(Status::internal("No segment manager found")); | |
} | |
}; | |
let records = match segment_manager | |
.get_records(&segment_uuid, request.ids) | |
.await | |
{ | |
Ok(records) => records, | |
Err(e) => { | |
return Err(Status::internal(format!("Error getting records: {}", e))); | |
} | |
}; | |
let mut proto_records = Vec::new(); | |
for record in records { | |
let sed_id_bytes = record.seq_id.to_bytes_le(); | |
let dim = record.vector.len(); | |
let proto_vector = (record.vector, ScalarEncoding::FLOAT32, dim).try_into(); | |
match proto_vector { | |
Ok(proto_vector) => { | |
let proto_record = chroma_proto::VectorEmbeddingRecord { | |
id: record.id, | |
seq_id: sed_id_bytes.1, | |
vector: Some(proto_vector), | |
}; | |
proto_records.push(proto_record); | |
} | |
Err(e) => { | |
return Err(Status::internal(format!("Error converting vector: {}", e))); | |
} | |
} | |
} | |
let resp = chroma_proto::GetVectorsResponse { | |
records: proto_records, | |
}; | |
Ok(Response::new(resp)) | |
} | |
async fn query_vectors( | |
&self, | |
request: Request<QueryVectorsRequest>, | |
) -> Result<Response<QueryVectorsResponse>, Status> { | |
let request = request.into_inner(); | |
let segment_uuid = match Uuid::parse_str(&request.segment_id) { | |
Ok(uuid) => uuid, | |
Err(_) => { | |
return Err(Status::invalid_argument("Invalid Segment UUID")); | |
} | |
}; | |
let segment_manager = match self.segment_manager { | |
Some(ref segment_manager) => segment_manager, | |
None => { | |
return Err(Status::internal("No segment manager found")); | |
} | |
}; | |
let mut proto_results_for_all = Vec::new(); | |
for proto_query_vector in request.vectors { | |
let (query_vector, encoding) = match proto_query_vector.try_into() { | |
Ok((vector, encoding)) => (vector, encoding), | |
Err(e) => { | |
return Err(Status::internal(format!("Error converting vector: {}", e))); | |
} | |
}; | |
let results = match segment_manager | |
.query_vector( | |
&segment_uuid, | |
&query_vector, | |
request.k as usize, | |
request.include_embeddings, | |
) | |
.await | |
{ | |
Ok(results) => results, | |
Err(e) => { | |
return Err(Status::internal(format!("Error querying segment: {}", e))); | |
} | |
}; | |
let mut proto_results = Vec::new(); | |
for query_result in results { | |
let proto_result = chroma_proto::VectorQueryResult { | |
id: query_result.id, | |
seq_id: query_result.seq_id.to_bytes_le().1, | |
distance: query_result.distance, | |
vector: match query_result.vector { | |
Some(vector) => { | |
match (vector, ScalarEncoding::FLOAT32, query_vector.len()).try_into() { | |
Ok(proto_vector) => Some(proto_vector), | |
Err(e) => { | |
return Err(Status::internal(format!( | |
"Error converting vector: {}", | |
e | |
))); | |
} | |
} | |
} | |
None => None, | |
}, | |
}; | |
proto_results.push(proto_result); | |
} | |
let vector_query_results = chroma_proto::VectorQueryResults { | |
results: proto_results, | |
}; | |
proto_results_for_all.push(vector_query_results); | |
} | |
let resp = chroma_proto::QueryVectorsResponse { | |
results: proto_results_for_all, | |
}; | |
return Ok(Response::new(resp)); | |
} | |
} | |