|
from models.unet import UNet2DConditionModel |
|
import torch |
|
from ip_adapter import IPAdapterXL |
|
from safetensors.torch import load_file |
|
import onnx |
|
from pathlib import Path |
|
|
|
output_path = '/home/new_onnx/unet' |
|
output_path = Path(output_path) |
|
|
|
unet = UNet2DConditionModel.from_pretrained( |
|
"neta-art/neta-xl-2.0", |
|
subfolder="unet", |
|
) |
|
|
|
state_dict = load_file('/home/ControlNeXt/ControlNeXt-SDXL/unet.safetensors') |
|
unet.load_state_dict(state_dict, strict=False) |
|
|
|
image_encoder_path = "h94/IP-Adapter" |
|
ip_ckpt = "h94/IP-Adapter" |
|
device = 'cpu' |
|
ip_model = IPAdapterXL(unet, image_encoder_path, ip_ckpt, device, num_tokens=4) |
|
|
|
unet = ip_model.unet |
|
|
|
sample = torch.randn((1, 4, 128, 128)) |
|
timestep = torch.rand(1, dtype=torch.float32) |
|
encoder_hidden_state = torch.randn((1, 81, 2048)) |
|
mid_block_additional_residual_scale = torch.tensor([1], dtype=torch.float32) |
|
mid_block_additional_residual = torch.randn((1, 320, 128, 128), dtype=torch.float32) |
|
|
|
dummy_inputs = (sample, timestep, encoder_hidden_state, mid_block_additional_residual, mid_block_additional_residual_scale) |
|
|
|
onnx_output_path = output_path / "unet" / "model.onnx" |
|
torch.onnx.export( |
|
unet, |
|
dummy_inputs, |
|
str(onnx_output_path), |
|
export_params=True, |
|
opset_version=18, |
|
do_constant_folding=True, |
|
input_names=['sample', 'timestep', 'encoder_hidden_state', 'control_out', 'control_scale'], |
|
output_names=['predict_noise'], |
|
dynamic_axes={ |
|
"sample": {0: "B"}, |
|
"encoder_hidden_state": {0: "B", 1: "1B", 2: '2B'}, |
|
"control_out": {0: "B"}, |
|
"predict_noise": {0: 'B'} |
|
} |
|
) |
|
|
|
unet_opt_graph = onnx.load(str(onnx_output_path)) |
|
unet_optimize = output_path / "unet_optimize" / "model.onnx" |
|
onnx.save_model( |
|
unet_opt_graph, |
|
str(unet_optimize), |
|
save_as_external_data=True, |
|
all_tensors_to_one_file=True, |
|
location="weights.pb", |
|
) |
|
|