Llama-3.2-3B-ONNX-INT8-StrongTowerApps-Research

πŸ›‘οΈ Project Description

This repository contains an optimized version of Meta's Llama-3.2-3B model. As part of independent research, the model has been transformed to the ONNX format and dynamically quantized to Int8, achieving a drastic size reduction from an intermediate state of ~27GB to just 3.4GB.

🧠 Transformation Process: From Dynamic Code to Static Graph

To achieve the necessary portability and efficiency in local environments, the model went through a critical "compilation" phase of its architecture:

  1. Static Graph Conversion: The dynamic execution logic (PyTorch) was transformed into a static ONNX mathematical graph. This means that each of the operations and connections between the 28 layers of the model was explicitly defined, eliminating the reliance on the Python interpreter during inference.
  2. KV Cache Integration: The text-generation-with-past task was incorporated, integrating the memory logic (past_key_values) directly into the graph. This allows the model to be significantly faster by maintaining the conversation context.
  3. Intermediate Technical Expansion: During this process, the original ~12GB model expanded to 27GB. This growth was a necessary technical step due to loop unrolling to optimize CPU performance and comprehensive serialization of the Protobuf format. This "expanded" version was the essential foundation for the subsequent pruning and final quantization.

βš™οΈ Technical Details

  • Base: meta-llama/Llama-3.2-3B
  • Final Format: ONNX (External Data)
  • Optimization: Dynamic Quantization (QUInt8)
  • Usage: Specially designed for local execution on CPU using onnxruntime.

πŸš€ Usage Example

Below is a Python script (ask-model.py) demonstrating how to load the ONNX model and the tokenizer to generate text on a CPU.

Prerequisites

Make sure to install the required libraries:

pip install onnxruntime numpy transformers

Code Implementation

import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
import time

# --- Configuration ---
model_path = "Llama-3.2-3B-ONNX-INT8-StrongTowerApps-Research/model_quantized.onnx"

tokenizer_name = "Llama-3.2-3B-ONNX-INT8-StrongTowerApps-Research" 
max_new_tokens = 500 # Limit of words to generate (increased for more complete answers)

try:
    print("1. Loading the tokenizer from the local folder...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    print("2. Loading the ONNX model on CPU...")
    session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    
    # Extract metadata to map the KV Cache (past_key_values)
    input_names = [i.name for i in session.get_inputs()]
    output_names = [o.name for o in session.get_outputs()]
    past_kv_names = [name for name in input_names if "past_key_values" in name]
    
    # --- Prepare the Prompt ---
    prompt = "What are the benefits of artificial intelligence in Cybersecurity?"
    print(f"\n[User]: {prompt}\n")
    print("[Llama-3.2-3B-ONNX]: ", end="", flush=True)
    
    # Tokenize the input text
    inputs = tokenizer(prompt, return_tensors="np")
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)
    seq_len = input_ids.shape[1]
    
    # position_ids: [0, 1, 2, ..., seq_len - 1]
    position_ids = np.arange(0, seq_len, dtype=np.int64).reshape(1, seq_len)
    
    # Initialize the input dictionary for ONNX
    ort_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "position_ids": position_ids
    }
    
    # Initialize empty 'past_key_values' (sequence length = 0)
    for input_meta in session.get_inputs():
        if "past_key_values" in input_meta.name:
            # Reconstruct the shape: [batch_size, num_heads, 0, head_dim]
            shape = [dim if isinstance(dim, int) else (0 if i == 2 else 1) for i, dim in enumerate(input_meta.shape)]
            dtype = np.float32
            if 'int64' in input_meta.type: dtype = np.int64
            elif 'int32' in input_meta.type: dtype = np.int32
            elif 'float16' in input_meta.type: dtype = np.float16
            ort_inputs[input_meta.name] = np.zeros(shape, dtype=dtype)
            
    # --- Inference Cycle (Generation Loop) ---
    start_time = time.time()
    
    for step in range(max_new_tokens):
        # Execute the model
        outputs = session.run(None, ort_inputs)
        
        # outputs[0] are the logits (predictions). We extract the last token.
        logits = outputs[0]
        next_token_id = np.argmax(logits[:, -1, :], axis=-1)[0]
        
        # Print the generated word in real time
        word = tokenizer.decode([next_token_id])
        print(word, end="", flush=True)
        
        # If the model predicts the end of the response, we stop the cycle
        if next_token_id == tokenizer.eos_token_id:
            break
            
        # --- Update Inputs for the next step (Using KV Cache) ---
        ort_inputs["input_ids"] = np.array([[next_token_id]], dtype=np.int64)
        ort_inputs["attention_mask"] = np.concatenate([ort_inputs["attention_mask"], np.ones((1, 1), dtype=np.int64)], axis=1)
        ort_inputs["position_ids"] = np.array([[seq_len + step]], dtype=np.int64)
        for past_name, present_value in zip(past_kv_names, outputs[1:]):
            ort_inputs[past_name] = present_value
            
    print(f"\n\n[INFO] Generation time: {time.time() - start_time:.2f} seconds.")
    
except Exception as e:
    print(f"\nAn error occurred: {e}")

βš–οΈ Licenses

The quantization processes, the optimization pipeline, including static graph transformation and technical compilation, was engineered by Strong Tower Appsβ„’. This distribution leverages ONNX Runtime for high-performance inference, while the underlying model intelligence remains the property of Meta under the Llama 3.2 Community License.


Researcher: Strong Tower Appsβ„’

Downloads last month
3
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for strongtowerapps/Llama-3.2-3B-ONNX-INT8-StrongTowerApps-Research

Quantized
(135)
this model