Niv Sardi commited on
Commit
3d78160
1 Parent(s): eb42660

python/write_data: properly support split and yolo v5/6

Browse files
Files changed (1) hide show
  1. python/write_data.py +10 -9
python/write_data.py CHANGED
@@ -4,15 +4,15 @@ import argparse
4
 
5
  from common import defaults
6
 
7
- YOLO_TEMPLATES = {
8
  5: '''
9
- train: %%datapath%%/squares
10
- val: %%datapath%%squares
11
  ''',
12
  6: '''
13
- train: %%datapath%%/squares/images
14
- val: %%datapath%%/squares/images
15
- test: %%datapath%%/squares/images
16
 
17
  is_coco: False
18
  '''
@@ -23,14 +23,14 @@ def gen_data_yaml(bcos, datapath='../data', version=6):
23
  return f'''
24
  # this file is autogenerated by write_data.py for YOLO version {version}
25
 
26
- {YOLO_TEMPLATES[version].replace('%%datapath%%', datapath)}
27
 
28
  nc: {len(bcos.keys())}
29
  names: {names}
30
  '''
31
 
32
  if __name__ == '__main__':
33
- parser = argparse.ArgumentParser(description='creates a YOLOv5 data.yaml')
34
  parser.add_argument('csv', metavar='csv', type=str,
35
  help='csv file', default=defaults.MAIN_CSV_PATH)
36
  parser.add_argument('--version', metavar='version', type=int,
@@ -39,4 +39,5 @@ if __name__ == '__main__':
39
  help='data path', default=defaults.DATA_PATH)
40
  args = parser.parse_args()
41
  bcos = entity.read_entities(args.csv)
42
- print(gen_data_yaml(bcos, args.data, args.version))
 
 
4
 
5
  from common import defaults
6
 
7
+ YOLO_DATA_TEMPLATES = {
8
  5: '''
9
+ train: %%datapath%%/split/train
10
+ val: %%datapath%%/split/val
11
  ''',
12
  6: '''
13
+ train: %%datapath%%/split/images/train
14
+ val: %%datapath%%/split/images/val
15
+ test: %%datapath%%/split/images/test
16
 
17
  is_coco: False
18
  '''
 
23
  return f'''
24
  # this file is autogenerated by write_data.py for YOLO version {version}
25
 
26
+ {YOLO_DATA_TEMPLATES[version].replace('%%datapath%%', datapath)}
27
 
28
  nc: {len(bcos.keys())}
29
  names: {names}
30
  '''
31
 
32
  if __name__ == '__main__':
33
+ parser = argparse.ArgumentParser(description='creates a YOLOv{5,6} data.yaml and trains it')
34
  parser.add_argument('csv', metavar='csv', type=str,
35
  help='csv file', default=defaults.MAIN_CSV_PATH)
36
  parser.add_argument('--version', metavar='version', type=int,
 
39
  help='data path', default=defaults.DATA_PATH)
40
  args = parser.parse_args()
41
  bcos = entity.read_entities(args.csv)
42
+ with open(f'{defaults.DATA_PATH}/data.yaml', 'w') as f:
43
+ f.write(gen_data_yaml(bcos, args.data, args.version))