Upload 2 files
Browse files- ResNet_int_NHWC.onnx +3 -0
- eval_onnx.py +1 -4
ResNet_int_NHWC.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7f27d4fb250334533478861df64a192e141d4c7877ec465ce1cafd241a7dcf7
|
3 |
+
size 102275699
|
eval_onnx.py
CHANGED
@@ -68,7 +68,6 @@ def accuracy(output: torch.Tensor,
|
|
68 |
output: Prediction of the model.
|
69 |
target: Ground truth labels.
|
70 |
topk: Topk accuracy to compute.
|
71 |
-
|
72 |
Returns:
|
73 |
Accuracy results according to 'topk'.
|
74 |
"""
|
@@ -91,13 +90,11 @@ def prepare_data_loader(data_dir: str,
|
|
91 |
batch_size: int = 100,
|
92 |
workers: int = 8) -> torch.utils.data.DataLoader:
|
93 |
"""Returns a validation data loader of ImageNet by given `data_dir`.
|
94 |
-
|
95 |
Args:
|
96 |
data_dir: Directory where images stores. There must be a subdirectory named
|
97 |
'validation' that stores the validation set of ImageNet.
|
98 |
batch_size: Batch size of data loader.
|
99 |
workers: How many subprocesses to use for data loading.
|
100 |
-
|
101 |
Returns:
|
102 |
An object of torch.utils.data.DataLoader.
|
103 |
"""
|
@@ -144,7 +141,7 @@ def val_imagenet():
|
|
144 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
145 |
with torch.no_grad():
|
146 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
147 |
-
inputs, targets = images.numpy(), targets
|
148 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
149 |
|
150 |
outputs = ort_session.run(None, ort_inputs)
|
|
|
68 |
output: Prediction of the model.
|
69 |
target: Ground truth labels.
|
70 |
topk: Topk accuracy to compute.
|
|
|
71 |
Returns:
|
72 |
Accuracy results according to 'topk'.
|
73 |
"""
|
|
|
90 |
batch_size: int = 100,
|
91 |
workers: int = 8) -> torch.utils.data.DataLoader:
|
92 |
"""Returns a validation data loader of ImageNet by given `data_dir`.
|
|
|
93 |
Args:
|
94 |
data_dir: Directory where images stores. There must be a subdirectory named
|
95 |
'validation' that stores the validation set of ImageNet.
|
96 |
batch_size: Batch size of data loader.
|
97 |
workers: How many subprocesses to use for data loading.
|
|
|
98 |
Returns:
|
99 |
An object of torch.utils.data.DataLoader.
|
100 |
"""
|
|
|
141 |
val_loader = tqdm(val_loader, file=sys.stdout)
|
142 |
with torch.no_grad():
|
143 |
for batch_idx, (images, targets) in enumerate(val_loader):
|
144 |
+
inputs, targets = images.numpy().transpose(0, 2, 3, 1), targets
|
145 |
ort_inputs = {ort_session.get_inputs()[0].name: inputs}
|
146 |
|
147 |
outputs = ort_session.run(None, ort_inputs)
|