glenn-jocher commited on
Commit
9b6dba6
1 Parent(s): df7706d

Update `dataset_stats()` to list of dicts (#3657)

Browse files

* Update `dataset_stats()` to list of dicts

@KalenMike

* Update datasets.py

Files changed (1) hide show
  1. utils/datasets.py +9 -3
utils/datasets.py CHANGED
@@ -1099,6 +1099,11 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
1099
  autodownload: Attempt to download dataset if not found locally
1100
  verbose: Print stats dictionary
1101
  """
 
 
 
 
 
1102
  with open(check_file(path)) as f:
1103
  data = yaml.safe_load(f) # data dict
1104
  check_dataset(data, autodownload) # download dataset if missing
@@ -1118,12 +1123,13 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
1118
  stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
1119
  'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
1120
  'per_class': (x > 0).sum(0).tolist()},
1121
- 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}}
 
1122
 
1123
  # Save, print and return
1124
  with open(cache_path.with_suffix('.json'), 'w') as f:
1125
  json.dump(stats, f) # save stats *.json
1126
  if verbose:
1127
- print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
1128
- # print(json.dumps(stats, indent=2, sort_keys=False))
1129
  return stats
 
1099
  autodownload: Attempt to download dataset if not found locally
1100
  verbose: Print stats dictionary
1101
  """
1102
+
1103
+ def round_labels(labels):
1104
+ # Update labels to integer class and 6 decimal place floats
1105
+ return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels]
1106
+
1107
  with open(check_file(path)) as f:
1108
  data = yaml.safe_load(f) # data dict
1109
  check_dataset(data, autodownload) # download dataset if missing
 
1123
  stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
1124
  'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
1125
  'per_class': (x > 0).sum(0).tolist()},
1126
+ 'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
1127
+ zip(dataset.img_files, dataset.labels)]}
1128
 
1129
  # Save, print and return
1130
  with open(cache_path.with_suffix('.json'), 'w') as f:
1131
  json.dump(stats, f) # save stats *.json
1132
  if verbose:
1133
+ print(json.dumps(stats, indent=2, sort_keys=False))
1134
+ # print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
1135
  return stats