|
import os |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Set |
|
|
|
import onnx |
|
from onnxruntime.quantization import ( |
|
quantize_dynamic, |
|
QuantType |
|
) |
|
|
|
from optimum.exporters.tasks import TasksManager |
|
from transformers import ( |
|
AutoConfig, |
|
HfArgumentParser |
|
) |
|
|
|
DEFAULT_QUANTIZE_PARAMS = { |
|
'per_channel': True, |
|
'reduce_range': True, |
|
} |
|
|
|
MODEL_SPECIFIC_QUANTIZE_PARAMS = { |
|
'whisper': { |
|
'per_channel': False, |
|
'reduce_range': False, |
|
} |
|
} |
|
|
|
MODELS_WITHOUT_TOKENIZERS = [ |
|
'wav2vec2' |
|
] |
|
|
|
|
|
@dataclass |
|
class ConversionArguments: |
|
""" |
|
Arguments used for converting HuggingFace models to onnx. |
|
""" |
|
|
|
model_id: str = field( |
|
metadata={ |
|
"help": "Model identifier" |
|
} |
|
) |
|
quantize: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to quantize the model." |
|
} |
|
) |
|
output_parent_dir: str = field( |
|
default='./models/', |
|
metadata={ |
|
"help": "Path where the converted model will be saved to." |
|
} |
|
) |
|
|
|
task: Optional[str] = field( |
|
default='auto', |
|
metadata={ |
|
"help": ( |
|
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:" |
|
f" {str(list(TasksManager._TASKS_TO_AUTOMODELS.keys()))}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder." |
|
) |
|
} |
|
) |
|
|
|
opset: int = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used." |
|
) |
|
} |
|
) |
|
|
|
device: str = field( |
|
default='cpu', |
|
metadata={ |
|
"help": 'The device to use to do the export.' |
|
} |
|
) |
|
skip_validation: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to skip validation of the converted model" |
|
} |
|
) |
|
|
|
per_channel: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to quantize weights per channel" |
|
} |
|
) |
|
reduce_range: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode" |
|
} |
|
) |
|
|
|
output_attentions: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now." |
|
} |
|
) |
|
|
|
split_modalities: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to split multimodal models. NOTE: This is only supported for CLIP models right now." |
|
} |
|
) |
|
|
|
|
|
def get_operators(model: onnx.ModelProto) -> Set[str]: |
|
operators = set() |
|
|
|
def traverse_graph(graph): |
|
for node in graph.node: |
|
operators.add(node.op_type) |
|
for attr in node.attribute: |
|
if attr.type == onnx.AttributeProto.GRAPH: |
|
subgraph = attr.g |
|
traverse_graph(subgraph) |
|
|
|
traverse_graph(model.graph) |
|
return operators |
|
|
|
|
|
def quantize(model_path): |
|
""" |
|
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU |
|
|
|
Uses unsigned ints for activation values, signed ints for weights, per |
|
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection |
|
it is faster on most CPU architectures |
|
Args: |
|
onnx_model_path: Path to location the exported ONNX model is stored |
|
Returns: The Path generated for the quantized |
|
""" |
|
directory_path = os.path.dirname(model_path) |
|
|
|
loaded_model = onnx.load_model(model_path) |
|
op_types = get_operators(loaded_model) |
|
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 |
|
print("quantizing to", weight_type) |
|
|
|
quantize_dynamic( |
|
model_input=model_path, |
|
model_output=os.path.join(directory_path, 'model-q8.onnx'), |
|
weight_type=weight_type, |
|
optimize_model=False, |
|
) |
|
|
|
|
|
def main(): |
|
""" |
|
Example usage: |
|
python quantize_onnx.py --model_id sentence-transformers/all-MiniLM-L6-v2-unquantized |
|
""" |
|
parser = HfArgumentParser( |
|
(ConversionArguments,) |
|
) |
|
conv_args, = parser.parse_args_into_dataclasses() |
|
|
|
model_id = conv_args.model_id |
|
|
|
quantize(os.path.join(model_id, "model.onnx")) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|