File size: 5,596 Bytes
d0853a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// granite-docling ONNX Rust Example with ORT crate
// Demonstrates how to use granite-docling ONNX model in Rust applications

use anyhow::Result;
use ort::{
    execution_providers::ExecutionProvider,
    session::{Session, builder::GraphOptimizationLevel},
    inputs, value::TensorRef,
};
use ndarray::{Array1, Array2, Array4};

/// granite-docling ONNX inference engine
pub struct GraniteDoclingONNX {
    session: Session,
}

impl GraniteDoclingONNX {
    /// Load granite-docling ONNX model
    pub fn new(model_path: &str) -> Result<Self> {
        println!("Loading granite-docling ONNX model from: {}", model_path);

        let session = Session::builder()?
            .with_optimization_level(GraphOptimizationLevel::Level3)?
            .with_execution_providers([
                ExecutionProvider::DirectML,  // Windows ML acceleration
                ExecutionProvider::CUDA,      // NVIDIA acceleration
                ExecutionProvider::CPU,       // Universal fallback
            ])?
            .commit_from_file(model_path)?;

        // Print model information
        println!("Model loaded successfully:");
        for (i, input) in session.inputs()?.iter().enumerate() {
            println!("  Input {}: {} {:?}", i, input.name(), input.input_type());
        }
        for (i, output) in session.outputs()?.iter().enumerate() {
            println!("  Output {}: {} {:?}", i, output.name(), output.output_type());
        }

        Ok(Self { session })
    }

    /// Process document image to DocTags markup
    pub async fn process_document(
        &self,
        document_image: Array4<f32>,  // [batch, channels, height, width]
        prompt: &str,
    ) -> Result<String> {

        println!("Processing document with granite-docling...");

        // Prepare text inputs (simplified tokenization)
        let input_ids = self.tokenize_prompt(prompt)?;
        let attention_mask = Array2::ones((1, input_ids.len()));

        // Convert to required input format
        let input_ids_2d = Array2::from_shape_vec(
            (1, input_ids.len()),
            input_ids.iter().map(|&x| x as i64).collect(),
        )?;

        // Run inference
        let outputs = self.session.run(inputs![
            "pixel_values" => TensorRef::from_array_view(&document_image.view())?,
            "input_ids" => TensorRef::from_array_view(&input_ids_2d.view())?,
            "attention_mask" => TensorRef::from_array_view(&attention_mask.view())?,
        ])?;

        // Extract logits and decode to text
        let logits = outputs["logits"].try_extract_tensor::<f32>()?;
        let tokens = self.decode_logits_to_tokens(&logits)?;
        let doctags = self.detokenize_to_doctags(&tokens)?;

        println!("✅ Document processing complete");
        Ok(doctags)
    }

    /// Simple tokenization (in practice, use proper tokenizer)
    fn tokenize_prompt(&self, prompt: &str) -> Result<Vec<u32>> {
        // Simplified tokenization - in practice, load tokenizer.json
        // and use proper HuggingFace tokenization
        let tokens: Vec<u32> = prompt
            .split_whitespace()
            .enumerate()
            .map(|(i, _)| (i + 1) as u32)
            .collect();

        Ok(tokens)
    }

    /// Decode logits to most likely tokens
    fn decode_logits_to_tokens(&self, logits: &ndarray::ArrayViewD<f32>) -> Result<Vec<u32>> {
        // Find argmax for each position
        let tokens: Vec<u32> = logits
            .axis_iter(ndarray::Axis(2))
            .map(|logit_slice| {
                logit_slice
                    .iter()
                    .enumerate()
                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
                    .map(|(idx, _)| idx as u32)
                    .unwrap_or(0)
            })
            .collect();

        Ok(tokens)
    }

    /// Convert tokens back to DocTags markup
    fn detokenize_to_doctags(&self, tokens: &[u32]) -> Result<String> {
        // In practice, use granite-docling tokenizer to convert tokens → text
        // Then parse the text as DocTags markup

        // Simplified example
        let mock_doctags = format!(
            "<doctag>\n  <text>Document processed with {} tokens</text>\n</doctag>",
            tokens.len()
        );

        Ok(mock_doctags)
    }
}

/// Preprocess document image for granite-docling inference
pub fn preprocess_document_image(image_path: &str) -> Result<Array4<f32>> {
    // Load image and resize to 512x512 (SigLIP2 requirement)
    // Normalize with SigLIP2 parameters
    // Convert to [batch, channels, height, width] format

    // Simplified example - in practice, use image processing library
    let document_image = Array4::zeros((1, 3, 512, 512));

    Ok(document_image)
}

#[tokio::main]
async fn main() -> Result<()> {
    println!("granite-docling ONNX Rust Example");

    // Load granite-docling ONNX model
    let model_path = "granite-docling-258M-onnx/model.onnx";
    let granite_docling = GraniteDoclingONNX::new(model_path)?;

    // Preprocess document image
    let document_image = preprocess_document_image("example_document.png")?;

    // Process document
    let prompt = "Convert this document to DocTags:";
    let doctags = granite_docling.process_document(document_image, prompt).await?;

    println!("Generated DocTags:");
    println!("{}", doctags);

    Ok(())
}

// Cargo.toml dependencies:
/*
[dependencies]
ort = { version = "2.0.0-rc.10", features = ["directml", "cuda", "tensorrt"] }
ndarray = "0.15"
anyhow = "1.0"
tokio = { version = "1.0", features = ["full"] }
*/