// A scheduler recieves embedding records for a given batch of documents // and schedules them to be ingested to the segment manager use crate::{ system::{Component, ComponentContext, Handler, Receiver}, types::EmbeddingRecord, }; use async_trait::async_trait; use rand::prelude::SliceRandom; use rand::Rng; use std::{ collections::{btree_map::Range, HashMap}, fmt::{Debug, Formatter, Result}, sync::Arc, }; pub(crate) struct RoundRobinScheduler { // The segment manager to schedule to, a segment manager is a component // segment_manager: SegmentManager curr_wake_up: Option>, tenant_to_queue: HashMap>>, new_tenant_channel: Option>, subscribers: Option>>>>, } impl Debug for RoundRobinScheduler { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.debug_struct("Scheduler").finish() } } impl RoundRobinScheduler { pub(crate) fn new() -> Self { RoundRobinScheduler { curr_wake_up: None, tenant_to_queue: HashMap::new(), new_tenant_channel: None, subscribers: Some(Vec::new()), } } pub(crate) fn subscribe(&mut self, subscriber: Box>>) { match self.subscribers { Some(ref mut subscribers) => { subscribers.push(subscriber); } None => {} } } } impl Component for RoundRobinScheduler { fn queue_size(&self) -> usize { 1000 } fn on_start(&mut self, ctx: &ComponentContext) { let sleep_sender = ctx.sender.clone(); let (new_tenant_tx, mut new_tenant_rx) = tokio::sync::mpsc::channel(1000); self.new_tenant_channel = Some(new_tenant_tx); let cancellation_token = ctx.cancellation_token.clone(); let subscribers = self.subscribers.take(); let mut subscribers = match subscribers { Some(subscribers) => subscribers, None => { // TODO: log + error return; } }; tokio::spawn(async move { let mut tenant_queues: HashMap< String, tokio::sync::mpsc::Receiver>, > = HashMap::new(); loop { // TODO: handle cancellation let mut did_work = false; for tenant_queue in tenant_queues.values_mut() { match tenant_queue.try_recv() { Ok(message) => { // Randomly pick a subscriber to send the message to // This serves as a crude load balancing between available threads // Future improvements here could be // - Use a work stealing scheduler // - Use rayon // - We need to enforce partial order over writes to a given key // so we need a mechanism to ensure that all writes to a given key // occur in order let mut subscriber = None; { let mut rng = rand::thread_rng(); subscriber = subscribers.choose_mut(&mut rng); } match subscriber { Some(subscriber) => { let res = subscriber.send(message).await; } None => {} } did_work = true; } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { continue; } Err(_) => { // TODO: Handle a erroneous channel // log an error continue; } }; } match new_tenant_rx.try_recv() { Ok(new_tenant_message) => { tenant_queues.insert(new_tenant_message.tenant, new_tenant_message.channel); } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => { // no - op } Err(_) => { // TODO: handle erroneous channel // log an error continue; } }; if !did_work { // Send a sleep message to the sender let (wake_tx, wake_rx) = tokio::sync::oneshot::channel(); let sleep_res = sleep_sender.send(SleepMessage { sender: wake_tx }).await; let wake_res = wake_rx.await; } } }); } } #[async_trait] impl Handler<(String, Box)> for RoundRobinScheduler { async fn handle( &mut self, message: (String, Box), _ctx: &ComponentContext, ) { let (tenant, embedding_record) = message; // Check if the tenant is already in the tenant set, if not we need to inform the scheduler loop // of a new tenant if self.tenant_to_queue.get(&tenant).is_none() { // Create a new channel for the tenant let (sender, reciever) = tokio::sync::mpsc::channel(1000); // Add the tenant to the tenant set self.tenant_to_queue.insert(tenant.clone(), sender); // Send the new tenant message to the scheduler loop let new_tenant_channel = match self.new_tenant_channel { Some(ref mut channel) => channel, None => { // TODO: this is an error // It should always be populated by on_start return; } }; let res = new_tenant_channel .send(NewTenantMessage { tenant: tenant.clone(), channel: reciever, }) .await; // TODO: handle this res } // Send the embedding record to the tenant's channel let res = self .tenant_to_queue .get(&tenant) .unwrap() .send(embedding_record) .await; // TODO: handle this res // Check if the scheduler is sleeping, if so wake it up // TODO: we need to init with a wakeup otherwise we are off by one if self.curr_wake_up.is_some() { // Send a wake up message to the scheduler loop let res = self.curr_wake_up.take().unwrap().send(WakeMessage {}); // TOOD: handle this res } } } #[async_trait] impl Handler for RoundRobinScheduler { async fn handle(&mut self, message: SleepMessage, _ctx: &ComponentContext) { // Set the current wake up channel self.curr_wake_up = Some(message.sender); } } /// Used by round robin scheduler to wake its scheduler loop #[derive(Debug)] struct WakeMessage {} /// The round robin scheduler will sleep when there is no work to be done and send a sleep message /// this allows the manager to wake it up when there is work to be scheduled #[derive(Debug)] struct SleepMessage { sender: tokio::sync::oneshot::Sender, } struct NewTenantMessage { tenant: String, channel: tokio::sync::mpsc::Receiver>, }