Spaces:
Runtime error
Runtime error
| /// Text Generation Inference webserver entrypoint | |
| use axum::http::HeaderValue; | |
| use clap::Parser; | |
| use opentelemetry::sdk::propagation::TraceContextPropagator; | |
| use opentelemetry::sdk::trace; | |
| use opentelemetry::sdk::trace::Sampler; | |
| use opentelemetry::sdk::Resource; | |
| use opentelemetry::{global, KeyValue}; | |
| use opentelemetry_otlp::WithExportConfig; | |
| use std::net::{IpAddr, Ipv4Addr, SocketAddr}; | |
| use std::path::Path; | |
| use text_generation_client::ShardedClient; | |
| use text_generation_router::{server, HubModelInfo}; | |
| use tokenizers::{FromPretrainedParameters, Tokenizer}; | |
| use tower_http::cors::AllowOrigin; | |
| use tracing_subscriber::layer::SubscriberExt; | |
| use tracing_subscriber::util::SubscriberInitExt; | |
| use tracing_subscriber::{EnvFilter, Layer}; | |
| /// App Configuration | |
| struct Args { | |
| max_concurrent_requests: usize, | |
| max_best_of: usize, | |
| max_stop_sequences: usize, | |
| max_input_length: usize, | |
| max_total_tokens: usize, | |
| max_batch_size: Option<usize>, | |
| waiting_served_ratio: f32, | |
| max_batch_total_tokens: u32, | |
| max_waiting_tokens: usize, | |
| port: u16, | |
| master_shard_uds_path: String, | |
| tokenizer_name: String, | |
| revision: String, | |
| validation_workers: usize, | |
| json_output: bool, | |
| otlp_endpoint: Option<String>, | |
| cors_allow_origin: Option<Vec<String>>, | |
| } | |
| fn main() -> Result<(), std::io::Error> { | |
| // Get args | |
| let args = Args::parse(); | |
| // Pattern match configuration | |
| let Args { | |
| max_concurrent_requests, | |
| max_best_of, | |
| max_stop_sequences, | |
| max_input_length, | |
| max_total_tokens, | |
| max_batch_size, | |
| waiting_served_ratio, | |
| mut max_batch_total_tokens, | |
| max_waiting_tokens, | |
| port, | |
| master_shard_uds_path, | |
| tokenizer_name, | |
| revision, | |
| validation_workers, | |
| json_output, | |
| otlp_endpoint, | |
| cors_allow_origin, | |
| } = args; | |
| if validation_workers == 0 { | |
| panic!("validation_workers must be > 0"); | |
| } | |
| // CORS allowed origins | |
| // map to go inside the option and then map to parse from String to HeaderValue | |
| // Finally, convert to AllowOrigin | |
| let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| { | |
| AllowOrigin::list( | |
| cors_allow_origin | |
| .iter() | |
| .map(|origin| origin.parse::<HeaderValue>().unwrap()), | |
| ) | |
| }); | |
| // Parse Huggingface hub token | |
| let authorization_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); | |
| // Tokenizer instance | |
| // This will only be used to validate payloads | |
| let local_path = Path::new(&tokenizer_name); | |
| let local_model = local_path.exists() && local_path.is_dir(); | |
| let tokenizer = if local_model { | |
| // Load local tokenizer | |
| Tokenizer::from_file(local_path.join("tokenizer.json")).ok() | |
| } else { | |
| // Download and instantiate tokenizer | |
| // We need to download it outside of the Tokio runtime | |
| let params = FromPretrainedParameters { | |
| revision: revision.clone(), | |
| auth_token: authorization_token.clone(), | |
| ..Default::default() | |
| }; | |
| Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).ok() | |
| }; | |
| // Launch Tokio runtime | |
| tokio::runtime::Builder::new_multi_thread() | |
| .enable_all() | |
| .build() | |
| .unwrap() | |
| .block_on(async { | |
| init_logging(otlp_endpoint, json_output); | |
| if let Some(max_batch_size) = max_batch_size{ | |
| tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead"); | |
| max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32; | |
| tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}"); | |
| } | |
| if tokenizer.is_none() { | |
| tracing::warn!( | |
| "Could not find a fast tokenizer implementation for {tokenizer_name}" | |
| ); | |
| tracing::warn!("Rust input length validation and truncation is disabled"); | |
| } | |
| // Get Model info | |
| let model_info = match local_model { | |
| true => HubModelInfo { | |
| model_id: tokenizer_name.clone(), | |
| sha: None, | |
| pipeline_tag: None, | |
| }, | |
| false => get_model_info(&tokenizer_name, &revision, authorization_token).await, | |
| }; | |
| // if pipeline-tag == text-generation we default to return_full_text = true | |
| let compat_return_full_text = match &model_info.pipeline_tag { | |
| None => { | |
| tracing::warn!("no pipeline tag found for model {tokenizer_name}"); | |
| false | |
| } | |
| Some(pipeline_tag) => pipeline_tag.as_str() == "text-generation", | |
| }; | |
| // Instantiate sharded client from the master unix socket | |
| let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) | |
| .await | |
| .expect("Could not connect to server"); | |
| // Clear the cache; useful if the webserver rebooted | |
| sharded_client | |
| .clear_cache(None) | |
| .await | |
| .expect("Unable to clear cache"); | |
| // Get info from the shard | |
| let shard_info = sharded_client | |
| .info() | |
| .await | |
| .expect("Unable to get shard info"); | |
| tracing::info!("Connected"); | |
| // Binds on localhost | |
| let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port); | |
| // Run server | |
| server::run( | |
| model_info, | |
| shard_info, | |
| compat_return_full_text, | |
| max_concurrent_requests, | |
| max_best_of, | |
| max_stop_sequences, | |
| max_input_length, | |
| max_total_tokens, | |
| waiting_served_ratio, | |
| max_batch_total_tokens, | |
| max_waiting_tokens, | |
| sharded_client, | |
| tokenizer, | |
| validation_workers, | |
| addr, | |
| cors_allow_origin, | |
| ) | |
| .await; | |
| Ok(()) | |
| }) | |
| } | |
| /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: | |
| /// - otlp_endpoint is an optional URL to an Open Telemetry collector | |
| /// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO) | |
| /// - LOG_FORMAT may be TEXT or JSON (default to TEXT) | |
| fn init_logging(otlp_endpoint: Option<String>, json_output: bool) { | |
| let mut layers = Vec::new(); | |
| // STDOUT/STDERR layer | |
| let fmt_layer = tracing_subscriber::fmt::layer() | |
| .with_file(true) | |
| .with_line_number(true); | |
| let fmt_layer = match json_output { | |
| true => fmt_layer.json().flatten_event(true).boxed(), | |
| false => fmt_layer.boxed(), | |
| }; | |
| layers.push(fmt_layer); | |
| // OpenTelemetry tracing layer | |
| if let Some(otlp_endpoint) = otlp_endpoint { | |
| global::set_text_map_propagator(TraceContextPropagator::new()); | |
| let tracer = opentelemetry_otlp::new_pipeline() | |
| .tracing() | |
| .with_exporter( | |
| opentelemetry_otlp::new_exporter() | |
| .tonic() | |
| .with_endpoint(otlp_endpoint), | |
| ) | |
| .with_trace_config( | |
| trace::config() | |
| .with_resource(Resource::new(vec![KeyValue::new( | |
| "service.name", | |
| "text-generation-inference.router", | |
| )])) | |
| .with_sampler(Sampler::AlwaysOn), | |
| ) | |
| .install_batch(opentelemetry::runtime::Tokio); | |
| if let Ok(tracer) = tracer { | |
| layers.push(tracing_opentelemetry::layer().with_tracer(tracer).boxed()); | |
| axum_tracing_opentelemetry::init_propagator().unwrap(); | |
| }; | |
| } | |
| // Filter events with LOG_LEVEL | |
| let env_filter = | |
| EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info")); | |
| tracing_subscriber::registry() | |
| .with(env_filter) | |
| .with(layers) | |
| .init(); | |
| } | |
| /// get model info from the Huggingface Hub | |
| pub async fn get_model_info(model_id: &str, revision: &str, token: Option<String>) -> HubModelInfo { | |
| let client = reqwest::Client::new(); | |
| let mut builder = client.get(format!( | |
| "https://huggingface.co/api/models/{model_id}/revision/{revision}" | |
| )); | |
| if let Some(token) = token { | |
| builder = builder.bearer_auth(token); | |
| } | |
| let model_info = builder | |
| .send() | |
| .await | |
| .expect("Could not connect to hf.co") | |
| .text() | |
| .await | |
| .expect("error when retrieving model info from hf.co"); | |
| serde_json::from_str(&model_info).expect("unable to parse model info") | |
| } | |