from collections import defaultdict, OrderedDict from dataclasses import dataclass from typing import List, Dict, Union import numpy as np from openvino.runtime import Model, Node from openvino.runtime.op import Parameter, Constant import openvino.runtime.opset12 as opset from openvino.runtime.utils.types import get_element_type import openvino as ov from tqdm.auto import tqdm OPERATION_TYPE_MAP = {"MatMul": opset.matmul, "Convolution": opset.convolution} ORIGINAL_PRECISION_RT_INFO_NAME = "precise_0" @dataclass class TrackedNodeInfo: """ Data associated with a node tracked for upcasting """ node: Node # Target node to track snr: float = None # SNR of the target node input_nodes: List[Node] = None # Input nodes of the target node result_node: Node = None # Result node of the target node input_result_nodes: Dict[Node, Node] = None # Result nodes of non-const inputs of the target node node_value_full_precision: np.ndarray = None # Result of the node in full precision node_value_half_precision: np.ndarray = None # Result of the node in half precision input_values_full_precision: np.ndarray = None # Results of the target node inputs in full precision def partially_upcast_nodes_to_fp32( orig_model: Model, example_input: Union[List, Dict], half_type: str = "f16", batch_size: int = 50, operation_types: List[str] = None, upcast_ratio: float = 0.1, verbose: bool = False, ) -> Model: """ Transform a model to upcast some nodes to be executed in full precision instead of half precision. These nodes are marked with runtime info flag. Nodes are selected based on Signal-to-Noise Ratio (SNR) metric: upcast_ratio fraction of tracked nodes with the lowest SNR are marked for full precision execution. Note: Input model should have fp16 weights (i.e. saved with compress_to_fp16=True) in order to conserve calibration memory. :param orig_model: Model to process :param example_input: Example input for model inference :param half_type: Either "f16" or "bf16" :param batch_size: Number of nodes to process together during a single model inference. The lower the value is, the less memory footprint is, but the larger is the processing time. The value of -1 is used to disable batching. :param operation_types: Types of operations to consider. If None, MatMuls and Convolutions are considered. :param upcast_ratio: Fraction of nodes to upcast (with the lowest SNR). 0 - do not upcast anything, 1 - upcast every operation of the given types. :param verbose: If True, prints progress output. :return: Upcasted OV model with some nodes marked for full precision execution. """ if half_type not in ("f16", "bf16"): raise ValueError(f"Half type must be either 'f16' or 'bf16'. Got {half_type}.") if half_type == "bf16": print("Warning! Calibration currently does not provide any improvement for bf16 type.") if operation_types is None: operation_types = ["MatMul", "Convolution"] for op_type in operation_types: if op_type not in OPERATION_TYPE_MAP: raise ValueError(f"Operation type must be one of the following {list(OPERATION_TYPE_MAP.keys())}. " f"Got {op_type}.") if verbose: print(f"The following operation types will be considered: {operation_types}") device = "GPU" if half_type == "f16" else "CPU" nodes_to_track_names = get_nodes_to_track(orig_model, operation_types) if len(nodes_to_track_names) == 0: if verbose: print("Warning. Not found any operations of the given type(s).") return orig_model.clone() node_names_and_snrs = [] batch_size = len(nodes_to_track_names) if batch_size == -1 or batch_size > len(nodes_to_track_names) else batch_size if verbose: print("Started upcasting") for i in tqdm( range(0, len(nodes_to_track_names), batch_size), desc="Processing batches", disable=not verbose, ): if upcast_ratio == 0.0 or upcast_ratio == 1.0: continue model = orig_model.clone() name_to_node_map = {op.get_friendly_name(): op for op in model.get_ops()} nodes_to_track_batch = [TrackedNodeInfo(name_to_node_map[node_name]) for node_name in nodes_to_track_names[i : i + batch_size]] # Add outputs for non-constant inputs of tracked nodes insert_outputs_for_tracked_ops(model, nodes_to_track_batch) # Infer model to collect tracked operation results and results of their inputs in full precision infer_full_net(nodes_to_track_batch, model, example_input) # Infer nodes in half precision one by one using full precision inputs, collect half precision results infer_nodes(nodes_to_track_batch, device, half_type) # Compute operation SNR based on full precision and half precision results for node_info in nodes_to_track_batch: try: snr = compute_snr( node_info.node_value_full_precision, node_info.node_value_half_precision, ) except RuntimeError as e: # TODO: find the reason behind this if node_info.node.get_type_name() in [ "Add", "Concat", ] and "Shape mismatch" in str(e): print( "Warning.", str(e), node_info.node.get_friendly_name(), node_info.node.get_type_name(), [(inp_node.get_friendly_name(), inp_node.get_type_name()) for inp_node in node_info.input_nodes], ) snr = np.finfo(np.float32).max else: raise e node_names_and_snrs.append((node_info.node.get_friendly_name(), snr)) if upcast_ratio != 0.0 and upcast_ratio != 1.0: node_names_and_snrs = sorted(node_names_and_snrs, key=lambda it: it[1]) node_names, node_snrs = tuple(zip(*node_names_and_snrs)) n_nodes = len(node_names) nodes_to_upcast_cnt = int(np.ceil(n_nodes * upcast_ratio)) node_to_upcast_names = node_names[:nodes_to_upcast_cnt] if verbose: snr_quantile = node_snrs[nodes_to_upcast_cnt - 1] print(f"Upcasted {nodes_to_upcast_cnt}/{n_nodes} nodes with SNR less than {snr_quantile:.2f}.") for node_name, node_snr in node_names_and_snrs[:nodes_to_upcast_cnt]: print(node_name, node_snr) elif upcast_ratio == 0.0: if verbose: print("Skipping algorithm because upcast ratio equals 0.0. Nothing to upcast.") node_to_upcast_names = [] else: if verbose: print("Skipping algorithm because upcast ratio equals 1.0. Upcasting all nodes of the given type(s).") node_to_upcast_names = nodes_to_track_names new_model = orig_model.clone() mark_nodes_to_upcast_to_fp32(new_model, node_to_upcast_names) return new_model def get_nodes_to_track(model: Model, operation_types: List[str]) -> List: nodes_to_track = [] for i, op in enumerate(model.get_ordered_ops()): if op.get_type_name() in operation_types and all( map( lambda input: input.get_node().get_type_name() != "Result", op.output(0).get_target_inputs(), ) ): nodes_to_track.append(op.get_friendly_name()) return nodes_to_track def insert_outputs_for_tracked_ops(model: Model, nodes_to_track: List[TrackedNodeInfo]) -> None: node_to_output_map = OrderedDict() node_to_node_info_map = defaultdict(list) for node_info in nodes_to_track: node = node_info.node node_to_node_info_map[node].append((node_info, "parent")) # add as a parent node if node not in node_to_output_map: node_to_output_map[node] = node.output(0) node_info.input_nodes = [] for inp_value in node.input_values(): child_node = inp_value.get_node() node_info.input_nodes.append(child_node) # Do not add outputs for constant nodes if child_node.get_type_name() != "Constant" and not is_constant_path(child_node): node_to_node_info_map[child_node].append((node_info, "child")) # add as a child node if child_node not in node_to_output_map: node_to_output_map[child_node] = child_node.output(0) outputs = model.add_outputs(list(node_to_output_map.values())) for output, node in zip(outputs, node_to_output_map.keys()): # Value matching will be done later based on result node friendly names result_node = output.node for node_info, parent_label in node_to_node_info_map[node]: is_parent = parent_label == "parent" if is_parent: node_info.result_node = result_node else: if node_info.input_result_nodes is None: node_info.input_result_nodes = {} node_info.input_result_nodes[node] = result_node def get_const_value_from_ovmodel(node: Union[Constant, Node]) -> np.ndarray: if node.get_type_name() == "Constant": assert node.get_element_type() not in [ ov.Type.f16, ov.Type.bf16, ], f"{node.get_friendly_name()}, {node.get_element_type()}" return node.get_data() elif is_constant_path(node): # If model is compressed and constant values flow through decompression convert const_node = node.input_value(0).get_node() assert const_node.get_type_name() == "Constant" assert const_node.get_element_type().is_real(), const_node.get_element_type() return node.input_value(0).get_node().get_data() # return f16 weight else: raise Exception(f"Cannot get const values from ov.Model for {node.get_friendly_name()} with type {node.get_type_name()}") def is_constant_path(node: Node) -> bool: if node.get_type_name() != "Convert": return False if len(node.get_rt_info()["is_decompression_0"].aslist()) > 0: return True if node.input_value(0).get_node().get_type_name() == "Constant": return True return False def infer_full_net(nodes_to_track: List[TrackedNodeInfo], orig_model: Model, example_inputs: List) -> None: core = ov.Core() exec_net = core.compile_model(orig_model, "CPU", config={"INFERENCE_PRECISION_HINT": "f32"}) request = exec_net.create_infer_request() results = request.infer(example_inputs, share_inputs=True, share_outputs=True) friendly_name_to_result_map = {} for i, (key, val) in enumerate(results.items()): result_node = key.node friendly_name_to_result_map[result_node.get_friendly_name()] = val for node_info in nodes_to_track: node_info.node_value_full_precision = friendly_name_to_result_map[node_info.result_node.get_friendly_name()] node_info.input_values_full_precision = [] for input_node in node_info.input_nodes: if input_node.get_type_name() == "Constant" or is_constant_path(input_node): # If input is constant, retrieve its value from model input_value = get_const_value_from_ovmodel(input_node) else: # If input is not constant, retrieve its input from inference results input_value = friendly_name_to_result_map[node_info.input_result_nodes[input_node].get_friendly_name()] node_info.input_values_full_precision.append(input_value) def infer_nodes(nodes_to_track: List[TrackedNodeInfo], device: str, precision: str) -> None: for node_info in nodes_to_track: infer_tracked_op(node_info, device, precision) def infer_tracked_op(node_info: TrackedNodeInfo, device: str, precision: str) -> None: parameters = [] inputs = [] input_values = node_info.input_values_full_precision for input_value in input_values: parameter = Parameter(get_element_type(input_value.dtype), ov.PartialShape(input_value.shape)) if input_value.dtype == np.float16: # Convert f16 weight to f32 convert_node = opset.convert(parameter, "f32") inputs.append(convert_node) else: inputs.append(parameter) parameters.append(parameter) node = node_info.node try: call_attributes = node.get_attributes() # Below are some op workarounds if node.get_type_name() == "Divide" and "m_pythondiv" in call_attributes: del call_attributes["m_pythondiv"] if node.get_type_name() == "Broadcast" and "mode" in call_attributes: call_attributes["broadcast_spec"] = call_attributes["mode"] del call_attributes["mode"] if node.get_type_name() == "Concat": new_op = OPERATION_TYPE_MAP[node.get_type_name()](inputs, **call_attributes) else: new_op = OPERATION_TYPE_MAP[node.get_type_name()](*inputs, **call_attributes) ov_model = ov.Model([new_op], parameters=parameters) exec_net = ov.Core().compile_model(ov_model, device, config={"INFERENCE_PRECISION_HINT": precision}) request = exec_net.create_infer_request() result = request.infer(input_values, share_inputs=True, share_outputs=True) except Exception as e: print( "Operation inference error", node.get_type_name(), node.get_friendly_name(), inputs, node.get_attributes(), ) raise e node_info.node_value_half_precision = result[0] assert len(result) == 1 def is_model_partially_upcasted(model) -> bool: for node in model.get_ordered_ops(): if node.get_type_name() not in OPERATION_TYPE_MAP.keys(): continue if ORIGINAL_PRECISION_RT_INFO_NAME in node.get_rt_info().keys(): return True return False def mark_nodes_to_upcast_to_fp32(model: ov.Model, nodes_with_errors: List[str]) -> None: nodes_to_mark = set(nodes_with_errors) for node in model.get_ordered_ops(): if node.get_friendly_name() in nodes_to_mark: node.get_rt_info()[ORIGINAL_PRECISION_RT_INFO_NAME] = "" nodes_to_mark.remove(node.get_friendly_name()) assert len(nodes_to_mark) == 0, nodes_to_mark def compute_snr(x, y): # x -- original value (full precision), y -- value with noise (half precision) x, y = x.astype(np.float32), y.astype(np.float32) max_value = np.finfo(np.float32).max if np.prod(x.shape) != np.prod(y.shape): raise RuntimeError(f"Shape mismatch: {x.shape}, {y.shape}.") x = np.nan_to_num(x, posinf=max_value) y = np.nan_to_num(y, posinf=max_value) Ps = np.linalg.norm(x) Pn = np.nan_to_num(np.linalg.norm(x - y), posinf=max_value) if Ps == Pn == 0.0: return max_value snr = np.nan_to_num(20 * np.log10(Ps / Pn), posinf=max_value) return snr