Spaces:
Build error
Build error
| use anyhow::Result; | |
| use clap::ValueEnum; | |
| use half::f16; | |
| use ndarray::{Array, CowArray, IxDyn}; | |
| use ort::{ | |
| CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch, | |
| TensorRTExecutionProvider, | |
| }; | |
| use ort::{Session, SessionBuilder}; | |
| use ort::{TensorElementType, ValueType}; | |
| use regex::Regex; | |
| pub enum YOLOTask { | |
| // YOLO tasks | |
| Classify, | |
| Detect, | |
| Pose, | |
| Segment, | |
| } | |
| pub enum OrtEP { | |
| // ONNXRuntime execution provider | |
| CPU, | |
| CUDA(i32), | |
| Trt(i32), | |
| } | |
| pub struct Batch { | |
| pub opt: u32, | |
| pub min: u32, | |
| pub max: u32, | |
| } | |
| impl Default for Batch { | |
| fn default() -> Self { | |
| Self { | |
| opt: 1, | |
| min: 1, | |
| max: 1, | |
| } | |
| } | |
| } | |
| pub struct OrtInputs { | |
| // ONNX model inputs attrs | |
| pub shapes: Vec<Vec<i64>>, | |
| //pub dtypes: Vec<TensorElementDataType>, | |
| pub dtypes: Vec<TensorElementType>, | |
| pub names: Vec<String>, | |
| pub sizes: Vec<Vec<u32>>, | |
| } | |
| impl OrtInputs { | |
| pub fn new(session: &Session) -> Self { | |
| let mut shapes = Vec::new(); | |
| let mut dtypes = Vec::new(); | |
| let mut names = Vec::new(); | |
| for i in session.inputs.iter() { | |
| /* let shape: Vec<i32> = i | |
| .dimensions() | |
| .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) | |
| .collect(); | |
| shapes.push(shape); */ | |
| if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { | |
| dtypes.push(ty.clone()); | |
| let shape = dimensions.clone(); | |
| shapes.push(shape); | |
| } else { | |
| panic!("不支持的数据格式, {} - {}", file!(), line!()); | |
| } | |
| //dtypes.push(i.input_type); | |
| names.push(i.name.clone()); | |
| } | |
| Self { | |
| shapes, | |
| dtypes, | |
| names, | |
| ..Default::default() | |
| } | |
| } | |
| } | |
| pub struct OrtConfig { | |
| // ORT config | |
| pub f: String, | |
| pub task: Option<YOLOTask>, | |
| pub ep: OrtEP, | |
| pub trt_fp16: bool, | |
| pub batch: Batch, | |
| pub image_size: (Option<u32>, Option<u32>), | |
| } | |
| pub struct OrtBackend { | |
| // ORT engine | |
| session: Session, | |
| task: YOLOTask, | |
| ep: OrtEP, | |
| batch: Batch, | |
| inputs: OrtInputs, | |
| } | |
| impl OrtBackend { | |
| pub fn build(args: OrtConfig) -> Result<Self> { | |
| // build env & session | |
| // in version 2.x environment is removed | |
| /* let env = ort::EnvironmentBuilder | |
| ::with_name("YOLOv8") | |
| .build()? | |
| .into_arc(); */ | |
| let sessionbuilder = SessionBuilder::new()?; | |
| let session = sessionbuilder.commit_from_file(&args.f)?; | |
| //let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?; | |
| // get inputs | |
| let mut inputs = OrtInputs::new(&session); | |
| // batch size | |
| let mut batch = args.batch; | |
| let batch = if inputs.shapes[0][0] == -1 { | |
| batch | |
| } else { | |
| assert_eq!( | |
| inputs.shapes[0][0] as u32, batch.opt, | |
| "Expected batch size: {}, got {}. Try using `--batch {}`.", | |
| inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32 | |
| ); | |
| batch.opt = inputs.shapes[0][0] as u32; | |
| batch | |
| }; | |
| // input size: height and width | |
| let height = if inputs.shapes[0][2] == -1 { | |
| match args.image_size.0 { | |
| Some(height) => height, | |
| None => panic!("Failed to get model height. Make it explicit with `--height`"), | |
| } | |
| } else { | |
| inputs.shapes[0][2] as u32 | |
| }; | |
| let width = if inputs.shapes[0][3] == -1 { | |
| match args.image_size.1 { | |
| Some(width) => width, | |
| None => panic!("Failed to get model width. Make it explicit with `--width`"), | |
| } | |
| } else { | |
| inputs.shapes[0][3] as u32 | |
| }; | |
| inputs.sizes.push(vec![height, width]); | |
| // build provider | |
| let (ep, provider) = match args.ep { | |
| OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id), | |
| OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs), | |
| _ => ( | |
| OrtEP::CPU, | |
| ExecutionProviderDispatch::from(CPUExecutionProvider::default()), | |
| ), | |
| }; | |
| // build session again with the new provider | |
| let session = SessionBuilder::new()? | |
| // .with_optimization_level(ort::GraphOptimizationLevel::Level3)? | |
| .with_execution_providers([provider])? | |
| .commit_from_file(args.f)?; | |
| // task: using given one or guessing | |
| let task = match args.task { | |
| Some(task) => task, | |
| None => match session.metadata() { | |
| Err(_) => panic!("No metadata found. Try making it explicit by `--task`"), | |
| Ok(metadata) => match metadata.custom("task") { | |
| Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"), | |
| Ok(value) => match value { | |
| None => panic!("No corresponding value of `task` found in metadata. Make it explicit by `--task`"), | |
| Some(task) => match task.as_str() { | |
| "classify" => YOLOTask::Classify, | |
| "detect" => YOLOTask::Detect, | |
| "pose" => YOLOTask::Pose, | |
| "segment" => YOLOTask::Segment, | |
| x => todo!("{:?} is not supported for now!", x), | |
| }, | |
| }, | |
| }, | |
| }, | |
| }; | |
| Ok(Self { | |
| session, | |
| task, | |
| ep, | |
| batch, | |
| inputs, | |
| }) | |
| } | |
| pub fn fetch_inputs_from_session( | |
| session: &Session, | |
| ) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) { | |
| // get inputs attrs from ONNX model | |
| let mut shapes = Vec::new(); | |
| let mut dtypes = Vec::new(); | |
| let mut names = Vec::new(); | |
| for i in session.inputs.iter() { | |
| if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { | |
| dtypes.push(ty.clone()); | |
| let shape = dimensions.clone(); | |
| shapes.push(shape); | |
| } else { | |
| panic!("不支持的数据格式, {} - {}", file!(), line!()); | |
| } | |
| names.push(i.name.clone()); | |
| } | |
| (shapes, dtypes, names) | |
| } | |
| pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) { | |
| let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id); | |
| if let Ok(true) = cuda_provider.is_available() { | |
| ( | |
| OrtEP::CUDA(device_id), | |
| ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider) | |
| ) | |
| } else { | |
| println!("> CUDA is not available! Using CPU."); | |
| ( | |
| OrtEP::CPU, | |
| ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default()) | |
| ) | |
| } | |
| } | |
| pub fn set_ep_trt( | |
| device_id: i32, | |
| fp16: bool, | |
| batch: &Batch, | |
| inputs: &OrtInputs, | |
| ) -> (OrtEP, ExecutionProviderDispatch) { | |
| // set TensorRT | |
| let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id); | |
| //trt_provider. | |
| if let Ok(true) = trt_provider.is_available() { | |
| let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]); | |
| if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 { | |
| panic!( | |
| "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`", | |
| inputs.dtypes[0] | |
| ); | |
| } | |
| // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,... | |
| let mut opt_string = String::new(); | |
| let mut min_string = String::new(); | |
| let mut max_string = String::new(); | |
| for name in inputs.names.iter() { | |
| let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width); | |
| let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width); | |
| let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width); | |
| opt_string.push_str(s_opt.as_str()); | |
| min_string.push_str(s_min.as_str()); | |
| max_string.push_str(s_max.as_str()); | |
| } | |
| let _ = opt_string.pop(); | |
| let _ = min_string.pop(); | |
| let _ = max_string.pop(); | |
| let trt_provider = trt_provider | |
| .with_profile_opt_shapes(opt_string) | |
| .with_profile_min_shapes(min_string) | |
| .with_profile_max_shapes(max_string) | |
| .with_fp16(fp16) | |
| .with_timing_cache(true); | |
| ( | |
| OrtEP::Trt(device_id), | |
| ExecutionProviderDispatch::from(trt_provider), | |
| ) | |
| } else { | |
| println!("> TensorRT is not available! Try using CUDA..."); | |
| Self::set_ep_cuda(device_id) | |
| } | |
| } | |
| pub fn fetch_from_metadata(&self, key: &str) -> Option<String> { | |
| // fetch value from onnx model file by key | |
| match self.session.metadata() { | |
| Err(_) => None, | |
| Ok(metadata) => match metadata.custom(key) { | |
| Err(_) => None, | |
| Ok(value) => value, | |
| }, | |
| } | |
| } | |
| pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> { | |
| // ORT inference | |
| match self.dtype() { | |
| TensorElementType::Float16 => self.run_fp16(xs, profile), | |
| TensorElementType::Float32 => self.run_fp32(xs, profile), | |
| _ => todo!(), | |
| } | |
| } | |
| pub fn run_fp16(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> { | |
| // f32->f16 | |
| let t = std::time::Instant::now(); | |
| let xs = xs.mapv(f16::from_f32); | |
| if profile { | |
| println!("[ORT f32->f16]: {:?}", t.elapsed()); | |
| } | |
| // h2d | |
| let t = std::time::Instant::now(); | |
| let xs = CowArray::from(xs); | |
| if profile { | |
| println!("[ORT H2D]: {:?}", t.elapsed()); | |
| } | |
| // run | |
| let t = std::time::Instant::now(); | |
| let ys = self.session.run(ort::inputs![xs.view()]?)?; | |
| if profile { | |
| println!("[ORT Inference]: {:?}", t.elapsed()); | |
| } | |
| // d2h | |
| Ok(ys | |
| .iter() | |
| .map(|(_k, v)| { | |
| // d2h | |
| let t = std::time::Instant::now(); | |
| let v = v.try_extract_tensor().unwrap(); | |
| //let v = v.try_extract::<_>().unwrap().view().clone().into_owned(); | |
| if profile { | |
| println!("[ORT D2H]: {:?}", t.elapsed()); | |
| } | |
| // f16->f32 | |
| let t_ = std::time::Instant::now(); | |
| let v = v.mapv(f16::to_f32); | |
| if profile { | |
| println!("[ORT f16->f32]: {:?}", t_.elapsed()); | |
| } | |
| v | |
| }) | |
| .collect::<Vec<Array<_, _>>>()) | |
| } | |
| pub fn run_fp32(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> { | |
| // h2d | |
| let t = std::time::Instant::now(); | |
| let xs = CowArray::from(xs); | |
| if profile { | |
| println!("[ORT H2D]: {:?}", t.elapsed()); | |
| } | |
| // run | |
| let t = std::time::Instant::now(); | |
| let ys = self.session.run(ort::inputs![xs.view()]?)?; | |
| if profile { | |
| println!("[ORT Inference]: {:?}", t.elapsed()); | |
| } | |
| // d2h | |
| Ok(ys | |
| .iter() | |
| .map(|(_k, v)| { | |
| let t = std::time::Instant::now(); | |
| let v = v.try_extract_tensor::<f32>().unwrap().into_owned(); | |
| //let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); | |
| if profile { | |
| println!("[ORT D2H]: {:?}", t.elapsed()); | |
| } | |
| v | |
| }) | |
| .collect::<Vec<Array<_, _>>>()) | |
| } | |
| pub fn output_shapes(&self) -> Vec<Vec<i64>> { | |
| let mut shapes = Vec::new(); | |
| for output in &self.session.outputs { | |
| if let ValueType::Tensor { ty: _, dimensions } = &output.output_type { | |
| let shape = dimensions.clone(); | |
| shapes.push(shape); | |
| } else { | |
| panic!("not support data format, {} - {}", file!(), line!()); | |
| } | |
| } | |
| shapes | |
| } | |
| pub fn output_dtypes(&self) -> Vec<TensorElementType> { | |
| let mut dtypes = Vec::new(); | |
| for output in &self.session.outputs { | |
| if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type { | |
| dtypes.push(ty.clone()); | |
| } else { | |
| panic!("not support data format, {} - {}", file!(), line!()); | |
| } | |
| } | |
| dtypes | |
| } | |
| pub fn input_shapes(&self) -> &Vec<Vec<i64>> { | |
| &self.inputs.shapes | |
| } | |
| pub fn input_names(&self) -> &Vec<String> { | |
| &self.inputs.names | |
| } | |
| pub fn input_dtypes(&self) -> &Vec<TensorElementType> { | |
| &self.inputs.dtypes | |
| } | |
| pub fn dtype(&self) -> TensorElementType { | |
| self.input_dtypes()[0] | |
| } | |
| pub fn height(&self) -> u32 { | |
| self.inputs.sizes[0][0] | |
| } | |
| pub fn width(&self) -> u32 { | |
| self.inputs.sizes[0][1] | |
| } | |
| pub fn is_height_dynamic(&self) -> bool { | |
| self.input_shapes()[0][2] == -1 | |
| } | |
| pub fn is_width_dynamic(&self) -> bool { | |
| self.input_shapes()[0][3] == -1 | |
| } | |
| pub fn batch(&self) -> u32 { | |
| self.batch.opt | |
| } | |
| pub fn is_batch_dynamic(&self) -> bool { | |
| self.input_shapes()[0][0] == -1 | |
| } | |
| pub fn ep(&self) -> &OrtEP { | |
| &self.ep | |
| } | |
| pub fn task(&self) -> YOLOTask { | |
| self.task.clone() | |
| } | |
| pub fn names(&self) -> Option<Vec<String>> { | |
| // class names, metadata parsing | |
| // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` | |
| match self.fetch_from_metadata("names") { | |
| Some(names) => { | |
| let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); | |
| let mut names_ = vec![]; | |
| for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { | |
| names_.push(name.to_string()); | |
| } | |
| Some(names_) | |
| } | |
| None => None, | |
| } | |
| } | |
| pub fn nk(&self) -> Option<u32> { | |
| // num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]` | |
| match self.fetch_from_metadata("kpt_shape") { | |
| None => None, | |
| Some(kpt_string) => { | |
| let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); | |
| let caps = re.captures(&kpt_string).unwrap(); | |
| Some(caps.get(1).unwrap().as_str().parse::<u32>().unwrap()) | |
| } | |
| } | |
| } | |
| pub fn nc(&self) -> Option<u32> { | |
| // num_classes | |
| match self.names() { | |
| // by names | |
| Some(names) => Some(names.len() as u32), | |
| None => match self.task() { | |
| // by task calculation | |
| YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32), | |
| YOLOTask::Detect => { | |
| if self.output_shapes()[0][1] == -1 { | |
| None | |
| } else { | |
| // cxywhclss | |
| Some(self.output_shapes()[0][1] as u32 - 4) | |
| } | |
| } | |
| YOLOTask::Pose => { | |
| match self.nk() { | |
| None => None, | |
| Some(nk) => { | |
| if self.output_shapes()[0][1] == -1 { | |
| None | |
| } else { | |
| // cxywhclss3*kpt | |
| Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk) | |
| } | |
| } | |
| } | |
| } | |
| YOLOTask::Segment => { | |
| if self.output_shapes()[0][1] == -1 { | |
| None | |
| } else { | |
| // cxywhclssnm | |
| Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4) | |
| } | |
| } | |
| }, | |
| } | |
| } | |
| pub fn nm(&self) -> Option<u32> { | |
| // num_masks | |
| match self.task() { | |
| YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32), | |
| _ => None, | |
| } | |
| } | |
| pub fn na(&self) -> Option<u32> { | |
| // num_anchors | |
| match self.task() { | |
| YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => { | |
| if self.output_shapes()[0][2] == -1 { | |
| None | |
| } else { | |
| Some(self.output_shapes()[0][2] as u32) | |
| } | |
| } | |
| _ => None, | |
| } | |
| } | |
| pub fn author(&self) -> Option<String> { | |
| self.fetch_from_metadata("author") | |
| } | |
| pub fn version(&self) -> Option<String> { | |
| self.fetch_from_metadata("version") | |
| } | |
| } | |