from models.unet import UNet2DConditionModel import torch from ip_adapter import IPAdapterXL from safetensors.torch import load_file import onnx from pathlib import Path output_path = '/home/new_onnx/unet' output_path = Path(output_path) unet = UNet2DConditionModel.from_pretrained( "neta-art/neta-xl-2.0", subfolder="unet", ) state_dict = load_file('/home/ControlNeXt/ControlNeXt-SDXL/unet.safetensors') unet.load_state_dict(state_dict, strict=False) image_encoder_path = "h94/IP-Adapter" ip_ckpt = "h94/IP-Adapter" device = 'cpu' ip_model = IPAdapterXL(unet, image_encoder_path, ip_ckpt, device, num_tokens=4) unet = ip_model.unet sample = torch.randn((1, 4, 128, 128)) timestep = torch.rand(1, dtype=torch.float32) encoder_hidden_state = torch.randn((1, 81, 2048)) mid_block_additional_residual_scale = torch.tensor([1], dtype=torch.float32) mid_block_additional_residual = torch.randn((1, 320, 128, 128), dtype=torch.float32) dummy_inputs = (sample, timestep, encoder_hidden_state, mid_block_additional_residual, mid_block_additional_residual_scale) onnx_output_path = output_path / "unet" / "model.onnx" torch.onnx.export( unet, dummy_inputs, str(onnx_output_path), # Đường dẫn dưới dạng chuỗi để đảm bảo tương thích export_params=True, opset_version=18, do_constant_folding=True, input_names=['sample', 'timestep', 'encoder_hidden_state', 'control_out', 'control_scale'], output_names=['predict_noise'], dynamic_axes={ "sample": {0: "B"}, "encoder_hidden_state": {0: "B", 1: "1B", 2: '2B'}, "control_out": {0: "B"}, "predict_noise": {0: 'B'} } ) unet_opt_graph = onnx.load(str(onnx_output_path)) unet_optimize = output_path / "unet_optimize" / "model.onnx" onnx.save_model( unet_opt_graph, str(unet_optimize), save_as_external_data=True, all_tensors_to_one_file=True, location="weights.pb", )