File size: 5,596 Bytes
d0853a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
// granite-docling ONNX Rust Example with ORT crate
// Demonstrates how to use granite-docling ONNX model in Rust applications
use anyhow::Result;
use ort::{
execution_providers::ExecutionProvider,
session::{Session, builder::GraphOptimizationLevel},
inputs, value::TensorRef,
};
use ndarray::{Array1, Array2, Array4};
/// granite-docling ONNX inference engine
pub struct GraniteDoclingONNX {
session: Session,
}
impl GraniteDoclingONNX {
/// Load granite-docling ONNX model
pub fn new(model_path: &str) -> Result<Self> {
println!("Loading granite-docling ONNX model from: {}", model_path);
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_execution_providers([
ExecutionProvider::DirectML, // Windows ML acceleration
ExecutionProvider::CUDA, // NVIDIA acceleration
ExecutionProvider::CPU, // Universal fallback
])?
.commit_from_file(model_path)?;
// Print model information
println!("Model loaded successfully:");
for (i, input) in session.inputs()?.iter().enumerate() {
println!(" Input {}: {} {:?}", i, input.name(), input.input_type());
}
for (i, output) in session.outputs()?.iter().enumerate() {
println!(" Output {}: {} {:?}", i, output.name(), output.output_type());
}
Ok(Self { session })
}
/// Process document image to DocTags markup
pub async fn process_document(
&self,
document_image: Array4<f32>, // [batch, channels, height, width]
prompt: &str,
) -> Result<String> {
println!("Processing document with granite-docling...");
// Prepare text inputs (simplified tokenization)
let input_ids = self.tokenize_prompt(prompt)?;
let attention_mask = Array2::ones((1, input_ids.len()));
// Convert to required input format
let input_ids_2d = Array2::from_shape_vec(
(1, input_ids.len()),
input_ids.iter().map(|&x| x as i64).collect(),
)?;
// Run inference
let outputs = self.session.run(inputs![
"pixel_values" => TensorRef::from_array_view(&document_image.view())?,
"input_ids" => TensorRef::from_array_view(&input_ids_2d.view())?,
"attention_mask" => TensorRef::from_array_view(&attention_mask.view())?,
])?;
// Extract logits and decode to text
let logits = outputs["logits"].try_extract_tensor::<f32>()?;
let tokens = self.decode_logits_to_tokens(&logits)?;
let doctags = self.detokenize_to_doctags(&tokens)?;
println!("✅ Document processing complete");
Ok(doctags)
}
/// Simple tokenization (in practice, use proper tokenizer)
fn tokenize_prompt(&self, prompt: &str) -> Result<Vec<u32>> {
// Simplified tokenization - in practice, load tokenizer.json
// and use proper HuggingFace tokenization
let tokens: Vec<u32> = prompt
.split_whitespace()
.enumerate()
.map(|(i, _)| (i + 1) as u32)
.collect();
Ok(tokens)
}
/// Decode logits to most likely tokens
fn decode_logits_to_tokens(&self, logits: &ndarray::ArrayViewD<f32>) -> Result<Vec<u32>> {
// Find argmax for each position
let tokens: Vec<u32> = logits
.axis_iter(ndarray::Axis(2))
.map(|logit_slice| {
logit_slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx as u32)
.unwrap_or(0)
})
.collect();
Ok(tokens)
}
/// Convert tokens back to DocTags markup
fn detokenize_to_doctags(&self, tokens: &[u32]) -> Result<String> {
// In practice, use granite-docling tokenizer to convert tokens → text
// Then parse the text as DocTags markup
// Simplified example
let mock_doctags = format!(
"<doctag>\n <text>Document processed with {} tokens</text>\n</doctag>",
tokens.len()
);
Ok(mock_doctags)
}
}
/// Preprocess document image for granite-docling inference
pub fn preprocess_document_image(image_path: &str) -> Result<Array4<f32>> {
// Load image and resize to 512x512 (SigLIP2 requirement)
// Normalize with SigLIP2 parameters
// Convert to [batch, channels, height, width] format
// Simplified example - in practice, use image processing library
let document_image = Array4::zeros((1, 3, 512, 512));
Ok(document_image)
}
#[tokio::main]
async fn main() -> Result<()> {
println!("granite-docling ONNX Rust Example");
// Load granite-docling ONNX model
let model_path = "granite-docling-258M-onnx/model.onnx";
let granite_docling = GraniteDoclingONNX::new(model_path)?;
// Preprocess document image
let document_image = preprocess_document_image("example_document.png")?;
// Process document
let prompt = "Convert this document to DocTags:";
let doctags = granite_docling.process_document(document_image, prompt).await?;
println!("Generated DocTags:");
println!("{}", doctags);
Ok(())
}
// Cargo.toml dependencies:
/*
[dependencies]
ort = { version = "2.0.0-rc.10", features = ["directml", "cuda", "tensorrt"] }
ndarray = "0.15"
anyhow = "1.0"
tokio = { version = "1.0", features = ["full"] }
*/ |