File size: 1,402 Bytes
7b4bb4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import os
from copy import deepcopy

class ModelExporter(torch.nn.Module):
    def __init__(self, yoloModel, device='cpu'):
        super(ModelExporter, self).__init__()
        model = deepcopy(yoloModel).to(device)
        for p in model.parameters():
            p.requires_grad = False
        model.eval()
        model.float()
        model = model.fuse()

        self.model = model
        self.device = device

    def forward(self, x, txt_feats):
        return self.model.predict(x, txt_feats=txt_feats)

    def export(self, output_dir, model_name, img_width, img_height, num_classes):
        x = torch.randn(1, 3, img_width, img_height, requires_grad=False).to(self.device)
        txt_feats = torch.randn(1, num_classes, 512, requires_grad=False).to(self.device)

        print(x.shape, txt_feats.shape)

        # Export model
        onnx_name = model_name + ".onnx"
        os.makedirs(output_dir, exist_ok=True)
        output_path = f"{output_dir}/{onnx_name}"
        with torch.no_grad():
            torch.onnx.export(self,
                              (x, txt_feats),
                              output_path,
                              do_constant_folding=True,
                              opset_version=17,
                              input_names=["images", "txt_feats"],
                              output_names=["output"])

        return output_path