Niv Sardi commited on
Commit
e04a33d
1 Parent(s): 26ef429

entities: make read_entities return entities

Browse files

Signed-off-by: Niv Sardi <xaiki@evilgiggle.com>

Files changed (3) hide show
  1. python/entity.py +1 -1
  2. python/write_data.py +1 -1
  3. train.sh +60 -0
python/entity.py CHANGED
@@ -7,7 +7,7 @@ from common import defaults
7
  def read_entities(fn = defaults.MAIN_CSV_PATH):
8
  with open(fn, newline='') as csvfile:
9
  reader = csv.DictReader(csvfile)
10
- bcos = { d['bco']:d for d in reader}
11
  return bcos
12
 
13
  class Entity(NamedTuple):
 
7
  def read_entities(fn = defaults.MAIN_CSV_PATH):
8
  with open(fn, newline='') as csvfile:
9
  reader = csv.DictReader(csvfile)
10
+ bcos = { d['bco']:Entity.from_dict(d) for d in reader}
11
  return bcos
12
 
13
  class Entity(NamedTuple):
python/write_data.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  from common import defaults
6
 
7
  def gen_data_yaml(bcos, datapath='../data'):
8
- names = [f"{d['name']}" for d in bcos.values()]
9
  return f'''
10
  # this file is autogenerated by write_data.py
11
 
 
5
  from common import defaults
6
 
7
  def gen_data_yaml(bcos, datapath='../data'):
8
+ names = [f"{d.name}" for d in bcos.values()]
9
  return f'''
10
  # this file is autogenerated by write_data.py
11
 
train.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ set -e
4
+
5
+ PY=python3
6
+ DATA_PATH=${PWD}/data
7
+
8
+ ${PY} ./python/write_data.py ./data/entities.csv --data ${DATA_PATH} > ${DATA_PATH}/data.yaml
9
+ grep nc ${DATA_PATH}/data.yaml > ${DATA_PATH}/custom_yolov5s.yaml
10
+ cat <<EOF >> ${DATA_PATH}/custom_yolov5s.yaml
11
+ # parameters
12
+ nc: {num_classes} # number of classes # CHANGED HERE
13
+ depth_multiple: 0.33 # model depth multiple
14
+ width_multiple: 0.50 # layer channel multiple
15
+
16
+ # anchors
17
+ anchors:
18
+ - [10,13, 16,30, 33,23] # P3/8
19
+ - [30,61, 62,45, 59,119] # P4/16
20
+ - [116,90, 156,198, 373,326] # P5/32
21
+
22
+ # YOLOv5 backbone
23
+ backbone:
24
+ # [from, number, module, args]
25
+ [[-1, 1, Focus, [64, 3]], # 0-P1/2
26
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
27
+ [-1, 3, BottleneckCSP, [128]],
28
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
29
+ [-1, 9, BottleneckCSP, [256]],
30
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
31
+ [-1, 9, BottleneckCSP, [512]],
32
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
33
+ [-1, 1, SPP, [1024, [5, 9, 13]]],
34
+ [-1, 3, BottleneckCSP, [1024, False]], # 9
35
+ ]
36
+
37
+ # YOLOv5 head
38
+ head:
39
+ [[-1, 1, Conv, [512, 1, 1]],
40
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
41
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
42
+ [-1, 3, BottleneckCSP, [512, False]], # 13
43
+
44
+ [-1, 1, Conv, [256, 1, 1]],
45
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
46
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
47
+ [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small)
48
+
49
+ [-1, 1, Conv, [256, 3, 2]],
50
+ [[-1, 14], 1, Concat, [1]], # cat head P4
51
+ [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium)
52
+
53
+ [-1, 1, Conv, [512, 3, 2]],
54
+ [[-1, 10], 1, Concat, [1]], # cat head P5
55
+ [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large)
56
+
57
+ [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
58
+ ]
59
+ EOF
60
+ (cd yolov5; ${PY} train.py --img 416 --batch 80 --epochs 1000 --data ${DATA_PATH}/data.yaml --cfg ${DATA_PATH}/custom_yolov5s.yaml --weights '')