|
import torch |
|
import os |
|
from copy import deepcopy |
|
|
|
class ModelExporter(torch.nn.Module): |
|
def __init__(self, yoloModel, device='cpu'): |
|
super(ModelExporter, self).__init__() |
|
model = deepcopy(yoloModel).to(device) |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
model.eval() |
|
model.float() |
|
model = model.fuse() |
|
|
|
self.model = model |
|
self.device = device |
|
|
|
def forward(self, x, txt_feats): |
|
return self.model.predict(x, txt_feats=txt_feats) |
|
|
|
def export(self, output_dir, model_name, img_width, img_height, num_classes): |
|
x = torch.randn(1, 3, img_width, img_height, requires_grad=False).to(self.device) |
|
txt_feats = torch.randn(1, num_classes, 512, requires_grad=False).to(self.device) |
|
|
|
print(x.shape, txt_feats.shape) |
|
|
|
|
|
onnx_name = model_name + ".onnx" |
|
os.makedirs(output_dir, exist_ok=True) |
|
output_path = f"{output_dir}/{onnx_name}" |
|
with torch.no_grad(): |
|
torch.onnx.export(self, |
|
(x, txt_feats), |
|
output_path, |
|
do_constant_folding=True, |
|
opset_version=17, |
|
input_names=["images", "txt_feats"], |
|
output_names=["output"]) |
|
|
|
return output_path |