|
from functools import wraps |
|
import warnings |
|
import torch |
|
import openvino as ov |
|
from pathlib import Path |
|
from typing import Tuple, Optional |
|
import types |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
|
|
try: |
|
from optimum.exporters.openvino.stateful import make_stateful |
|
from optimum.exporters.openvino.stateful import fuse_cache_reorder |
|
except ImportError: |
|
warnings.warn("We recommend to update optimum-intel for getting optimal performance") |
|
make_stateful = None |
|
fuse_cache_reorder = None |
|
|
|
|
|
def register_configs(): |
|
from optimum.exporters.tasks import TasksManager |
|
TasksManager._SUPPORTED_MODEL_TYPE["minicpm"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"] |
|
TasksManager._SUPPORTED_MODEL_TYPE["qwen2"] = TasksManager._SUPPORTED_MODEL_TYPE["llama"] |
|
|
|
def patch_stateful(ov_model, model_type): |
|
key_value_input_names = [ |
|
key.get_any_name() for key in ov_model.inputs if any("key_values" in key_name for key_name in key.get_names()) |
|
] |
|
key_value_output_names = [ |
|
key.get_any_name() for key in ov_model.outputs if any("present" in key_name for key_name in key.get_names()) |
|
] |
|
not_kv_inputs = [ |
|
input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names()) |
|
] |
|
if not key_value_input_names or not key_value_output_names: |
|
return |
|
batch_dim = 1 if model_type == "chatglm" else 0 |
|
num_attention_heads = 1 |
|
|
|
fuse_cache_reorder(ov_model, not_kv_inputs, key_value_input_names, batch_dim) |
|
make_stateful( |
|
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None |
|
) |
|
|
|
|
|
|
|
def flattenize_inputs(inputs): |
|
""" |
|
Helper function for making nested inputs flattens |
|
""" |
|
flatten_inputs = [] |
|
for input_data in inputs: |
|
if input_data is None: |
|
continue |
|
if isinstance(input_data, (list, tuple)): |
|
flatten_inputs.extend(flattenize_inputs(input_data)) |
|
else: |
|
flatten_inputs.append(input_data) |
|
return flatten_inputs |
|
|
|
|
|
def cleanup_torchscript_cache(): |
|
""" |
|
Helper for removing cached model representation |
|
""" |
|
torch._C._jit_clear_class_registry() |
|
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() |
|
torch.jit._state._clear_class_state() |
|
|
|
|
|
def convert_mpt(pt_model: torch.nn.Module, model_path: Path): |
|
""" |
|
MPT model conversion function |
|
|
|
Params: |
|
pt_model: PyTorch model |
|
model_path: path for saving model |
|
Returns: |
|
None |
|
""" |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
pt_model.config.save_pretrained(ov_out_path.parent) |
|
pt_model.config.use_cache = True |
|
outs = pt_model( |
|
input_ids=torch.ones((1, 10), dtype=torch.long), |
|
attention_mask=torch.ones((1, 10), dtype=torch.long), |
|
) |
|
inputs = ["input_ids"] |
|
outputs = ["logits"] |
|
|
|
dynamic_shapes = {"input_ids": {0: "batch_size", 1: "seq_len"}, "attention_mask": {0: "batch_size", 1: "seq_len"}} |
|
for idx in range(len(outs.past_key_values)): |
|
inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"]) |
|
dynamic_shapes[inputs[-1]] = {0: "batch_size", 2: "past_sequence + sequence"} |
|
dynamic_shapes[inputs[-2]] = {0: "batch_size", 3: "past_sequence + sequence"} |
|
outputs.extend([f"present.{idx}.key", f"present.{idx}.value"]) |
|
|
|
inputs.append("attention_mask") |
|
dummy_inputs = { |
|
"input_ids": torch.ones((1, 2), dtype=torch.long), |
|
"past_key_values": outs.past_key_values, |
|
"attention_mask": torch.ones((1, 12), dtype=torch.long), |
|
} |
|
pt_model.config.torchscript = True |
|
orig_forward = pt_model.forward |
|
|
|
@wraps(orig_forward) |
|
def ts_patched_forward( |
|
input_ids: torch.Tensor, |
|
past_key_values: Tuple[Tuple[torch.Tensor]], |
|
attention_mask: torch.Tensor, |
|
): |
|
pkv_list = list(past_key_values) |
|
outs = orig_forward( |
|
input_ids=input_ids, past_key_values=pkv_list, attention_mask=attention_mask |
|
) |
|
return (outs.logits, tuple(outs.past_key_values)) |
|
|
|
pt_model.forward = ts_patched_forward |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
pt_model.forward = orig_forward |
|
for inp_name, m_input, input_data in zip( |
|
inputs, ov_model.inputs, flattenize_inputs(dummy_inputs.values()) |
|
): |
|
input_node = m_input.get_node() |
|
if input_node.element_type == ov.Type.dynamic: |
|
m_input.get_node().set_element_type(ov.Type.f32) |
|
shape = list(input_data.shape) |
|
if inp_name in dynamic_shapes: |
|
for k in dynamic_shapes[inp_name]: |
|
shape[k] = -1 |
|
input_node.set_partial_shape(ov.PartialShape(shape)) |
|
m_input.get_tensor().set_names({inp_name}) |
|
|
|
for out, out_name in zip(ov_model.outputs, outputs): |
|
out.get_tensor().set_names({out_name}) |
|
|
|
ov_model.validate_nodes_and_infer_types() |
|
if make_stateful is not None: |
|
patch_stateful(ov_model, "mpt") |
|
ov.save_model(ov_model, ov_out_path) |
|
del ov_model |
|
cleanup_torchscript_cache() |
|
del pt_model |
|
|
|
|
|
def convert_baichuan(pt_model: torch.nn.Module, model_path: Path): |
|
""" |
|
Baichuan model conversion function |
|
Params: |
|
pt_model: PyTorch model |
|
model_path: path for saving model |
|
Returns: |
|
None |
|
""" |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
pt_model.config.save_pretrained(ov_out_path.parent) |
|
pt_model.config.use_cache = True |
|
outs = pt_model( |
|
input_ids=torch.ones((1, 10), dtype=torch.long), |
|
attention_mask=torch.ones((1, 10), dtype=torch.long), |
|
) |
|
inputs = ["input_ids", "attention_mask"] |
|
outputs = ["logits"] |
|
|
|
dynamic_shapes = { |
|
"input_ids": {0: "batch_size", 1: "seq_len"}, |
|
"attention_mask": {0: "batch_size", 1: "seq_len"}, |
|
} |
|
for idx in range(len(outs.past_key_values)): |
|
inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"]) |
|
dynamic_shapes[inputs[-1]] = {0: "batch_size", 2: "past_sequence + sequence"} |
|
dynamic_shapes[inputs[-2]] = {0: "batch_size", 2: "past_sequence + sequence"} |
|
outputs.extend([f"present.{idx}.key", f"present.{idx}.value"]) |
|
|
|
dummy_inputs = { |
|
"input_ids": torch.ones((1, 2), dtype=torch.long), |
|
"attention_mask": torch.ones((1, 12), dtype=torch.long), |
|
"past_key_values": outs.past_key_values, |
|
} |
|
pt_model.config.torchscript = True |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
for inp_name, m_input, input_data in zip( |
|
inputs, ov_model.inputs, flattenize_inputs(dummy_inputs.values()) |
|
): |
|
input_node = m_input.get_node() |
|
if input_node.element_type == ov.Type.dynamic: |
|
m_input.get_node().set_element_type(ov.Type.f32) |
|
shape = list(input_data.shape) |
|
if inp_name in dynamic_shapes: |
|
for k in dynamic_shapes[inp_name]: |
|
shape[k] = -1 |
|
input_node.set_partial_shape(ov.PartialShape(shape)) |
|
m_input.get_tensor().set_names({inp_name}) |
|
|
|
for out, out_name in zip(ov_model.outputs, outputs): |
|
out.get_tensor().set_names({out_name}) |
|
|
|
ov_model.validate_nodes_and_infer_types() |
|
if make_stateful is not None: |
|
patch_stateful(ov_model, "baichuan") |
|
ov.save_model(ov_model, ov_out_path) |
|
del ov_model |
|
cleanup_torchscript_cache() |
|
del pt_model |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor): |
|
mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype) |
|
if query_layer.shape[2] == key_layer.shape[2]: |
|
tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1) |
|
mask.masked_fill_(tmp_mask, float("-inf")) |
|
|
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=mask) |
|
return context_layer |
|
|
|
|
|
def _core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): |
|
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] |
|
if attention_mask is None: |
|
context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer) |
|
else: |
|
context_layer = torch.nn.functional.scaled_dot_product_attention( |
|
query_layer, key_layer, value_layer, attention_mask |
|
) |
|
context_layer = context_layer.permute(2, 0, 1, 3) |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) |
|
context_layer = context_layer.reshape(*new_context_layer_shape) |
|
|
|
return context_layer |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def _get_chatglm_attention_mask(input_ids, past_key): |
|
mask = torch.zeros((input_ids.shape[1], past_key.shape[0] + input_ids.shape[1]), dtype=past_key.dtype) |
|
if past_key.shape[0] == 0: |
|
tmp_mask = torch.ones((input_ids.shape[1], past_key.shape[0] + input_ids.shape[1]), dtype=torch.bool).triu(diagonal=1) |
|
mask.masked_fill_(tmp_mask, float("-inf")) |
|
return mask |
|
|
|
|
|
def _chatglm_transformer_forward( |
|
self, |
|
input_ids, |
|
position_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.BoolTensor] = None, |
|
full_attention_mask: Optional[torch.BoolTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None |
|
): |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embedding(input_ids) |
|
|
|
if self.pre_seq_len is not None: |
|
if past_key_values is None: |
|
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, |
|
dtype=inputs_embeds.dtype) |
|
if attention_mask is not None: |
|
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1) |
|
|
|
if full_attention_mask is None: |
|
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): |
|
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) |
|
elif past_key_values is not None: |
|
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, |
|
device=input_ids.device, |
|
dtype=torch.float) * float("-inf") |
|
full_attention_mask.triu_(diagonal=1) |
|
past_length = 0 |
|
if past_key_values: |
|
past_length = past_key_values[0][0].shape[0] |
|
if past_length: |
|
full_attention_mask = torch.cat((torch.zeros(batch_size, seq_length, past_length, |
|
device=input_ids.device), full_attention_mask), dim=-1) |
|
full_attention_mask.unsqueeze_(1) |
|
|
|
|
|
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) |
|
if position_ids is not None: |
|
rotary_pos_emb = rotary_pos_emb[position_ids] |
|
else: |
|
rotary_pos_emb = rotary_pos_emb[None, :seq_length] |
|
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() |
|
|
|
|
|
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( |
|
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, |
|
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states |
|
) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=presents, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|
|
|
|
def _patch_chatglm_forward(model: "PreTrainedModel"): |
|
model.transformer.forward = types.MethodType(_chatglm_transformer_forward, model.transformer) |
|
for block in model.transformer.encoder.layers: |
|
block.self_attention.core_attention.forward = types.MethodType( |
|
_core_attention_forward, block.self_attention.core_attention |
|
) |
|
|
|
|
|
def convert_chatglm(pt_model: torch.nn.Module, model_path: Path): |
|
""" |
|
ChatGLM model conversion function |
|
|
|
Params: |
|
pt_model: PyTorch model |
|
model_path: path for saving model |
|
Returns: |
|
None |
|
""" |
|
_patch_chatglm_forward(pt_model) |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
pt_model.config.save_pretrained(ov_out_path.parent) |
|
pt_model.config.use_cache = True |
|
outs = pt_model( |
|
input_ids=torch.ones((1, 10), dtype=torch.long), |
|
position_ids=torch.arange(0, 10, dtype=torch.long), |
|
) |
|
inputs = ["input_ids"] |
|
outputs = ["logits"] |
|
|
|
dynamic_shapes = { |
|
"input_ids": {0: "batch_size", 1: "seq_len"}, |
|
"position_ids": {0: "batch_size", 1: "seq_len"}, |
|
"attention_mask": {0: "batch_size", 1: "seq_len"}, |
|
} |
|
inputs += ["position_ids", "attention_mask"] |
|
for idx in range(len(outs.past_key_values)): |
|
inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"]) |
|
dynamic_shapes[inputs[-1]] = {0: "past_sequence + sequence", 1: "batch_size"} |
|
dynamic_shapes[inputs[-2]] = {0: "past_sequence + sequence", 1: "batch_size"} |
|
outputs.extend([f"present.{idx}.key", f"present.{idx}.value"]) |
|
|
|
dummy_inputs = { |
|
"input_ids": torch.ones((1, 1), dtype=torch.long), |
|
"position_ids": torch.tensor([[10]], dtype=torch.long), |
|
"attention_mask": torch.ones((1, 11), dtype=torch.long), |
|
"past_key_values": outs.past_key_values, |
|
} |
|
pt_model.config.torchscript = True |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
for inp_name, m_input, input_data in zip( |
|
inputs, ov_model.inputs, flattenize_inputs(dummy_inputs.values()) |
|
): |
|
input_node = m_input.get_node() |
|
if input_node.element_type == ov.Type.dynamic: |
|
m_input.get_node().set_element_type(ov.Type.f32) |
|
shape = list(input_data.shape) |
|
if inp_name in dynamic_shapes: |
|
for k in dynamic_shapes[inp_name]: |
|
shape[k] = -1 |
|
input_node.set_partial_shape(ov.PartialShape(shape)) |
|
m_input.get_tensor().set_names({inp_name}) |
|
|
|
for out, out_name in zip(ov_model.outputs, outputs): |
|
out.get_tensor().set_names({out_name}) |
|
|
|
ov_model.validate_nodes_and_infer_types() |
|
if make_stateful is not None: |
|
patch_stateful(ov_model, "chatglm") |
|
ov.save_model(ov_model, ov_out_path) |
|
del ov_model |
|
cleanup_torchscript_cache() |
|
del pt_model |
|
|
|
def convert_gemma(pt_model: torch.nn.Module, model_path: Path): |
|
""" |
|
Gamma model conversion function |
|
|
|
Params: |
|
pt_model: PyTorch model |
|
model_path: path for saving model |
|
Returns: |
|
None |
|
""" |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
pt_model.config.save_pretrained(ov_out_path.parent) |
|
pt_model.config.use_cache = True |
|
outs = pt_model(input_ids=torch.ones((2, 10), dtype=torch.long)) |
|
inputs = ["input_ids"] |
|
outputs = ["logits"] |
|
|
|
dynamic_shapes = { |
|
"input_ids": {0: "batch_size", 1: "seq_len"}, |
|
"attention_mask": {0: "batch_size", 1: "seq_len"}, |
|
"position_ids": {0: "batch_size", 1: "seq_len"}, |
|
} |
|
inputs += ["attention_mask", "position_ids"] |
|
for idx in range(len(outs.past_key_values)): |
|
inputs.extend([f"past_key_values.{idx}.key", f"past_key_values.{idx}.value"]) |
|
dynamic_shapes[inputs[-1]] = {0: "batch_size", 2: "past_sequence + sequence"} |
|
dynamic_shapes[inputs[-2]] = {0: "batch_size", 2: "past_sequence + sequence"} |
|
outputs.extend([f"present.{idx}.key", f"present.{idx}.value"]) |
|
|
|
dummy_inputs = { |
|
"input_ids": torch.ones((2, 2), dtype=torch.long), |
|
"attention_mask": torch.ones((2, 12), dtype=torch.long), |
|
"position_ids": torch.tensor([[10, 11], [10, 11]], dtype=torch.long), |
|
"past_key_values": outs.past_key_values, |
|
} |
|
pt_model.config.torchscript = True |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
for inp_name, m_input, input_data in zip( |
|
inputs, ov_model.inputs, flattenize_inputs(dummy_inputs.values()) |
|
): |
|
input_node = m_input.get_node() |
|
if input_node.element_type == ov.Type.dynamic: |
|
m_input.get_node().set_element_type(ov.Type.f32) |
|
shape = list(input_data.shape) |
|
if inp_name in dynamic_shapes: |
|
for k in dynamic_shapes[inp_name]: |
|
shape[k] = -1 |
|
input_node.set_partial_shape(ov.PartialShape(shape)) |
|
m_input.get_tensor().set_names({inp_name}) |
|
|
|
for out, out_name in zip(ov_model.outputs, outputs): |
|
out.get_tensor().set_names({out_name}) |
|
|
|
ov_model.validate_nodes_and_infer_types() |
|
if make_stateful is not None: |
|
patch_stateful(ov_model, "gemma") |
|
ov.save_model(ov_model, ov_out_path) |
|
del ov_model |
|
cleanup_torchscript_cache() |
|
del pt_model |
|
|
|
|
|
|
|
def convert_mpnet(pt_model: torch.nn.Module, model_path: Path): |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
dummy_inputs = {"input_ids": torch.ones((1, 10), dtype=torch.long), "attention_mask": torch.ones( |
|
(1, 10), dtype=torch.long)} |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
ov.save_model(ov_model, ov_out_path) |
|
|
|
def convert_bert(pt_model: torch.nn.Module, model_path: Path): |
|
ov_out_path = Path(model_path) / "openvino_model.xml" |
|
dummy_inputs = {"input_ids": torch.ones((1, 10), dtype=torch.long), "attention_mask": torch.ones( |
|
(1, 10), dtype=torch.long), "token_type_ids": torch.zeros((1, 10), dtype=torch.long)} |
|
ov_model = ov.convert_model(pt_model, example_input=dummy_inputs) |
|
ov.save_model(ov_model, ov_out_path) |
|
|
|
|
|
converters = { |
|
|
|
"mpt": convert_mpt, |
|
"chatglm3": convert_chatglm, |
|
"baichuan2": convert_baichuan, |
|
"gemma": convert_gemma, |
|
|
|
"all-mpnet-base-v2": convert_mpnet, |
|
"text2vec-large-chinese": convert_bert, |
|
} |
|
|