import torch from inference.model import DiffTransformerLLM from inference.inference import load_model import argparse import os def main(): parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX") parser.add_argument( "--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)" ) parser.add_argument( "--onnx_path", type=str, default="model.onnx", help="Output ONNX file path" ) parser.add_argument( "--seq_len", type=int, default=32, help="Dummy input sequence length" ) args = parser.parse_args() device = torch.device("cpu") print(f"Loading model from {args.checkpoint}") model = load_model(args.checkpoint, device=device, fp16=False, quantize=False) model.eval() # Prepare dummy input batch_size = 1 seq_len = args.seq_len input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long) # Create a dummy causal mask. This will be a dynamic input to the ONNX model. causal_mask = torch.triu( torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1 ) attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32) attn_mask.masked_fill_(causal_mask, float("-inf")) # Export to ONNX print(f"Exporting to ONNX: {args.onnx_path}") torch.onnx.export( model, (input_ids, attn_mask), args.onnx_path, input_names=["input_ids", "attn_mask"], output_names=["logits"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "seq_len"}, "attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"}, "logits": {0: "batch_size", 1: "seq_len"}, }, opset_version=17, do_constant_folding=True, ) print(f"ONNX export complete: {args.onnx_path}") if __name__ == "__main__": main()