Spaces:
Running
Running
File size: 7,922 Bytes
287a0bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
// A scheduler recieves embedding records for a given batch of documents
// and schedules them to be ingested to the segment manager
use crate::{
system::{Component, ComponentContext, Handler, Receiver},
types::EmbeddingRecord,
};
use async_trait::async_trait;
use rand::prelude::SliceRandom;
use rand::Rng;
use std::{
collections::{btree_map::Range, HashMap},
fmt::{Debug, Formatter, Result},
sync::Arc,
};
pub(crate) struct RoundRobinScheduler {
// The segment manager to schedule to, a segment manager is a component
// segment_manager: SegmentManager
curr_wake_up: Option<tokio::sync::oneshot::Sender<WakeMessage>>,
tenant_to_queue: HashMap<String, tokio::sync::mpsc::Sender<Box<EmbeddingRecord>>>,
new_tenant_channel: Option<tokio::sync::mpsc::Sender<NewTenantMessage>>,
subscribers: Option<Vec<Box<dyn Receiver<Box<EmbeddingRecord>>>>>,
}
impl Debug for RoundRobinScheduler {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("Scheduler").finish()
}
}
impl RoundRobinScheduler {
pub(crate) fn new() -> Self {
RoundRobinScheduler {
curr_wake_up: None,
tenant_to_queue: HashMap::new(),
new_tenant_channel: None,
subscribers: Some(Vec::new()),
}
}
pub(crate) fn subscribe(&mut self, subscriber: Box<dyn Receiver<Box<EmbeddingRecord>>>) {
match self.subscribers {
Some(ref mut subscribers) => {
subscribers.push(subscriber);
}
None => {}
}
}
}
impl Component for RoundRobinScheduler {
fn queue_size(&self) -> usize {
1000
}
fn on_start(&mut self, ctx: &ComponentContext<Self>) {
let sleep_sender = ctx.sender.clone();
let (new_tenant_tx, mut new_tenant_rx) = tokio::sync::mpsc::channel(1000);
self.new_tenant_channel = Some(new_tenant_tx);
let cancellation_token = ctx.cancellation_token.clone();
let subscribers = self.subscribers.take();
let mut subscribers = match subscribers {
Some(subscribers) => subscribers,
None => {
// TODO: log + error
return;
}
};
tokio::spawn(async move {
let mut tenant_queues: HashMap<
String,
tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>,
> = HashMap::new();
loop {
// TODO: handle cancellation
let mut did_work = false;
for tenant_queue in tenant_queues.values_mut() {
match tenant_queue.try_recv() {
Ok(message) => {
// Randomly pick a subscriber to send the message to
// This serves as a crude load balancing between available threads
// Future improvements here could be
// - Use a work stealing scheduler
// - Use rayon
// - We need to enforce partial order over writes to a given key
// so we need a mechanism to ensure that all writes to a given key
// occur in order
let mut subscriber = None;
{
let mut rng = rand::thread_rng();
subscriber = subscribers.choose_mut(&mut rng);
}
match subscriber {
Some(subscriber) => {
let res = subscriber.send(message).await;
}
None => {}
}
did_work = true;
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
continue;
}
Err(_) => {
// TODO: Handle a erroneous channel
// log an error
continue;
}
};
}
match new_tenant_rx.try_recv() {
Ok(new_tenant_message) => {
tenant_queues.insert(new_tenant_message.tenant, new_tenant_message.channel);
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {
// no - op
}
Err(_) => {
// TODO: handle erroneous channel
// log an error
continue;
}
};
if !did_work {
// Send a sleep message to the sender
let (wake_tx, wake_rx) = tokio::sync::oneshot::channel();
let sleep_res = sleep_sender.send(SleepMessage { sender: wake_tx }).await;
let wake_res = wake_rx.await;
}
}
});
}
}
#[async_trait]
impl Handler<(String, Box<EmbeddingRecord>)> for RoundRobinScheduler {
async fn handle(
&mut self,
message: (String, Box<EmbeddingRecord>),
_ctx: &ComponentContext<Self>,
) {
let (tenant, embedding_record) = message;
// Check if the tenant is already in the tenant set, if not we need to inform the scheduler loop
// of a new tenant
if self.tenant_to_queue.get(&tenant).is_none() {
// Create a new channel for the tenant
let (sender, reciever) = tokio::sync::mpsc::channel(1000);
// Add the tenant to the tenant set
self.tenant_to_queue.insert(tenant.clone(), sender);
// Send the new tenant message to the scheduler loop
let new_tenant_channel = match self.new_tenant_channel {
Some(ref mut channel) => channel,
None => {
// TODO: this is an error
// It should always be populated by on_start
return;
}
};
let res = new_tenant_channel
.send(NewTenantMessage {
tenant: tenant.clone(),
channel: reciever,
})
.await;
// TODO: handle this res
}
// Send the embedding record to the tenant's channel
let res = self
.tenant_to_queue
.get(&tenant)
.unwrap()
.send(embedding_record)
.await;
// TODO: handle this res
// Check if the scheduler is sleeping, if so wake it up
// TODO: we need to init with a wakeup otherwise we are off by one
if self.curr_wake_up.is_some() {
// Send a wake up message to the scheduler loop
let res = self.curr_wake_up.take().unwrap().send(WakeMessage {});
// TOOD: handle this res
}
}
}
#[async_trait]
impl Handler<SleepMessage> for RoundRobinScheduler {
async fn handle(&mut self, message: SleepMessage, _ctx: &ComponentContext<Self>) {
// Set the current wake up channel
self.curr_wake_up = Some(message.sender);
}
}
/// Used by round robin scheduler to wake its scheduler loop
#[derive(Debug)]
struct WakeMessage {}
/// The round robin scheduler will sleep when there is no work to be done and send a sleep message
/// this allows the manager to wake it up when there is work to be scheduled
#[derive(Debug)]
struct SleepMessage {
sender: tokio::sync::oneshot::Sender<WakeMessage>,
}
struct NewTenantMessage {
tenant: String,
channel: tokio::sync::mpsc::Receiver<Box<EmbeddingRecord>>,
}
|