mod health; /// Text Generation Inference Webserver mod infer; mod queue; pub mod server; mod validation; use infer::Infer; use queue::{Entry, Queue}; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; use validation::Validation; /// Hub type #[derive(Clone, Debug, Deserialize)] pub struct HubModelInfo { #[serde(rename(deserialize = "id"))] pub model_id: String, pub sha: Option, pub pipeline_tag: Option, } #[derive(Clone, Debug, Serialize, ToSchema)] pub struct Info { /// Model info #[schema(example = "bigscience/blomm-560m")] pub model_id: String, #[schema(nullable = true, example = "e985a63cdc139290c5f700ff1929f0b5942cced2")] pub model_sha: Option, #[schema(example = "torch.float16")] pub model_dtype: String, #[schema(example = "cuda")] pub model_device_type: String, #[schema(nullable = true, example = "text-generation")] pub model_pipeline_tag: Option, /// Router Parameters #[schema(example = "128")] pub max_concurrent_requests: usize, #[schema(example = "2")] pub max_best_of: usize, #[schema(example = "4")] pub max_stop_sequences: usize, #[schema(example = "1024")] pub max_input_length: usize, #[schema(example = "2048")] pub max_total_tokens: usize, #[schema(example = "1.2")] pub waiting_served_ratio: f32, #[schema(example = "32000")] pub max_batch_total_tokens: u32, #[schema(example = "20")] pub max_waiting_tokens: usize, #[schema(example = "2")] pub validation_workers: usize, /// Router Info #[schema(example = "0.5.0")] pub version: &'static str, #[schema(nullable = true, example = "null")] pub sha: Option<&'static str>, #[schema(nullable = true, example = "null")] pub docker_label: Option<&'static str>, } #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateParameters { #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 1)] pub best_of: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, nullable = true, default = "null", example = 0.5 )] pub temperature: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, nullable = true, default = "null", example = 1.03 )] pub repetition_penalty: Option, #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] pub top_k: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null", example = 0.95 )] pub top_p: Option, #[serde(default)] #[schema( exclusive_minimum = 0.0, maximum = 1.0, nullable = true, default = "null", example = 0.95 )] pub typical_p: Option, #[serde(default)] #[schema(default = "false", example = true)] pub do_sample: bool, #[serde(default = "default_max_new_tokens")] #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] #[schema(nullable = true, default = "null", example = false)] pub return_full_text: Option, #[serde(default)] #[schema(inline, max_items = 4, example = json ! (["photographer"]))] pub stop: Vec, #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub truncate: Option, #[serde(default)] #[schema(default = "false", example = true)] pub watermark: bool, #[serde(default)] #[schema(default = "true")] pub details: bool, #[serde(default)] #[schema( exclusive_minimum = 0, nullable = true, default = "null", example = "null" )] pub seed: Option, } fn default_max_new_tokens() -> u32 { 20 } fn default_parameters() -> GenerateParameters { GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, top_k: None, top_p: None, typical_p: None, do_sample: false, max_new_tokens: default_max_new_tokens(), return_full_text: None, stop: Vec::new(), truncate: None, watermark: false, details: false, seed: None, } } #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct GenerateRequest { #[schema(example = "My name is Olivier and I")] pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, } #[derive(Clone, Debug, Deserialize, ToSchema)] pub(crate) struct CompatGenerateRequest { #[schema(example = "My name is Olivier and I")] pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, #[serde(default)] #[allow(dead_code)] pub stream: bool, } impl From for GenerateRequest { fn from(req: CompatGenerateRequest) -> Self { Self { inputs: req.inputs, parameters: req.parameters, } } } #[derive(Debug, Serialize, ToSchema)] pub struct PrefillToken { #[schema(example = 0)] id: u32, #[schema(example = "test")] text: String, #[schema(nullable = true, example = - 0.34)] logprob: f32, } #[derive(Debug, Serialize, ToSchema)] pub struct Token { #[schema(example = 0)] id: u32, #[schema(example = "test")] text: String, #[schema(nullable = true, example = - 0.34)] logprob: f32, #[schema(example = "false")] special: bool, } #[derive(Serialize, ToSchema)] #[serde(rename_all(serialize = "snake_case"))] pub(crate) enum FinishReason { #[schema(rename = "length")] Length, #[serde(rename = "eos_token")] #[schema(rename = "eos_token")] EndOfSequenceToken, #[schema(rename = "stop_sequence")] StopSequence, } #[derive(Serialize, ToSchema)] pub(crate) struct BestOfSequence { #[schema(example = "test")] pub generated_text: String, #[schema(example = "length")] pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, pub prefill: Vec, pub tokens: Vec, } #[derive(Serialize, ToSchema)] pub(crate) struct Details { #[schema(example = "length")] pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, pub prefill: Vec, pub tokens: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub best_of_sequences: Option>, } #[derive(Serialize, ToSchema)] pub(crate) struct GenerateResponse { #[schema(example = "test")] pub generated_text: String, #[serde(skip_serializing_if = "Option::is_none")] pub details: Option
, } #[derive(Serialize, ToSchema)] pub(crate) struct StreamDetails { #[schema(example = "length")] pub finish_reason: FinishReason, #[schema(example = 1)] pub generated_tokens: u32, #[schema(nullable = true, example = 42)] pub seed: Option, } #[derive(Serialize, ToSchema)] pub(crate) struct StreamResponse { pub token: Token, #[schema(nullable = true, default = "null", example = "test")] pub generated_text: Option, #[schema(nullable = true, default = "null")] pub details: Option, } #[derive(Serialize, ToSchema)] pub(crate) struct ErrorResponse { pub error: String, pub error_type: String, } #[cfg(test)] mod tests { use std::io::Write; use tokenizers::Tokenizer; pub(crate) async fn get_tokenizer() -> Tokenizer { if !std::path::Path::new("tokenizer.json").exists() { let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json") .await .unwrap() .bytes() .await .unwrap(); let mut file = std::fs::File::create("tokenizer.json").unwrap(); file.write_all(&content).unwrap(); } Tokenizer::from_file("tokenizer.json").unwrap() } }