Spaces:
Runtime error
Runtime error
use std::sync::atomic::{AtomicBool, Ordering}; | |
use std::sync::Arc; | |
use text_generation_client::{ | |
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters, | |
}; | |
// Note: Request ids and batch ids cannot collide. | |
const LIVENESS_ID: u64 = u64::MAX; | |
const BATCH_ID: u64 = u64::MAX; | |
pub(crate) struct Health { | |
client: ShardedClient, | |
generation_health: Arc<AtomicBool>, | |
} | |
impl Health { | |
pub(crate) fn new(client: ShardedClient, generation_health: Arc<AtomicBool>) -> Self { | |
Self { | |
client, | |
generation_health, | |
} | |
} | |
pub(crate) async fn check(&mut self) -> bool { | |
if self.generation_health.load(Ordering::SeqCst) { | |
// Generation is healthy, we only check that the shards are answering gRPC calls | |
self.client.health().await.is_ok() | |
} else { | |
// Generation is unhealthy or have not sent any generation request yet | |
// Dummy batch of 1 token and 1 generated token | |
let liveness_request = Request { | |
id: LIVENESS_ID, | |
inputs: "liveness".to_string(), | |
truncate: 10, | |
parameters: Some(NextTokenChooserParameters { | |
temperature: 1.0, | |
top_k: 0, | |
top_p: 1.0, | |
typical_p: 1.0, | |
do_sample: false, | |
seed: 0, | |
repetition_penalty: 1.0, | |
watermark: false, | |
}), | |
stopping_parameters: Some(StoppingCriteriaParameters { | |
max_new_tokens: 1, | |
stop_sequences: vec![], | |
ignore_eos_token: false, | |
}), | |
}; | |
let batch = Batch { | |
id: BATCH_ID, | |
requests: vec![liveness_request], | |
size: 1, | |
max_tokens: 2, | |
}; | |
// Skips the queue | |
let value = self.client.prefill(batch).await.is_ok(); | |
// Update generation health | |
self.generation_health.store(value, Ordering::SeqCst); | |
value | |
} | |
} | |
} | |