Add calibration for int8 quantization
Browse files- yolo_nas_pose_to_onnx.py +68 -2
yolo_nas_pose_to_onnx.py
CHANGED
@@ -12,6 +12,12 @@ import onnxruntime
|
|
12 |
import os
|
13 |
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
|
14 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
os.environ['CRASH_HANDLER']='0'
|
17 |
|
@@ -19,7 +25,7 @@ os.environ['CRASH_HANDLER']='0'
|
|
19 |
|
20 |
CONVERSION = True
|
21 |
input_image_shape = [640, 640]
|
22 |
-
quantization_modes = [
|
23 |
output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT
|
24 |
|
25 |
# NMS-related Setting
|
@@ -37,6 +43,61 @@ image_name = "https://deci-pretrained-models.s3.amazonaws.com/sample_images/beat
|
|
37 |
# Check
|
38 |
SHAPE_CHECK=True
|
39 |
VISUAL_CHECK=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def iterate_over_flat_predictions(predictions, batch_size):
|
42 |
[flat_predictions] = predictions
|
@@ -65,6 +126,11 @@ image = load_image(image_name)
|
|
65 |
image = cv2.resize(image, (input_image_shape[1], input_image_shape[0]))
|
66 |
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))
|
67 |
|
|
|
|
|
|
|
|
|
|
|
68 |
for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S ]:
|
69 |
for q in quantization_modes:
|
70 |
|
@@ -94,7 +160,7 @@ for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO
|
|
94 |
engine=ExportTargetBackend.ONNXRUNTIME,
|
95 |
quantization_mode=q,
|
96 |
#selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
|
97 |
-
|
98 |
#calibration_method: str = "percentile",
|
99 |
#calibration_batches: int = 16,
|
100 |
#calibration_percentile: float = 99.99,
|
|
|
12 |
import os
|
13 |
from super_gradients.training.utils.visualization.pose_estimation import PoseVisualization
|
14 |
import matplotlib.pyplot as plt
|
15 |
+
from datasets import load_dataset
|
16 |
+
from torchvision import transforms
|
17 |
+
from torch.utils.data import DataLoader, Dataset
|
18 |
+
from torchvision import transforms
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
|
21 |
|
22 |
os.environ['CRASH_HANDLER']='0'
|
23 |
|
|
|
25 |
|
26 |
CONVERSION = True
|
27 |
input_image_shape = [640, 640]
|
28 |
+
quantization_modes = [ExportQuantizationMode.INT8, ExportQuantizationMode.FP16, None]
|
29 |
output_predictions_format=DetectionOutputFormatMode.FLAT_FORMAT
|
30 |
|
31 |
# NMS-related Setting
|
|
|
43 |
# Check
|
44 |
SHAPE_CHECK=True
|
45 |
VISUAL_CHECK=True
|
46 |
+
CALIBRATION_DATASET_CHECK=False
|
47 |
+
|
48 |
+
# Function to convert tensor to image for visualization
|
49 |
+
def tensor_to_image(tensor):
|
50 |
+
# Convert the tensor to a numpy array
|
51 |
+
numpy_image = tensor.numpy()
|
52 |
+
|
53 |
+
# The output of ToTensor() is in C x H x W format, convert to H x W x C
|
54 |
+
numpy_image = numpy_image.transpose(1, 2, 0)
|
55 |
+
|
56 |
+
# Undo the normalization (if any)
|
57 |
+
# numpy_image = numpy_image * std + mean # Adjust based on your normalization
|
58 |
+
|
59 |
+
return numpy_image
|
60 |
+
|
61 |
+
class HFDatasetWrapper(Dataset):
|
62 |
+
def __init__(self, hf_dataset, transform=None):
|
63 |
+
self.hf_dataset = hf_dataset
|
64 |
+
self.transform = transform
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.hf_dataset)
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
item = self.hf_dataset[idx]
|
71 |
+
if self.transform:
|
72 |
+
item = self.transform(item)
|
73 |
+
return item['image']
|
74 |
+
|
75 |
+
def preprocess(data):
|
76 |
+
# Convert byte data to PIL Image
|
77 |
+
image = data['image']
|
78 |
+
|
79 |
+
# Convert to RGB if not already
|
80 |
+
if image.mode != 'RGB':
|
81 |
+
image = image.convert('RGB')
|
82 |
+
|
83 |
+
# Define your transformations
|
84 |
+
transform = transforms.Compose([
|
85 |
+
transforms.Resize((640, 640)), # Resize (example size)
|
86 |
+
transforms.ToTensor(), # Convert to tensor
|
87 |
+
# Add normalization or other transformations if needed
|
88 |
+
])
|
89 |
+
|
90 |
+
# Process Image
|
91 |
+
transformed = transform(image)
|
92 |
+
|
93 |
+
if CALIBRATION_DATASET_CHECK:
|
94 |
+
# Display the Processed Image
|
95 |
+
plt_image = tensor_to_image(transformed)
|
96 |
+
plt.imshow(plt_image)
|
97 |
+
plt.axis('off') # Turn off axis numbers
|
98 |
+
plt.show()
|
99 |
+
|
100 |
+
return {'image': transformed}
|
101 |
|
102 |
def iterate_over_flat_predictions(predictions, batch_size):
|
103 |
[flat_predictions] = predictions
|
|
|
126 |
image = cv2.resize(image, (input_image_shape[1], input_image_shape[0]))
|
127 |
image_bchw = np.transpose(np.expand_dims(image, 0), (0, 3, 1, 2))
|
128 |
|
129 |
+
# Prepare Calibration Dataset for INT8 Quantization
|
130 |
+
dataset = load_dataset("cppe-5", split="test")
|
131 |
+
hf_dataset_wrapper = HFDatasetWrapper(dataset, transform=preprocess)
|
132 |
+
calibration_loader = DataLoader(hf_dataset_wrapper, batch_size=8)
|
133 |
+
|
134 |
for model_name in [Models.YOLO_NAS_POSE_L, Models.YOLO_NAS_POSE_M, Models.YOLO_NAS_POSE_N, Models.YOLO_NAS_POSE_S ]:
|
135 |
for q in quantization_modes:
|
136 |
|
|
|
160 |
engine=ExportTargetBackend.ONNXRUNTIME,
|
161 |
quantization_mode=q,
|
162 |
#selective_quantizer: Optional["SelectiveQuantizer"] = None, # noqa
|
163 |
+
calibration_loader = calibration_loader,
|
164 |
#calibration_method: str = "percentile",
|
165 |
#calibration_batches: int = 16,
|
166 |
#calibration_percentile: float = 99.99,
|