Spaces:
Runtime error
Runtime error
Ibai Gorordo
commited on
Commit
•
7b4bb4b
1
Parent(s):
bc68550
Add files
Browse files- ExportModel.py +40 -0
- app.py +40 -0
- requirements.txt +2 -0
ExportModel.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
class ModelExporter(torch.nn.Module):
|
6 |
+
def __init__(self, yoloModel, device='cpu'):
|
7 |
+
super(ModelExporter, self).__init__()
|
8 |
+
model = deepcopy(yoloModel).to(device)
|
9 |
+
for p in model.parameters():
|
10 |
+
p.requires_grad = False
|
11 |
+
model.eval()
|
12 |
+
model.float()
|
13 |
+
model = model.fuse()
|
14 |
+
|
15 |
+
self.model = model
|
16 |
+
self.device = device
|
17 |
+
|
18 |
+
def forward(self, x, txt_feats):
|
19 |
+
return self.model.predict(x, txt_feats=txt_feats)
|
20 |
+
|
21 |
+
def export(self, output_dir, model_name, img_width, img_height, num_classes):
|
22 |
+
x = torch.randn(1, 3, img_width, img_height, requires_grad=False).to(self.device)
|
23 |
+
txt_feats = torch.randn(1, num_classes, 512, requires_grad=False).to(self.device)
|
24 |
+
|
25 |
+
print(x.shape, txt_feats.shape)
|
26 |
+
|
27 |
+
# Export model
|
28 |
+
onnx_name = model_name + ".onnx"
|
29 |
+
os.makedirs(output_dir, exist_ok=True)
|
30 |
+
output_path = f"{output_dir}/{onnx_name}"
|
31 |
+
with torch.no_grad():
|
32 |
+
torch.onnx.export(self,
|
33 |
+
(x, txt_feats),
|
34 |
+
output_path,
|
35 |
+
do_constant_folding=True,
|
36 |
+
opset_version=17,
|
37 |
+
input_names=["images", "txt_feats"],
|
38 |
+
output_names=["output"])
|
39 |
+
|
40 |
+
return output_path
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from ultralytics import YOLOWorld
|
3 |
+
|
4 |
+
from ExportModel import ModelExporter
|
5 |
+
|
6 |
+
model_list = ['yolov8s-worldv2', 'yolov8m-worldv2', 'yolov8l-worldv2', 'yolov8x-worldv2']
|
7 |
+
|
8 |
+
def export_model(model, width, height, num_classes):
|
9 |
+
model_name = model
|
10 |
+
img_width = width
|
11 |
+
img_height = height
|
12 |
+
num_classes = num_classes
|
13 |
+
|
14 |
+
yoloModel = YOLOWorld(model_name)
|
15 |
+
yoloModel.set_classes([""] * num_classes)
|
16 |
+
|
17 |
+
# Initialize model exporter
|
18 |
+
export_model = ModelExporter(yoloModel.model)
|
19 |
+
|
20 |
+
# Export model
|
21 |
+
output_path = export_model.export("temp", model_name, img_width, img_height, num_classes)
|
22 |
+
|
23 |
+
return output_path
|
24 |
+
|
25 |
+
|
26 |
+
demo = gr.Interface(
|
27 |
+
export_model,
|
28 |
+
[
|
29 |
+
gr.Dropdown(model_list, label="model", value=model_list[0]),
|
30 |
+
gr.Slider(32, 4096, step=32, value=640, label="width"),
|
31 |
+
gr.Slider(32, 4096, step=32, value=480, label="height"),
|
32 |
+
gr.Number(label="num_classes", value=1),
|
33 |
+
],
|
34 |
+
"file",
|
35 |
+
title="ONNX Export Ultralytics YOLO-World Open Vocabulary Object Detection",
|
36 |
+
description="Demo to export Ultralytics YOLO-World Open Vocabulary Object Detection model to ONNX",
|
37 |
+
api_name="export"
|
38 |
+
)
|
39 |
+
if __name__ == "__main__":
|
40 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
ultralytics
|
2 |
+
gradio
|