Spaces:
Sleeping
Sleeping
import torch | |
from inference.model import DiffTransformerLLM | |
from inference.inference import load_model | |
import argparse | |
import os | |
def main(): | |
parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX") | |
parser.add_argument( | |
"--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)" | |
) | |
parser.add_argument( | |
"--onnx_path", type=str, default="model.onnx", help="Output ONNX file path" | |
) | |
parser.add_argument( | |
"--seq_len", type=int, default=32, help="Dummy input sequence length" | |
) | |
args = parser.parse_args() | |
device = torch.device("cpu") | |
print(f"Loading model from {args.checkpoint}") | |
model = load_model(args.checkpoint, device=device, fp16=False, quantize=False) | |
model.eval() | |
# Prepare dummy input | |
batch_size = 1 | |
seq_len = args.seq_len | |
input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long) | |
# Create a dummy causal mask. This will be a dynamic input to the ONNX model. | |
causal_mask = torch.triu( | |
torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1 | |
) | |
attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32) | |
attn_mask.masked_fill_(causal_mask, float("-inf")) | |
# Export to ONNX | |
print(f"Exporting to ONNX: {args.onnx_path}") | |
torch.onnx.export( | |
model, | |
(input_ids, attn_mask), | |
args.onnx_path, | |
input_names=["input_ids", "attn_mask"], | |
output_names=["logits"], | |
dynamic_axes={ | |
"input_ids": {0: "batch_size", 1: "seq_len"}, | |
"attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"}, | |
"logits": {0: "batch_size", 1: "seq_len"}, | |
}, | |
opset_version=17, | |
do_constant_folding=True, | |
) | |
print(f"ONNX export complete: {args.onnx_path}") | |
if __name__ == "__main__": | |
main() | |