Spaces:
Runtime error
Runtime error
/// Text Generation Inference webserver entrypoint | |
use axum::http::HeaderValue; | |
use clap::Parser; | |
use opentelemetry::sdk::propagation::TraceContextPropagator; | |
use opentelemetry::sdk::trace; | |
use opentelemetry::sdk::trace::Sampler; | |
use opentelemetry::sdk::Resource; | |
use opentelemetry::{global, KeyValue}; | |
use opentelemetry_otlp::WithExportConfig; | |
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; | |
use std::path::Path; | |
use text_generation_client::ShardedClient; | |
use text_generation_router::{server, HubModelInfo}; | |
use tokenizers::{FromPretrainedParameters, Tokenizer}; | |
use tower_http::cors::AllowOrigin; | |
use tracing_subscriber::layer::SubscriberExt; | |
use tracing_subscriber::util::SubscriberInitExt; | |
use tracing_subscriber::{EnvFilter, Layer}; | |
/// App Configuration | |
struct Args { | |
max_concurrent_requests: usize, | |
max_best_of: usize, | |
max_stop_sequences: usize, | |
max_input_length: usize, | |
max_total_tokens: usize, | |
max_batch_size: Option<usize>, | |
waiting_served_ratio: f32, | |
max_batch_total_tokens: u32, | |
max_waiting_tokens: usize, | |
port: u16, | |
master_shard_uds_path: String, | |
tokenizer_name: String, | |
revision: String, | |
validation_workers: usize, | |
json_output: bool, | |
otlp_endpoint: Option<String>, | |
cors_allow_origin: Option<Vec<String>>, | |
} | |
fn main() -> Result<(), std::io::Error> { | |
// Get args | |
let args = Args::parse(); | |
// Pattern match configuration | |
let Args { | |
max_concurrent_requests, | |
max_best_of, | |
max_stop_sequences, | |
max_input_length, | |
max_total_tokens, | |
max_batch_size, | |
waiting_served_ratio, | |
mut max_batch_total_tokens, | |
max_waiting_tokens, | |
port, | |
master_shard_uds_path, | |
tokenizer_name, | |
revision, | |
validation_workers, | |
json_output, | |
otlp_endpoint, | |
cors_allow_origin, | |
} = args; | |
if validation_workers == 0 { | |
panic!("validation_workers must be > 0"); | |
} | |
// CORS allowed origins | |
// map to go inside the option and then map to parse from String to HeaderValue | |
// Finally, convert to AllowOrigin | |
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| { | |
AllowOrigin::list( | |
cors_allow_origin | |
.iter() | |
.map(|origin| origin.parse::<HeaderValue>().unwrap()), | |
) | |
}); | |
// Parse Huggingface hub token | |
let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); | |
// Tokenizer instance | |
// This will only be used to validate payloads | |
let local_path = Path::new(&tokenizer_name); | |
let local_model = local_path.exists() && local_path.is_dir(); | |
let tokenizer = if local_model { | |
// Load local tokenizer | |
Tokenizer::from_file(local_path.join("tokenizer.json")).ok() | |
} else { | |
// Download and instantiate tokenizer | |
// We need to download it outside of the Tokio runtime | |
let params = FromPretrainedParameters { | |
revision: revision.clone(), | |
auth_token: authorization_token.clone(), | |
..Default::default() | |
}; | |
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() | |
}; | |
// Launch Tokio runtime | |
tokio::runtime::Builder::new_multi_thread() | |
.enable_all() | |
.build() | |
.unwrap() | |
.block_on(async { | |
init_logging(otlp_endpoint, json_output); | |
if let Some(max_batch_size) = max_batch_size{ | |
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); | |
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; | |
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); | |
} | |
if tokenizer.is_none() { | |
tracing::warn!( | |
"Could not find a fast tokenizer implementation for {tokenizer_name}" | |
); | |
tracing::warn!("Rust input length validation and truncation is disabled"); | |
} | |
// Get Model info | |
let model_info = match local_model { | |
true => HubModelInfo { | |
model_id: tokenizer_name.clone(), | |
sha: None, | |
pipeline_tag: None, | |
}, | |
false => get_model_info(&tokenizer_name, &revision, authorization_token).await, | |
}; | |
// if pipeline-tag == text-generation we default to return_full_text = true | |
let compat_return_full_text = match &model_info.pipeline_tag { | |
None => { | |
tracing::warn!("no pipeline tag found for model {tokenizer_name}"); | |
false | |
} | |
Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", | |
}; | |
// Instantiate sharded client from the master unix socket | |
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) | |
.await | |
.expect("Could not connect to server"); | |
// Clear the cache; useful if the webserver rebooted | |
sharded_client | |
.clear_cache(None) | |
.await | |
.expect("Unable to clear cache"); | |
// Get info from the shard | |
let shard_info = sharded_client | |
.info() | |
.await | |
.expect("Unable to get shard info"); | |
tracing::info!("Connected"); | |
// Binds on localhost | |
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); | |
// Run server | |
server::run( | |
model_info, | |
shard_info, | |
compat_return_full_text, | |
max_concurrent_requests, | |
max_best_of, | |
max_stop_sequences, | |
max_input_length, | |
max_total_tokens, | |
waiting_served_ratio, | |
max_batch_total_tokens, | |
max_waiting_tokens, | |
sharded_client, | |
tokenizer, | |
validation_workers, | |
addr, | |
cors_allow_origin, | |
) | |
.await; | |
Ok(()) | |
}) | |
} | |
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT: | |
/// - otlp_endpoint is an optional URL to an Open Telemetry collector | |
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) | |
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT) | |
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) { | |
let mut layers = Vec::new(); | |
// STDOUT/STDERR layer | |
let fmt_layer = tracing_subscriber::fmt::layer() | |
.with_file(true) | |
.with_line_number(true); | |
let fmt_layer = match json_output { | |
true => fmt_layer.json().flatten_event(true).boxed(), | |
false => fmt_layer.boxed(), | |
}; | |
layers.push(fmt_layer); | |
// OpenTelemetry tracing layer | |
if let Some(otlp_endpoint) = otlp_endpoint { | |
global::set_text_map_propagator(TraceContextPropagator::new()); | |
let tracer = opentelemetry_otlp::new_pipeline() | |
.tracing() | |
.with_exporter( | |
opentelemetry_otlp::new_exporter() | |
.tonic() | |
.with_endpoint(otlp_endpoint), | |
) | |
.with_trace_config( | |
trace::config() | |
.with_resource(Resource::new(vec![KeyValue::new( | |
"service.name", | |
"text-generation-inference.router", | |
)])) | |
.with_sampler(Sampler::AlwaysOn), | |
) | |
.install_batch(opentelemetry::runtime::Tokio); | |
if let Ok(tracer) = tracer { | |
layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); | |
axum_tracing_opentelemetry::init_propagator().unwrap(); | |
}; | |
} | |
// Filter events with LOG_LEVEL | |
let env_filter = | |
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); | |
tracing_subscriber::registry() | |
.with(env_filter) | |
.with(layers) | |
.init(); | |
} | |
/// get model info from the Huggingface Hub | |
pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo { | |
let client = reqwest::Client::new(); | |
let mut builder = client.get(format!( | |
"https://huggingface.co/api/models/{model_id}/revision/{revision}" | |
)); | |
if let Some(token) = token { | |
builder = builder.bearer_auth(token); | |
} | |
let model_info = builder | |
.send() | |
.await | |
.expect("Could not connect to hf.co") | |
.text() | |
.await | |
.expect("error when retrieving model info from hf.co"); | |
serde_json::from_str(&model_info).expect("unable to parse model info") | |
} | |