Niv Sardi commited on
Commit
eb42660
1 Parent(s): a7ac778

python/split: support yolov5 and v6

Browse files

they have slightly different data formats because…

Files changed (2) hide show
  1. python/split.py +16 -2
  2. run.sh +2 -1
python/split.py CHANGED
@@ -3,6 +3,17 @@ import os
3
  import math
4
  from common import defaults, mkdir
5
 
 
 
 
 
 
 
 
 
 
 
 
6
  if __name__ == '__main__':
7
  import argparse
8
  parser = argparse.ArgumentParser(description='splits a yolo dataset between different data partitions')
@@ -12,8 +23,11 @@ if __name__ == '__main__':
12
  help='data path', default=['train:0.8', 'val:0.1', 'test:0.1'])
13
  parser.add_argument('--dest', metavar='dest', type=str,
14
  help='dest path', default=defaults.SPLIT_DATA_PATH)
 
 
15
 
16
  args = parser.parse_args()
 
17
 
18
  def image_to_label(i):
19
  l = i.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
@@ -28,8 +42,8 @@ if __name__ == '__main__':
28
  p = np + 1
29
  np = min(p + math.floor(len(images)*float(r)), len(images))
30
 
31
- cpi = os.path.join(args.dest, d, 'images')
32
- cpl = os.path.join(args.dest, d, 'labels')
33
  rpi = os.path.relpath(os.path.join(args.datapath, 'images'), cpi)
34
  rpl = os.path.relpath(os.path.join(args.datapath, 'labels'), cpl)
35
 
 
3
  import math
4
  from common import defaults, mkdir
5
 
6
+ PATHS = {
7
+ 6: {
8
+ 'images': lambda dest, d: os.path.join(dest, 'images', d ),
9
+ 'labels': lambda dest, d: os.path.join(dest, 'labels', d )
10
+ },
11
+ 5: {
12
+ 'images': lambda desd, d: os.path.join(dest, d, 'images'),
13
+ 'labels': lambda desd, d: os.path.join(dest, d, 'labels'),
14
+ }
15
+ }
16
+
17
  if __name__ == '__main__':
18
  import argparse
19
  parser = argparse.ArgumentParser(description='splits a yolo dataset between different data partitions')
 
23
  help='data path', default=['train:0.8', 'val:0.1', 'test:0.1'])
24
  parser.add_argument('--dest', metavar='dest', type=str,
25
  help='dest path', default=defaults.SPLIT_DATA_PATH)
26
+ parser.add_argument('--yolo', metavar='yolo', type=int,
27
+ help='yolo version', default=6)
28
 
29
  args = parser.parse_args()
30
+ assert(PATHS[args.yolo])
31
 
32
  def image_to_label(i):
33
  l = i.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
 
42
  p = np + 1
43
  np = min(p + math.floor(len(images)*float(r)), len(images))
44
 
45
+ cpi = PATHS[args.yolo]['images'](args.dest, d)
46
+ cpl = PATHS[args.yolo]['labels'](args.dest, d)
47
  rpi = os.path.relpath(os.path.join(args.datapath, 'images'), cpi)
48
  rpl = os.path.relpath(os.path.join(args.datapath, 'labels'), cpl)
49
 
run.sh CHANGED
@@ -3,6 +3,7 @@ set -e
3
 
4
  PY=python3
5
  PARALLEL=$(cat /proc/cpuinfo | grep processor | wc -l)
 
6
 
7
  echo "📊 detected ${PARALLEL} cores"
8
  echo "🏛 fetching entities"
@@ -16,6 +17,6 @@ ${PY} ./python/augment.py
16
  echo "🖼 croping augmented data"
17
  ${PY} ./python/crop.py ./data/augmented/images
18
  echo "✂ split dataset into train, val and test groups"
19
- ${PY} ./python/split.py ./data/squares/
20
  echo "🧠 train model"
21
  sh train.sh
 
3
 
4
  PY=python3
5
  PARALLEL=$(cat /proc/cpuinfo | grep processor | wc -l)
6
+ YOLO=6
7
 
8
  echo "📊 detected ${PARALLEL} cores"
9
  echo "🏛 fetching entities"
 
17
  echo "🖼 croping augmented data"
18
  ${PY} ./python/crop.py ./data/augmented/images
19
  echo "✂ split dataset into train, val and test groups"
20
+ ${PY} ./python/split.py ./data/squares/ --yolo $YOLO
21
  echo "🧠 train model"
22
  sh train.sh