qMTEB / quantize_onnx.py
varun4's picture
quantizing scripts added
0606100
raw
history blame
4.68 kB
import os
from dataclasses import dataclass, field
from typing import Optional, Set
import onnx
from onnxruntime.quantization import (
quantize_dynamic,
QuantType
)
from optimum.exporters.tasks import TasksManager
from transformers import (
AutoConfig,
HfArgumentParser
)
DEFAULT_QUANTIZE_PARAMS = {
'per_channel': True,
'reduce_range': True,
}
MODEL_SPECIFIC_QUANTIZE_PARAMS = {
'whisper': {
'per_channel': False,
'reduce_range': False,
}
}
MODELS_WITHOUT_TOKENIZERS = [
'wav2vec2'
]
@dataclass
class ConversionArguments:
"""
Arguments used for converting HuggingFace models to onnx.
"""
model_id: str = field(
metadata={
"help": "Model identifier"
}
)
quantize: bool = field(
default=False,
metadata={
"help": "Whether to quantize the model."
}
)
output_parent_dir: str = field(
default='./models/',
metadata={
"help": "Path where the converted model will be saved to."
}
)
task: Optional[str] = field(
default='auto',
metadata={
"help": (
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:"
f" {str(list(TasksManager._TASKS_TO_AUTOMODELS.keys()))}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder."
)
}
)
opset: int = field(
default=None,
metadata={
"help": (
"If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used."
)
}
)
device: str = field(
default='cpu',
metadata={
"help": 'The device to use to do the export.'
}
)
skip_validation: bool = field(
default=False,
metadata={
"help": "Whether to skip validation of the converted model"
}
)
per_channel: bool = field(
default=None,
metadata={
"help": "Whether to quantize weights per channel"
}
)
reduce_range: bool = field(
default=None,
metadata={
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode"
}
)
output_attentions: bool = field(
default=False,
metadata={
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now."
}
)
split_modalities: bool = field(
default=False,
metadata={
"help": "Whether to split multimodal models. NOTE: This is only supported for CLIP models right now."
}
)
def get_operators(model: onnx.ModelProto) -> Set[str]:
operators = set()
def traverse_graph(graph):
for node in graph.node:
operators.add(node.op_type)
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
subgraph = attr.g
traverse_graph(subgraph)
traverse_graph(model.graph)
return operators
def quantize(model_path):
"""
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
directory_path = os.path.dirname(model_path)
loaded_model = onnx.load_model(model_path)
op_types = get_operators(loaded_model)
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8
print("quantizing to", weight_type)
quantize_dynamic(
model_input=model_path,
model_output=os.path.join(directory_path, 'model-q8.onnx'),
weight_type=weight_type,
optimize_model=False,
)
def main():
"""
Example usage:
python quantize_onnx.py --model_id sentence-transformers/all-MiniLM-L6-v2-unquantized
"""
parser = HfArgumentParser(
(ConversionArguments,)
)
conv_args, = parser.parse_args_into_dataclasses()
model_id = conv_args.model_id
quantize(os.path.join(model_id, "model.onnx"))
if __name__ == '__main__':
main()