File size: 1,878 Bytes
adf0368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from inference.model import DiffTransformerLLM
from inference.inference import load_model
import argparse
import os


def main():
    parser = argparse.ArgumentParser(description="Export DiffTransformerLLM to ONNX")
    parser.add_argument(
        "--checkpoint", type=str, required=True, help="Path to model checkpoint (.pt)"
    )
    parser.add_argument(
        "--onnx_path", type=str, default="model.onnx", help="Output ONNX file path"
    )
    parser.add_argument(
        "--seq_len", type=int, default=32, help="Dummy input sequence length"
    )
    args = parser.parse_args()

    device = torch.device("cpu")
    print(f"Loading model from {args.checkpoint}")
    model = load_model(args.checkpoint, device=device, fp16=False, quantize=False)
    model.eval()

    # Prepare dummy input
    batch_size = 1
    seq_len = args.seq_len
    input_ids = torch.randint(0, 259, (batch_size, seq_len), dtype=torch.long)

    # Create a dummy causal mask. This will be a dynamic input to the ONNX model.
    causal_mask = torch.triu(
        torch.ones(1, seq_len, seq_len, dtype=torch.bool), diagonal=1
    )
    attn_mask = torch.zeros(1, seq_len, seq_len, dtype=torch.float32)
    attn_mask.masked_fill_(causal_mask, float("-inf"))

    # Export to ONNX
    print(f"Exporting to ONNX: {args.onnx_path}")
    torch.onnx.export(
        model,
        (input_ids, attn_mask),
        args.onnx_path,
        input_names=["input_ids", "attn_mask"],
        output_names=["logits"],
        dynamic_axes={
            "input_ids": {0: "batch_size", 1: "seq_len"},
            "attn_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
            "logits": {0: "batch_size", 1: "seq_len"},
        },
        opset_version=17,
        do_constant_folding=True,
    )
    print(f"ONNX export complete: {args.onnx_path}")


if __name__ == "__main__":
    main()