Spaces:
Sleeping
Sleeping
| // A scheduler recieves embedding records for a given batch of documents | |
| // and schedules them to be ingested to the segment manager | |
| use crate::{ | |
| system::{Component, ComponentContext, Handler, Receiver}, | |
| types::EmbeddingRecord, | |
| }; | |
| use async_trait::async_trait; | |
| use rand::prelude::SliceRandom; | |
| use rand::Rng; | |
| use std::{ | |
| collections::{btree_map::Range, HashMap}, | |
| fmt::{Debug, Formatter, Result}, | |
| sync::Arc, | |
| }; | |
| pub(crate) struct RoundRobinScheduler { | |
| // The segment manager to schedule to, a segment manager is a component | |
| // segment_manager: SegmentManager | |
| curr_wake_up: Option<tokio::sync::oneshot::Sender<WakeMessage>>, | |
| tenant_to_queue: HashMap<String, tokio::sync::mpsc::Sender<Box<EmbeddingRecord>>>, | |
| new_tenant_channel: Option<tokio::sync::mpsc::Sender<NewTenantMessage>>, | |
| subscribers: Option<Vec<Box<dyn Receiver<Box<EmbeddingRecord>>>>>, | |
| } | |
| impl Debug for RoundRobinScheduler { | |
| fn fmt(&self, f: &mut Formatter<'_>) -> Result { | |
| f.debug_struct("Scheduler").finish() | |
| } | |
| } | |
| impl RoundRobinScheduler { | |
| pub(crate) fn new() -> Self { | |
| RoundRobinScheduler { | |
| curr_wake_up: None, | |
| tenant_to_queue: HashMap::new(), | |
| new_tenant_channel: None, | |
| subscribers: Some(Vec::new()), | |
| } | |
| } | |
| pub(crate) fn subscribe(&mut self, subscriber: Box<dyn Receiver<Box<EmbeddingRecord>>>) { | |
| match self.subscribers { | |
| Some(ref mut subscribers) => { | |
| subscribers.push(subscriber); | |
| } | |
| None => {} | |
| } | |
| } | |
| } | |
| impl Component for RoundRobinScheduler { | |
| fn queue_size(&self) -> usize { | |
| 1000 | |
| } | |
| fn on_start(&mut self, ctx: &ComponentContext<Self>) { | |
| let sleep_sender = ctx.sender.clone(); | |
| let (new_tenant_tx, mut new_tenant_rx) = tokio::sync::mpsc::channel(1000); | |
| self.new_tenant_channel = Some(new_tenant_tx); | |
| let cancellation_token = ctx.cancellation_token.clone(); | |
| let subscribers = self.subscribers.take(); | |
| let mut subscribers = match subscribers { | |
| Some(subscribers) => subscribers, | |
| None => { | |
| // TODO: log + error | |
| return; | |
| } | |
| }; | |
| tokio::spawn(async move { | |
| let mut tenant_queues: HashMap< | |
| String, | |
| tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>, | |
| > = HashMap::new(); | |
| loop { | |
| // TODO: handle cancellation | |
| let mut did_work = false; | |
| for tenant_queue in tenant_queues.values_mut() { | |
| match tenant_queue.try_recv() { | |
| Ok(message) => { | |
| // Randomly pick a subscriber to send the message to | |
| // This serves as a crude load balancing between available threads | |
| // Future improvements here could be | |
| // - Use a work stealing scheduler | |
| // - Use rayon | |
| // - We need to enforce partial order over writes to a given key | |
| // so we need a mechanism to ensure that all writes to a given key | |
| // occur in order | |
| let mut subscriber = None; | |
| { | |
| let mut rng = rand::thread_rng(); | |
| subscriber = subscribers.choose_mut(&mut rng); | |
| } | |
| match subscriber { | |
| Some(subscriber) => { | |
| let res = subscriber.send(message).await; | |
| } | |
| None => {} | |
| } | |
| did_work = true; | |
| } | |
| Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { | |
| continue; | |
| } | |
| Err(_) => { | |
| // TODO: Handle a erroneous channel | |
| // log an error | |
| continue; | |
| } | |
| }; | |
| } | |
| match new_tenant_rx.try_recv() { | |
| Ok(new_tenant_message) => { | |
| tenant_queues.insert(new_tenant_message.tenant, new_tenant_message.channel); | |
| } | |
| Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { | |
| // no - op | |
| } | |
| Err(_) => { | |
| // TODO: handle erroneous channel | |
| // log an error | |
| continue; | |
| } | |
| }; | |
| if !did_work { | |
| // Send a sleep message to the sender | |
| let (wake_tx, wake_rx) = tokio::sync::oneshot::channel(); | |
| let sleep_res = sleep_sender.send(SleepMessage { sender: wake_tx }).await; | |
| let wake_res = wake_rx.await; | |
| } | |
| } | |
| }); | |
| } | |
| } | |
| impl Handler<(String, Box<EmbeddingRecord>)> for RoundRobinScheduler { | |
| async fn handle( | |
| &mut self, | |
| message: (String, Box<EmbeddingRecord>), | |
| _ctx: &ComponentContext<Self>, | |
| ) { | |
| let (tenant, embedding_record) = message; | |
| // Check if the tenant is already in the tenant set, if not we need to inform the scheduler loop | |
| // of a new tenant | |
| if self.tenant_to_queue.get(&tenant).is_none() { | |
| // Create a new channel for the tenant | |
| let (sender, reciever) = tokio::sync::mpsc::channel(1000); | |
| // Add the tenant to the tenant set | |
| self.tenant_to_queue.insert(tenant.clone(), sender); | |
| // Send the new tenant message to the scheduler loop | |
| let new_tenant_channel = match self.new_tenant_channel { | |
| Some(ref mut channel) => channel, | |
| None => { | |
| // TODO: this is an error | |
| // It should always be populated by on_start | |
| return; | |
| } | |
| }; | |
| let res = new_tenant_channel | |
| .send(NewTenantMessage { | |
| tenant: tenant.clone(), | |
| channel: reciever, | |
| }) | |
| .await; | |
| // TODO: handle this res | |
| } | |
| // Send the embedding record to the tenant's channel | |
| let res = self | |
| .tenant_to_queue | |
| .get(&tenant) | |
| .unwrap() | |
| .send(embedding_record) | |
| .await; | |
| // TODO: handle this res | |
| // Check if the scheduler is sleeping, if so wake it up | |
| // TODO: we need to init with a wakeup otherwise we are off by one | |
| if self.curr_wake_up.is_some() { | |
| // Send a wake up message to the scheduler loop | |
| let res = self.curr_wake_up.take().unwrap().send(WakeMessage {}); | |
| // TOOD: handle this res | |
| } | |
| } | |
| } | |
| impl Handler<SleepMessage> for RoundRobinScheduler { | |
| async fn handle(&mut self, message: SleepMessage, _ctx: &ComponentContext<Self>) { | |
| // Set the current wake up channel | |
| self.curr_wake_up = Some(message.sender); | |
| } | |
| } | |
| /// Used by round robin scheduler to wake its scheduler loop | |
| struct WakeMessage {} | |
| /// The round robin scheduler will sleep when there is no work to be done and send a sleep message | |
| /// this allows the manager to wake it up when there is work to be scheduled | |
| struct SleepMessage { | |
| sender: tokio::sync::oneshot::Sender<WakeMessage>, | |
| } | |
| struct NewTenantMessage { | |
| tenant: String, | |
| channel: tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>, | |
| } | |