wangyuwy commited on
Commit
0f942c8
1 Parent(s): 9d8126f

Upload 2 files

Browse files
Files changed (2) hide show
  1. ResNet_int_NHWC.onnx +3 -0
  2. eval_onnx.py +1 -4
ResNet_int_NHWC.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7f27d4fb250334533478861df64a192e141d4c7877ec465ce1cafd241a7dcf7
3
+ size 102275699
eval_onnx.py CHANGED
@@ -68,7 +68,6 @@ def accuracy(output: torch.Tensor,
68
  output: Prediction of the model.
69
  target: Ground truth labels.
70
  topk: Topk accuracy to compute.
71
-
72
  Returns:
73
  Accuracy results according to 'topk'.
74
  """
@@ -91,13 +90,11 @@ def prepare_data_loader(data_dir: str,
91
  batch_size: int = 100,
92
  workers: int = 8) -> torch.utils.data.DataLoader:
93
  """Returns a validation data loader of ImageNet by given `data_dir`.
94
-
95
  Args:
96
  data_dir: Directory where images stores. There must be a subdirectory named
97
  'validation' that stores the validation set of ImageNet.
98
  batch_size: Batch size of data loader.
99
  workers: How many subprocesses to use for data loading.
100
-
101
  Returns:
102
  An object of torch.utils.data.DataLoader.
103
  """
@@ -144,7 +141,7 @@ def val_imagenet():
144
  val_loader = tqdm(val_loader, file=sys.stdout)
145
  with torch.no_grad():
146
  for batch_idx, (images, targets) in enumerate(val_loader):
147
- inputs, targets = images.numpy(), targets
148
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
149
 
150
  outputs = ort_session.run(None, ort_inputs)
 
68
  output: Prediction of the model.
69
  target: Ground truth labels.
70
  topk: Topk accuracy to compute.
 
71
  Returns:
72
  Accuracy results according to 'topk'.
73
  """
 
90
  batch_size: int = 100,
91
  workers: int = 8) -> torch.utils.data.DataLoader:
92
  """Returns a validation data loader of ImageNet by given `data_dir`.
 
93
  Args:
94
  data_dir: Directory where images stores. There must be a subdirectory named
95
  'validation' that stores the validation set of ImageNet.
96
  batch_size: Batch size of data loader.
97
  workers: How many subprocesses to use for data loading.
 
98
  Returns:
99
  An object of torch.utils.data.DataLoader.
100
  """
 
141
  val_loader = tqdm(val_loader, file=sys.stdout)
142
  with torch.no_grad():
143
  for batch_idx, (images, targets) in enumerate(val_loader):
144
+ inputs, targets = images.numpy().transpose(0, 2, 3, 1), targets
145
  ort_inputs = {ort_session.get_inputs()[0].name: inputs}
146
 
147
  outputs = ort_session.run(None, ort_inputs)