DAT-Byte-Demo / export_onnx.py
hudsongouge's picture
Update space
adf0368
import torch
from inference.model import DiffTransformerLLM
from inference.inference import load_model
import argparse
import os
def main():
parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX")
parser.add_argument(
"--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)"
)
parser.add_argument(
"--onnx_path", type=str, default="model.onnx", help="Output ONNX file path"
)
parser.add_argument(
"--seq_len", type=int, default=32, help="Dummy input sequence length"
)
args = parser.parse_args()
device = torch.device("cpu")
print(f"Loading model from {args.checkpoint}")
model = load_model(args.checkpoint, device=device, fp16=False, quantize=False)
model.eval()
# Prepare dummy input
batch_size = 1
seq_len = args.seq_len
input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long)
# Create a dummy causal mask. This will be a dynamic input to the ONNX model.
causal_mask = torch.triu(
torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1
)
attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32)
attn_mask.masked_fill_(causal_mask, float("-inf"))
# Export to ONNX
print(f"Exporting to ONNX: {args.onnx_path}")
torch.onnx.export(
model,
(input_ids, attn_mask),
args.onnx_path,
input_names=["input_ids", "attn_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
"logits": {0: "batch_size", 1: "seq_len"},
},
opset_version=17,
do_constant_folding=True,
)
print(f"ONNX export complete: {args.onnx_path}")
if __name__ == "__main__":
main()