Kedreamix commited on
Commit
4a3ab35
1 Parent(s): bea3c7d

YoloGesture推理主要代码

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ model_data/simhei.ttf filter=lfs diff=lfs merge=lfs -text
img/anticlockwise.jpg ADDED
img/back.jpg ADDED
img/clockwise.jpg ADDED
img/down.jpg ADDED
img/front.jpg ADDED
img/left.jpg ADDED
img/right.jpg ADDED
img/up.jpg ADDED
model_data/gesture.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #------------------------------detect.py--------------------------------#
2
+ # 这一部分是为了半自动标注数据,可以减轻负担,需要提前训练一个权重,以Labelme格式保存
3
+ # dir_origin_path 图片存放位置
4
+ # dir_save_path Annotation保存位置
5
+ # ----------------------------------------------------------------------#
6
+ dir_detect_path: ./JPEGImages
7
+ detect_save_path: ./Annotation
8
+
9
+ # ----------------------------- train.py -------------------------------#
10
+ nc: 8 # 类别的数量
11
+ classes: ["up","down","left","right","front","back","clockwise","anticlockwise"] # 类别
12
+ confidence: 0.5 # 置信度
13
+ nms_iou: 0.3
14
+ letterbox_image: False
15
+
16
+ lr_decay_type: cos # 使用到的学习率下降方式,可选的有step、cos
17
+ # 用于设置是否使用多线程读取数据
18
+ # 开启后会加快数据读取速度,但是会占用更多内存
19
+ # 内存较小的电脑可以设置为2或者0,win建议设为0
20
+ num_workers: 4
model_data/gesture_classes.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ up
2
+ down
3
+ left
4
+ right
5
+ front
6
+ back
7
+ clockwise
8
+ anticlockwise
model_data/simhei.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa4560dd8fe5645745fed3ffa301c3ca4d6c03cbd738145b613303961ba733b8
3
+ size 9753388
model_data/yolo_anchors.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
model_data/yolotiny_anchors.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
nets/CSPdarknet.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ #-------------------------------------------------#
10
+ # MISH激活函数
11
+ #-------------------------------------------------#
12
+ class Mish(nn.Module):
13
+ def __init__(self):
14
+ super(Mish, self).__init__()
15
+
16
+ def forward(self, x):
17
+ return x * torch.tanh(F.softplus(x))
18
+
19
+ #---------------------------------------------------#
20
+ # 卷积块 -> 卷积 + 标准化 + 激活函数
21
+ # Conv2d + BatchNormalization + Mish
22
+ #---------------------------------------------------#
23
+ class BasicConv(nn.Module):
24
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1):
25
+ super(BasicConv, self).__init__()
26
+
27
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
28
+ self.bn = nn.BatchNorm2d(out_channels)
29
+ self.activation = Mish()
30
+
31
+ def forward(self, x):
32
+ x = self.conv(x)
33
+ x = self.bn(x)
34
+ x = self.activation(x)
35
+ return x
36
+
37
+ #---------------------------------------------------#
38
+ # CSPdarknet的结构块的组成部分
39
+ # 内部堆叠的残差块
40
+ #---------------------------------------------------#
41
+ class Resblock(nn.Module):
42
+ def __init__(self, channels, hidden_channels=None):
43
+ super(Resblock, self).__init__()
44
+
45
+ if hidden_channels is None:
46
+ hidden_channels = channels
47
+
48
+ self.block = nn.Sequential(
49
+ BasicConv(channels, hidden_channels, 1),
50
+ BasicConv(hidden_channels, channels, 3)
51
+ )
52
+
53
+ def forward(self, x):
54
+ return x + self.block(x)
55
+
56
+ #--------------------------------------------------------------------#
57
+ # CSPdarknet的结构块
58
+ # 首先利用ZeroPadding2D和一个步长为2x2的卷积块进行高和宽的压缩
59
+ # 然后建立一个大的残差边shortconv、这个大残差边绕过了很多的残差结构
60
+ # 主干部分会对num_blocks进行循环,循环内部是残差结构。
61
+ # 对于整个CSPdarknet的结构块,就是一个大残差块+内部多个小残差块
62
+ #--------------------------------------------------------------------#
63
+ class Resblock_body(nn.Module):
64
+ def __init__(self, in_channels, out_channels, num_blocks, first):
65
+ super(Resblock_body, self).__init__()
66
+ #----------------------------------------------------------------#
67
+ # 利用一个步长为2x2的卷积块进行高和宽的压缩
68
+ #----------------------------------------------------------------#
69
+ self.downsample_conv = BasicConv(in_channels, out_channels, 3, stride=2)
70
+
71
+ if first:
72
+ #--------------------------------------------------------------------------#
73
+ # 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
74
+ #--------------------------------------------------------------------------#
75
+ self.split_conv0 = BasicConv(out_channels, out_channels, 1)
76
+
77
+ #----------------------------------------------------------------#
78
+ # 主干部分会对num_blocks进行循环,循环内部是残差结构。
79
+ #----------------------------------------------------------------#
80
+ self.split_conv1 = BasicConv(out_channels, out_channels, 1)
81
+ self.blocks_conv = nn.Sequential(
82
+ Resblock(channels=out_channels, hidden_channels=out_channels//2),
83
+ BasicConv(out_channels, out_channels, 1)
84
+ )
85
+
86
+ self.concat_conv = BasicConv(out_channels*2, out_channels, 1)
87
+ else:
88
+ #--------------------------------------------------------------------------#
89
+ # 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
90
+ #--------------------------------------------------------------------------#
91
+ self.split_conv0 = BasicConv(out_channels, out_channels//2, 1)
92
+
93
+ #----------------------------------------------------------------#
94
+ # 主干部分会对num_blocks进行循环,循环内部是残差结构。
95
+ #----------------------------------------------------------------#
96
+ self.split_conv1 = BasicConv(out_channels, out_channels//2, 1)
97
+ self.blocks_conv = nn.Sequential(
98
+ *[Resblock(out_channels//2) for _ in range(num_blocks)],
99
+ BasicConv(out_channels//2, out_channels//2, 1)
100
+ )
101
+
102
+ self.concat_conv = BasicConv(out_channels, out_channels, 1)
103
+
104
+ def forward(self, x):
105
+ x = self.downsample_conv(x)
106
+
107
+ x0 = self.split_conv0(x)
108
+
109
+ x1 = self.split_conv1(x)
110
+ x1 = self.blocks_conv(x1)
111
+
112
+ #------------------------------------#
113
+ # 将大残差边再堆叠回来
114
+ #------------------------------------#
115
+ x = torch.cat([x1, x0], dim=1)
116
+ #------------------------------------#
117
+ # 最���对通道数进行整合
118
+ #------------------------------------#
119
+ x = self.concat_conv(x)
120
+
121
+ return x
122
+
123
+ #---------------------------------------------------#
124
+ # CSPdarknet53 的主体部分
125
+ # 输入为一张416x416x3的图片
126
+ # 输出为三个有效特征层
127
+ #---------------------------------------------------#
128
+ class CSPDarkNet(nn.Module):
129
+ def __init__(self, layers):
130
+ super(CSPDarkNet, self).__init__()
131
+ self.inplanes = 32
132
+ # 416,416,3 -> 416,416,32
133
+ self.conv1 = BasicConv(3, self.inplanes, kernel_size=3, stride=1)
134
+ self.feature_channels = [64, 128, 256, 512, 1024]
135
+
136
+ self.stages = nn.ModuleList([
137
+ # 416,416,32 -> 208,208,64
138
+ Resblock_body(self.inplanes, self.feature_channels[0], layers[0], first=True),
139
+ # 208,208,64 -> 104,104,128
140
+ Resblock_body(self.feature_channels[0], self.feature_channels[1], layers[1], first=False),
141
+ # 104,104,128 -> 52,52,256
142
+ Resblock_body(self.feature_channels[1], self.feature_channels[2], layers[2], first=False),
143
+ # 52,52,256 -> 26,26,512
144
+ Resblock_body(self.feature_channels[2], self.feature_channels[3], layers[3], first=False),
145
+ # 26,26,512 -> 13,13,1024
146
+ Resblock_body(self.feature_channels[3], self.feature_channels[4], layers[4], first=False)
147
+ ])
148
+
149
+ self.num_features = 1
150
+ for m in self.modules():
151
+ if isinstance(m, nn.Conv2d):
152
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
153
+ m.weight.data.normal_(0, math.sqrt(2. / n))
154
+ elif isinstance(m, nn.BatchNorm2d):
155
+ m.weight.data.fill_(1)
156
+ m.bias.data.zero_()
157
+
158
+
159
+ def forward(self, x):
160
+ x = self.conv1(x)
161
+
162
+ x = self.stages[0](x)
163
+ x = self.stages[1](x)
164
+ out3 = self.stages[2](x)
165
+ out4 = self.stages[3](out3)
166
+ out5 = self.stages[4](out4)
167
+
168
+ return out3, out4, out5
169
+
170
+ def darknet53(pretrained):
171
+ model = CSPDarkNet([1, 2, 8, 8, 4])
172
+ if pretrained:
173
+ model.load_state_dict(torch.load("model_data/CSPdarknet53_backbone_weights.pth"))
174
+ return model
nets/CSPdarknet53_tiny.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ #-------------------------------------------------#
8
+ # 卷积块
9
+ # Conv2d + BatchNorm2d + LeakyReLU
10
+ #-------------------------------------------------#
11
+ class BasicConv(nn.Module):
12
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1):
13
+ super(BasicConv, self).__init__()
14
+
15
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
16
+ self.bn = nn.BatchNorm2d(out_channels)
17
+ self.activation = nn.LeakyReLU(0.1)
18
+
19
+ def forward(self, x):
20
+ x = self.conv(x)
21
+ x = self.bn(x)
22
+ x = self.activation(x)
23
+ return x
24
+
25
+
26
+ '''
27
+ input
28
+ |
29
+ BasicConv
30
+ -----------------------
31
+ | |
32
+ route_group route
33
+ | |
34
+ BasicConv |
35
+ | |
36
+ ------------------- |
37
+ | | |
38
+ route_1 BasicConv |
39
+ | | |
40
+ -----------------cat |
41
+ | |
42
+ ---- BasicConv |
43
+ | | |
44
+ feat cat---------------------
45
+ |
46
+ MaxPooling2D
47
+ '''
48
+ #---------------------------------------------------#
49
+ # CSPdarknet53-tiny的结构块
50
+ # 存在一个大残差边
51
+ # 这个大残差边绕过了很多的残差结构
52
+ #---------------------------------------------------#
53
+ class Resblock_body(nn.Module):
54
+ def __init__(self, in_channels, out_channels):
55
+ super(Resblock_body, self).__init__()
56
+ self.out_channels = out_channels
57
+
58
+ self.conv1 = BasicConv(in_channels, out_channels, 3)
59
+
60
+ self.conv2 = BasicConv(out_channels//2, out_channels//2, 3)
61
+ self.conv3 = BasicConv(out_channels//2, out_channels//2, 3)
62
+
63
+ self.conv4 = BasicConv(out_channels, out_channels, 1)
64
+ self.maxpool = nn.MaxPool2d([2,2],[2,2])
65
+
66
+ def forward(self, x):
67
+ # 利用一个3x3卷积进行特征整合
68
+ x = self.conv1(x)
69
+ # 引出一个大的残差边route
70
+ route = x
71
+
72
+ c = self.out_channels
73
+ # 对特征层的通道进行分割,取第二部分作为主干部分。
74
+ x = torch.split(x, c//2, dim = 1)[1]
75
+ # 对主干部分进行3x3卷积
76
+ x = self.conv2(x)
77
+ # 引出一个小的残差边route_1
78
+ route1 = x
79
+ # 对第主干部分进行3x3卷积
80
+ x = self.conv3(x)
81
+ # 主干部分与残差部分进行相接
82
+ x = torch.cat([x,route1], dim = 1)
83
+
84
+ # 对相接后的结果进行1x1卷积
85
+ x = self.conv4(x)
86
+ feat = x
87
+ x = torch.cat([route, x], dim = 1)
88
+
89
+ # 利用最大池化进行高和宽的压缩
90
+ x = self.maxpool(x)
91
+ return x,feat
92
+
93
+ class CSPDarkNet(nn.Module):
94
+ def __init__(self):
95
+ super(CSPDarkNet, self).__init__()
96
+ # 首先利用两次步长为2x2的3x3卷积进行高和宽的压缩
97
+ # 416,416,3 -> 208,208,32 -> 104,104,64
98
+ self.conv1 = BasicConv(3, 32, kernel_size=3, stride=2)
99
+ self.conv2 = BasicConv(32, 64, kernel_size=3, stride=2)
100
+
101
+ # 104,104,64 -> 52,52,128
102
+ self.resblock_body1 = Resblock_body(64, 64)
103
+ # 52,52,128 -> 26,26,256
104
+ self.resblock_body2 = Resblock_body(128, 128)
105
+ # 26,26,256 -> 13,13,512
106
+ self.resblock_body3 = Resblock_body(256, 256)
107
+ # 13,13,512 -> 13,13,512
108
+ self.conv3 = BasicConv(512, 512, kernel_size=3)
109
+
110
+ self.num_features = 1
111
+ # 进行权值初始化
112
+ for m in self.modules():
113
+ if isinstance(m, nn.Conv2d):
114
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
115
+ m.weight.data.normal_(0, math.sqrt(2. / n))
116
+ elif isinstance(m, nn.BatchNorm2d):
117
+ m.weight.data.fill_(1)
118
+ m.bias.data.zero_()
119
+
120
+
121
+ def forward(self, x):
122
+ # 416,416,3 -> 208,208,32 -> 104,104,64
123
+ x = self.conv1(x)
124
+ x = self.conv2(x)
125
+
126
+ # 104,104,64 -> 52,52,128
127
+ x, _ = self.resblock_body1(x)
128
+ # 52,52,128 -> 26,26,256
129
+ x, _ = self.resblock_body2(x)
130
+ # 26,26,256 -> x为13,13,512
131
+ # -> feat1为26,26,256
132
+ x, feat1 = self.resblock_body3(x)
133
+
134
+ # 13,13,512 -> 13,13,512
135
+ x = self.conv3(x)
136
+ feat2 = x
137
+ return feat1,feat2
138
+
139
+ def darknet53_tiny(pretrained, **kwargs):
140
+ model = CSPDarkNet()
141
+ if pretrained:
142
+ model.load_state_dict(torch.load("model_data/CSPdarknet53_tiny_backbone_weights.pth"))
143
+ return model
nets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
nets/attention.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class se_block(nn.Module):
6
+ def __init__(self, channel, ratio=16):
7
+ super(se_block, self).__init__()
8
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
9
+ self.fc = nn.Sequential(
10
+ nn.Linear(channel, channel // ratio, bias=False),
11
+ nn.ReLU(inplace=True),
12
+ nn.Linear(channel // ratio, channel, bias=False),
13
+ nn.Sigmoid()
14
+ )
15
+
16
+ def forward(self, x):
17
+ b, c, _, _ = x.size()
18
+ y = self.avg_pool(x).view(b, c)
19
+ y = self.fc(y).view(b, c, 1, 1)
20
+ return x * y
21
+
22
+ class ChannelAttention(nn.Module):
23
+ def __init__(self, in_planes, ratio=8):
24
+ super(ChannelAttention, self).__init__()
25
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
26
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
27
+
28
+ # 利用1x1卷积代替全连接
29
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
30
+ self.relu1 = nn.ReLU()
31
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
32
+
33
+ self.sigmoid = nn.Sigmoid()
34
+
35
+ def forward(self, x):
36
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
37
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
38
+ out = avg_out + max_out
39
+ return self.sigmoid(out)
40
+
41
+ class SpatialAttention(nn.Module):
42
+ def __init__(self, kernel_size=7):
43
+ super(SpatialAttention, self).__init__()
44
+
45
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
46
+ padding = 3 if kernel_size == 7 else 1
47
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
48
+ self.sigmoid = nn.Sigmoid()
49
+
50
+ def forward(self, x):
51
+ avg_out = torch.mean(x, dim=1, keepdim=True)
52
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
53
+ x = torch.cat([avg_out, max_out], dim=1)
54
+ x = self.conv1(x)
55
+ return self.sigmoid(x)
56
+
57
+ class cbam_block(nn.Module):
58
+ def __init__(self, channel, ratio=8, kernel_size=7):
59
+ super(cbam_block, self).__init__()
60
+ self.channelattention = ChannelAttention(channel, ratio=ratio)
61
+ self.spatialattention = SpatialAttention(kernel_size=kernel_size)
62
+
63
+ def forward(self, x):
64
+ x = x*self.channelattention(x)
65
+ x = x*self.spatialattention(x)
66
+ return x
67
+
68
+ class eca_block(nn.Module):
69
+ def __init__(self, channel, b=1, gamma=2):
70
+ super(eca_block, self).__init__()
71
+ kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
72
+ kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
73
+
74
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
75
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
76
+ self.sigmoid = nn.Sigmoid()
77
+
78
+ def forward(self, x):
79
+ y = self.avg_pool(x)
80
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
81
+ y = self.sigmoid(y)
82
+ return x * y.expand_as(x)
83
+
84
+ class CA_Block(nn.Module):
85
+ def __init__(self, channel, reduction=16):
86
+ super(CA_Block, self).__init__()
87
+
88
+ self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel//reduction, kernel_size=1, stride=1, bias=False)
89
+
90
+ self.relu = nn.ReLU()
91
+ self.bn = nn.BatchNorm2d(channel//reduction)
92
+
93
+ self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
94
+ self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
95
+
96
+ self.sigmoid_h = nn.Sigmoid()
97
+ self.sigmoid_w = nn.Sigmoid()
98
+
99
+ def forward(self, x):
100
+ _, _, h, w = x.size()
101
+
102
+ x_h = torch.mean(x, dim = 3, keepdim = True).permute(0, 1, 3, 2)
103
+ x_w = torch.mean(x, dim = 2, keepdim = True)
104
+
105
+ x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
106
+
107
+ x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)
108
+
109
+ s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
110
+ s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
111
+
112
+ out = x * s_h.expand_as(x) * s_w.expand_as(x)
113
+ return out
114
+
nets/yolo.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from nets.CSPdarknet import darknet53
7
+
8
+
9
+ def conv2d(filter_in, filter_out, kernel_size, stride=1):
10
+ pad = (kernel_size - 1) // 2 if kernel_size else 0
11
+ return nn.Sequential(OrderedDict([
12
+ ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
13
+ ("bn", nn.BatchNorm2d(filter_out)),
14
+ ("relu", nn.LeakyReLU(0.1)),
15
+ ]))
16
+
17
+ #---------------------------------------------------#
18
+ # SPP结构,利用不同大小的池化核进行池化
19
+ # 池化后堆叠
20
+ #---------------------------------------------------#
21
+ class SpatialPyramidPooling(nn.Module):
22
+ def __init__(self, pool_sizes=[5, 9, 13]):
23
+ super(SpatialPyramidPooling, self).__init__()
24
+
25
+ self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes])
26
+
27
+ def forward(self, x):
28
+ features = [maxpool(x) for maxpool in self.maxpools[::-1]]
29
+ features = torch.cat(features + [x], dim=1)
30
+
31
+ return features
32
+
33
+ #---------------------------------------------------#
34
+ # 卷积 + 上采样
35
+ #---------------------------------------------------#
36
+ class Upsample(nn.Module):
37
+ def __init__(self, in_channels, out_channels):
38
+ super(Upsample, self).__init__()
39
+
40
+ self.upsample = nn.Sequential(
41
+ conv2d(in_channels, out_channels, 1),
42
+ nn.Upsample(scale_factor=2, mode='nearest')
43
+ )
44
+
45
+ def forward(self, x,):
46
+ x = self.upsample(x)
47
+ return x
48
+
49
+ #---------------------------------------------------#
50
+ # 三次卷积块
51
+ #---------------------------------------------------#
52
+ def make_three_conv(filters_list, in_filters):
53
+ m = nn.Sequential(
54
+ conv2d(in_filters, filters_list[0], 1),
55
+ conv2d(filters_list[0], filters_list[1], 3),
56
+ conv2d(filters_list[1], filters_list[0], 1),
57
+ )
58
+ return m
59
+
60
+ #---------------------------------------------------#
61
+ # 五次卷积块
62
+ #---------------------------------------------------#
63
+ def make_five_conv(filters_list, in_filters):
64
+ m = nn.Sequential(
65
+ conv2d(in_filters, filters_list[0], 1),
66
+ conv2d(filters_list[0], filters_list[1], 3),
67
+ conv2d(filters_list[1], filters_list[0], 1),
68
+ conv2d(filters_list[0], filters_list[1], 3),
69
+ conv2d(filters_list[1], filters_list[0], 1),
70
+ )
71
+ return m
72
+
73
+ #---------------------------------------------------#
74
+ # 最后获得yolov4的输出
75
+ #---------------------------------------------------#
76
+ def yolo_head(filters_list, in_filters):
77
+ m = nn.Sequential(
78
+ conv2d(in_filters, filters_list[0], 3),
79
+ nn.Conv2d(filters_list[0], filters_list[1], 1),
80
+ )
81
+ return m
82
+
83
+ #---------------------------------------------------#
84
+ # yolo_body
85
+ #---------------------------------------------------#
86
+ class YoloBody(nn.Module):
87
+ def __init__(self, anchors_mask, num_classes, pretrained = False):
88
+ super(YoloBody, self).__init__()
89
+ #---------------------------------------------------#
90
+ # 生成CSPdarknet53的主干模型
91
+ # 获得三个有效特征层,他们的shape分别是:
92
+ # 52,52,256
93
+ # 26,26,512
94
+ # 13,13,1024
95
+ #---------------------------------------------------#
96
+ self.backbone = darknet53(pretrained)
97
+
98
+ self.conv1 = make_three_conv([512,1024],1024)
99
+ self.SPP = SpatialPyramidPooling()
100
+ self.conv2 = make_three_conv([512,1024],2048)
101
+
102
+ self.upsample1 = Upsample(512,256)
103
+ self.conv_for_P4 = conv2d(512,256,1)
104
+ self.make_five_conv1 = make_five_conv([256, 512],512)
105
+
106
+ self.upsample2 = Upsample(256,128)
107
+ self.conv_for_P3 = conv2d(256,128,1)
108
+ self.make_five_conv2 = make_five_conv([128, 256],256)
109
+
110
+ # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
111
+ self.yolo_head3 = yolo_head([256, len(anchors_mask[0]) * (5 + num_classes)],128)
112
+
113
+ self.down_sample1 = conv2d(128,256,3,stride=2)
114
+ self.make_five_conv3 = make_five_conv([256, 512],512)
115
+
116
+ # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
117
+ self.yolo_head2 = yolo_head([512, len(anchors_mask[1]) * (5 + num_classes)],256)
118
+
119
+ self.down_sample2 = conv2d(256,512,3,stride=2)
120
+ self.make_five_conv4 = make_five_conv([512, 1024],1024)
121
+
122
+ # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
123
+ self.yolo_head1 = yolo_head([1024, len(anchors_mask[2]) * (5 + num_classes)],512)
124
+
125
+
126
+ def forward(self, x):
127
+ # backbone
128
+ x2, x1, x0 = self.backbone(x)
129
+
130
+ # 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
131
+ P5 = self.conv1(x0)
132
+ P5 = self.SPP(P5)
133
+ # 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
134
+ P5 = self.conv2(P5)
135
+
136
+ # 13,13,512 -> 13,13,256 -> 26,26,256
137
+ P5_upsample = self.upsample1(P5)
138
+ # 26,26,512 -> 26,26,256
139
+ P4 = self.conv_for_P4(x1)
140
+ # 26,26,256 + 26,26,256 -> 26,26,512
141
+ P4 = torch.cat([P4,P5_upsample],axis=1)
142
+ # 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
143
+ P4 = self.make_five_conv1(P4)
144
+
145
+ # 26,26,256 -> 26,26,128 -> 52,52,128
146
+ P4_upsample = self.upsample2(P4)
147
+ # 52,52,256 -> 52,52,128
148
+ P3 = self.conv_for_P3(x2)
149
+ # 52,52,128 + 52,52,128 -> 52,52,256
150
+ P3 = torch.cat([P3,P4_upsample],axis=1)
151
+ # 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
152
+ P3 = self.make_five_conv2(P3)
153
+
154
+ # 52,52,128 -> 26,26,256
155
+ P3_downsample = self.down_sample1(P3)
156
+ # 26,26,256 + 26,26,256 -> 26,26,512
157
+ P4 = torch.cat([P3_downsample,P4],axis=1)
158
+ # 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
159
+ P4 = self.make_five_conv3(P4)
160
+
161
+ # 26,26,256 -> 13,13,512
162
+ P4_downsample = self.down_sample2(P4)
163
+ # 13,13,512 + 13,13,512 -> 13,13,1024
164
+ P5 = torch.cat([P4_downsample,P5],axis=1)
165
+ # 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
166
+ P5 = self.make_five_conv4(P5)
167
+
168
+ #---------------------------------------------------#
169
+ # 第三个特征层
170
+ # y3=(batch_size,75,52,52)
171
+ #---------------------------------------------------#
172
+ out2 = self.yolo_head3(P3)
173
+ #---------------------------------------------------#
174
+ # 第二个特征层
175
+ # y2=(batch_size,75,26,26)
176
+ #---------------------------------------------------#
177
+ out1 = self.yolo_head2(P4)
178
+ #---------------------------------------------------#
179
+ # 第一个特征层
180
+ # y1=(batch_size,75,13,13)
181
+ #---------------------------------------------------#
182
+ out0 = self.yolo_head1(P5)
183
+
184
+ return out0, out1, out2
185
+
nets/yolo_tiny.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from nets.CSPdarknet53_tiny import darknet53_tiny
5
+ from nets.attention import cbam_block, eca_block, se_block, CA_Block
6
+
7
+ attention_block = [se_block, cbam_block, eca_block, CA_Block]
8
+
9
+ #-------------------------------------------------#
10
+ # 卷积块 -> 卷积 + 标准化 + 激活函数
11
+ # Conv2d + BatchNormalization + LeakyReLU
12
+ #-------------------------------------------------#
13
+ class BasicConv(nn.Module):
14
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1):
15
+ super(BasicConv, self).__init__()
16
+
17
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
18
+ self.bn = nn.BatchNorm2d(out_channels)
19
+ self.activation = nn.LeakyReLU(0.1)
20
+
21
+ def forward(self, x):
22
+ x = self.conv(x)
23
+ x = self.bn(x)
24
+ x = self.activation(x)
25
+ return x
26
+
27
+ #---------------------------------------------------#
28
+ # 卷积 + 上采样
29
+ #---------------------------------------------------#
30
+ class Upsample(nn.Module):
31
+ def __init__(self, in_channels, out_channels):
32
+ super(Upsample, self).__init__()
33
+
34
+ self.upsample = nn.Sequential(
35
+ BasicConv(in_channels, out_channels, 1),
36
+ nn.Upsample(scale_factor=2, mode='nearest')
37
+ )
38
+
39
+ def forward(self, x,):
40
+ x = self.upsample(x)
41
+ return x
42
+
43
+ #---------------------------------------------------#
44
+ # 最后获得yolov4的输出
45
+ #---------------------------------------------------#
46
+ def yolo_head(filters_list, in_filters):
47
+ m = nn.Sequential(
48
+ BasicConv(in_filters, filters_list[0], 3),
49
+ nn.Conv2d(filters_list[0], filters_list[1], 1),
50
+ )
51
+ return m
52
+ #---------------------------------------------------#
53
+ # yolo_body
54
+ #---------------------------------------------------#
55
+ class YoloBodytiny(nn.Module):
56
+ def __init__(self, anchors_mask, num_classes, phi=0, pretrained=False):
57
+ super(YoloBodytiny, self).__init__()
58
+ self.phi = phi
59
+ self.backbone = darknet53_tiny(pretrained)
60
+
61
+ self.conv_for_P5 = BasicConv(512,256,1)
62
+ self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
63
+
64
+ self.upsample = Upsample(256,128)
65
+ self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
66
+
67
+ if 1 <= self.phi and self.phi <= 4:
68
+ self.feat1_att = attention_block[self.phi - 1](256)
69
+ self.feat2_att = attention_block[self.phi - 1](512)
70
+ self.upsample_att = attention_block[self.phi - 1](128)
71
+
72
+ def forward(self, x):
73
+ #---------------------------------------------------#
74
+ # 生成CSPdarknet53_tiny的主干模型
75
+ # feat1的shape为26,26,256
76
+ # feat2的shape为13,13,512
77
+ #---------------------------------------------------#
78
+ feat1, feat2 = self.backbone(x)
79
+ if 1 <= self.phi and self.phi <= 4:
80
+ feat1 = self.feat1_att(feat1)
81
+ feat2 = self.feat2_att(feat2)
82
+
83
+ # 13,13,512 -> 13,13,256
84
+ P5 = self.conv_for_P5(feat2)
85
+ # 13,13,256 -> 13,13,512 -> 13,13,255
86
+ out0 = self.yolo_headP5(P5)
87
+
88
+ # 13,13,256 -> 13,13,128 -> 26,26,128
89
+ P5_Upsample = self.upsample(P5)
90
+ # 26,26,256 + 26,26,128 -> 26,26,384
91
+ if 1 <= self.phi and self.phi <= 4:
92
+ P5_Upsample = self.upsample_att(P5_Upsample)
93
+ P4 = torch.cat([P5_Upsample,feat1],axis=1)
94
+
95
+ # 26,26,384 -> 26,26,256 -> 26,26,255
96
+ out1 = self.yolo_headP4(P4)
97
+
98
+ return out0, out1
99
+
nets/yolo_training.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class YOLOLoss(nn.Module):
10
+ def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0, focal_loss = False, alpha = 0.25, gamma = 2):
11
+ super(YOLOLoss, self).__init__()
12
+ #-----------------------------------------------------------#
13
+ # 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
14
+ # 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146]
15
+ # 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28]
16
+ #-----------------------------------------------------------#
17
+ self.anchors = anchors
18
+ self.num_classes = num_classes
19
+ self.bbox_attrs = 5 + num_classes
20
+ self.input_shape = input_shape
21
+ self.anchors_mask = anchors_mask
22
+ self.label_smoothing = label_smoothing
23
+
24
+ self.balance = [0.4, 1.0, 4]
25
+ self.box_ratio = 0.05
26
+ self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
27
+ self.cls_ratio = 1 * (num_classes / 80)
28
+
29
+ self.focal_loss = focal_loss
30
+ self.focal_loss_ratio = 10
31
+ self.alpha = alpha
32
+ self.gamma = gamma
33
+
34
+ self.ignore_threshold = 0.5
35
+ self.cuda = cuda
36
+
37
+ def clip_by_tensor(self, t, t_min, t_max):
38
+ t = t.float()
39
+ result = (t >= t_min).float() * t + (t < t_min).float() * t_min
40
+ result = (result <= t_max).float() * result + (result > t_max).float() * t_max
41
+ return result
42
+
43
+ def MSELoss(self, pred, target):
44
+ return torch.pow(pred - target, 2)
45
+
46
+ def BCELoss(self, pred, target):
47
+ epsilon = 1e-7
48
+ pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
49
+ output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
50
+ return output
51
+
52
+ def box_ciou(self, b1, b2):
53
+ """
54
+ 输入为:
55
+ ----------
56
+ b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
57
+ b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
58
+
59
+ 返回为:
60
+ -------
61
+ ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
62
+ """
63
+ #----------------------------------------------------#
64
+ # 求出预测框左上角右下角
65
+ #----------------------------------------------------#
66
+ b1_xy = b1[..., :2]
67
+ b1_wh = b1[..., 2:4]
68
+ b1_wh_half = b1_wh/2.
69
+ b1_mins = b1_xy - b1_wh_half
70
+ b1_maxes = b1_xy + b1_wh_half
71
+ #----------------------------------------------------#
72
+ # 求出真实框左上角右下角
73
+ #----------------------------------------------------#
74
+ b2_xy = b2[..., :2]
75
+ b2_wh = b2[..., 2:4]
76
+ b2_wh_half = b2_wh/2.
77
+ b2_mins = b2_xy - b2_wh_half
78
+ b2_maxes = b2_xy + b2_wh_half
79
+
80
+ #----------------------------------------------------#
81
+ # 求真实框和预测框所有的iou
82
+ #----------------------------------------------------#
83
+ intersect_mins = torch.max(b1_mins, b2_mins)
84
+ intersect_maxes = torch.min(b1_maxes, b2_maxes)
85
+ intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
86
+ intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
87
+ b1_area = b1_wh[..., 0] * b1_wh[..., 1]
88
+ b2_area = b2_wh[..., 0] * b2_wh[..., 1]
89
+ union_area = b1_area + b2_area - intersect_area
90
+ iou = intersect_area / torch.clamp(union_area,min = 1e-6)
91
+
92
+ #----------------------------------------------------#
93
+ # 计算中心的差距
94
+ #----------------------------------------------------#
95
+ center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
96
+
97
+ #----------------------------------------------------#
98
+ # 找到包裹两个框的最小框的左上角和右下角
99
+ #----------------------------------------------------#
100
+ enclose_mins = torch.min(b1_mins, b2_mins)
101
+ enclose_maxes = torch.max(b1_maxes, b2_maxes)
102
+ enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
103
+ #----------------------------------------------------#
104
+ # 计算对角线距离
105
+ #----------------------------------------------------#
106
+ enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
107
+ ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
108
+
109
+ v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
110
+ alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
111
+ ciou = ciou - alpha * v
112
+ return ciou
113
+
114
+ #---------------------------------------------------#
115
+ # 平滑标签
116
+ #---------------------------------------------------#
117
+ def smooth_labels(self, y_true, label_smoothing, num_classes):
118
+ return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
119
+
120
+ def forward(self, l, input, targets=None):
121
+ #----------------------------------------------------#
122
+ # l 代表使用的是第几个有效特征层
123
+ # input的shape为 bs, 3*(5+num_classes), 13, 13
124
+ # bs, 3*(5+num_classes), 26, 26
125
+ # bs, 3*(5+num_classes), 52, 52
126
+ # targets 真实框的标签情况 [batch_size, num_gt, 5]
127
+ #----------------------------------------------------#
128
+ #--------------------------------#
129
+ # 获得图片数量,特征层的高和宽
130
+ #--------------------------------#
131
+ bs = input.size(0)
132
+ in_h = input.size(2)
133
+ in_w = input.size(3)
134
+ #-----------------------------------------------------------------------#
135
+ # 计算步长
136
+ # 每一个特征点对应原来的图片上多少个像素点
137
+ #
138
+ # 如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
139
+ # 如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点
140
+ # 如果特征层为52x52的话,一个特征点就对应原来的图片上的8个像素点
141
+ # stride_h = stride_w = 32、16、8
142
+ #-----------------------------------------------------------------------#
143
+ stride_h = self.input_shape[0] / in_h
144
+ stride_w = self.input_shape[1] / in_w
145
+ #-------------------------------------------------#
146
+ # 此时获得的scaled_anchors大小是相对于特征层的
147
+ #-------------------------------------------------#
148
+ scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
149
+ #-----------------------------------------------#
150
+ # 输入的input一共有三个,他们的shape分别是
151
+ # bs, 3 * (5+num_classes), 13, 13 => bs, 3, 5 + num_classes, 13, 13 => batch_size, 3, 13, 13, 5 + num_classes
152
+
153
+ # batch_size, 3, 13, 13, 5 + num_classes
154
+ # batch_size, 3, 26, 26, 5 + num_classes
155
+ # batch_size, 3, 52, 52, 5 + num_classes
156
+ #-----------------------------------------------#
157
+ prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
158
+
159
+ #-----------------------------------------------#
160
+ # 先验框的中心位置的调整参数
161
+ #-----------------------------------------------#
162
+ x = torch.sigmoid(prediction[..., 0])
163
+ y = torch.sigmoid(prediction[..., 1])
164
+ #-----------------------------------------------#
165
+ # 先验框的宽高调整参数
166
+ #-----------------------------------------------#
167
+ w = prediction[..., 2]
168
+ h = prediction[..., 3]
169
+ #-----------------------------------------------#
170
+ # 获得置信度,是否有物体
171
+ #-----------------------------------------------#
172
+ conf = torch.sigmoid(prediction[..., 4])
173
+ #-----------------------------------------------#
174
+ # 种类置信度
175
+ #-----------------------------------------------#
176
+ pred_cls = torch.sigmoid(prediction[..., 5:])
177
+
178
+ #-----------------------------------------------#
179
+ # 获得网络应该有的预测结果
180
+ #-----------------------------------------------#
181
+ y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
182
+
183
+ #---------------------------------------------------------------#
184
+ # 将预测结果进行解码,判断预测结果和真实值的重合程度
185
+ # 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
186
+ # 作为负样本不合适
187
+ #----------------------------------------------------------------#
188
+ noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
189
+
190
+ if self.cuda:
191
+ y_true = y_true.type_as(x)
192
+ noobj_mask = noobj_mask.type_as(x)
193
+ box_loss_scale = box_loss_scale.type_as(x)
194
+ #--------------------------------------------------------------------------#
195
+ # box_loss_scale是真实框宽高的乘积,宽高均在0-1之间,因此乘积也在0-1之间。
196
+ # 2-宽高的乘积代表真实框越大,比重越小,小框的比重更大。
197
+ # 使用iou损失时,大中小目标的回归损失不存在比例失衡问题,故弃用
198
+ #--------------------------------------------------------------------------#
199
+ box_loss_scale = 2 - box_loss_scale
200
+
201
+ loss = 0
202
+ obj_mask = y_true[..., 4] == 1
203
+ n = torch.sum(obj_mask)
204
+ if n != 0:
205
+ #---------------------------------------------------------------#
206
+ # 计算预测结果和真实结果的差距
207
+ # loss_loc ciou回归损失
208
+ # loss_cls 分类损失
209
+ #---------------------------------------------------------------#
210
+ ciou = self.box_ciou(pred_boxes, y_true[..., :4]).type_as(x)
211
+ # loss_loc = torch.mean((1 - ciou)[obj_mask] * box_loss_scale[obj_mask])
212
+ loss_loc = torch.mean((1 - ciou)[obj_mask])
213
+
214
+ loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
215
+ loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
216
+
217
+ #---------------------------------------------------------------#
218
+ # 计算是否包含物体的置信度损失
219
+ #---------------------------------------------------------------#
220
+ if self.focal_loss:
221
+ pos_neg_ratio = torch.where(obj_mask, torch.ones_like(conf) * self.alpha, torch.ones_like(conf) * (1 - self.alpha))
222
+ hard_easy_ratio = torch.where(obj_mask, torch.ones_like(conf) - conf, conf) ** self.gamma
223
+ loss_conf = torch.mean((self.BCELoss(conf, obj_mask.type_as(conf)) * pos_neg_ratio * hard_easy_ratio)[noobj_mask.bool() | obj_mask]) * self.focal_loss_ratio
224
+ else:
225
+ loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])
226
+ loss += loss_conf * self.balance[l] * self.obj_ratio
227
+ # if n != 0:
228
+ # print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
229
+ return loss
230
+
231
+ def calculate_iou(self, _box_a, _box_b):
232
+ #-----------------------------------------------------------#
233
+ # 计算真实框的左上角和右下角
234
+ #-----------------------------------------------------------#
235
+ b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
236
+ b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
237
+ #-----------------------------------------------------------#
238
+ # 计算先验框获得的预测框的左上角和右下角
239
+ #-----------------------------------------------------------#
240
+ b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
241
+ b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
242
+
243
+ #-----------------------------------------------------------#
244
+ # 将真实框和预测框都转化成左上角右下角的形式
245
+ #-----------------------------------------------------------#
246
+ box_a = torch.zeros_like(_box_a)
247
+ box_b = torch.zeros_like(_box_b)
248
+ box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
249
+ box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
250
+
251
+ #-----------------------------------------------------------#
252
+ # A为真实框的数量,B为先验框的数量
253
+ #-----------------------------------------------------------#
254
+ A = box_a.size(0)
255
+ B = box_b.size(0)
256
+
257
+ #-----------------------------------------------------------#
258
+ # 计算交的面积
259
+ #-----------------------------------------------------------#
260
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
261
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
262
+ inter = torch.clamp((max_xy - min_xy), min=0)
263
+ inter = inter[:, :, 0] * inter[:, :, 1]
264
+ #-----------------------------------------------------------#
265
+ # 计算预测框和真实框各自的面积
266
+ #-----------------------------------------------------------#
267
+ area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
268
+ area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
269
+ #-----------------------------------------------------------#
270
+ # 求IOU
271
+ #-----------------------------------------------------------#
272
+ union = area_a + area_b - inter
273
+ return inter / union # [A,B]
274
+
275
+ def get_target(self, l, targets, anchors, in_h, in_w):
276
+ #-----------------------------------------------------#
277
+ # 计算一共有多少张图片
278
+ #-----------------------------------------------------#
279
+ bs = len(targets)
280
+ #-----------------------------------------------------#
281
+ # 用于选取哪些先验框不包含物体
282
+ #-----------------------------------------------------#
283
+ noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
284
+ #-----------------------------------------------------#
285
+ # 让网络更加去关注小目标
286
+ #-----------------------------------------------------#
287
+ box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
288
+ #-----------------------------------------------------#
289
+ # batch_size, 3, 13, 13, 5 + num_classes
290
+ #-----------------------------------------------------#
291
+ y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
292
+ for b in range(bs):
293
+ if len(targets[b])==0:
294
+ continue
295
+ batch_target = torch.zeros_like(targets[b])
296
+ #-------------------------------------------------------#
297
+ # 计算出正样本在特征层上的中心点
298
+ #-------------------------------------------------------#
299
+ batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
300
+ batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
301
+ batch_target[:, 4] = targets[b][:, 4]
302
+ batch_target = batch_target.cpu()
303
+
304
+ #-------------------------------------------------------#
305
+ # 将真实框转换一个形式
306
+ # num_true_box, 4
307
+ #-------------------------------------------------------#
308
+ gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
309
+ #-------------------------------------------------------#
310
+ # 将先验框转换一个形式
311
+ # 9, 4
312
+ #-------------------------------------------------------#
313
+ anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
314
+ #-------------------------------------------------------#
315
+ # 计算交并比
316
+ # self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
317
+ # best_ns:
318
+ # [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号]
319
+ #-------------------------------------------------------#
320
+ best_ns = torch.argmax(self.calculate_iou(gt_box, anchor_shapes), dim=-1)
321
+
322
+ for t, best_n in enumerate(best_ns):
323
+ if best_n not in self.anchors_mask[l]:
324
+ continue
325
+ #----------------------------------------#
326
+ # 判断这个先验框是当前特征点的哪一个先验框
327
+ #----------------------------------------#
328
+ k = self.anchors_mask[l].index(best_n)
329
+ #----------------------------------------#
330
+ # 获得真实框属于哪个网格点
331
+ #----------------------------------------#
332
+ i = torch.floor(batch_target[t, 0]).long()
333
+ j = torch.floor(batch_target[t, 1]).long()
334
+ #----------------------------------------#
335
+ # 取出真实框的种类
336
+ #----------------------------------------#
337
+ c = batch_target[t, 4].long()
338
+
339
+ #----------------------------------------#
340
+ # noobj_mask代表无目标的特征点
341
+ #----------------------------------------#
342
+ noobj_mask[b, k, j, i] = 0
343
+ #----------------------------------------#
344
+ # tx、ty代表中心调整参数的真实值
345
+ #----------------------------------------#
346
+ y_true[b, k, j, i, 0] = batch_target[t, 0]
347
+ y_true[b, k, j, i, 1] = batch_target[t, 1]
348
+ y_true[b, k, j, i, 2] = batch_target[t, 2]
349
+ y_true[b, k, j, i, 3] = batch_target[t, 3]
350
+ y_true[b, k, j, i, 4] = 1
351
+ y_true[b, k, j, i, c + 5] = 1
352
+ #----------------------------------------#
353
+ # 用于获得xywh的比例
354
+ # 大目标loss权重小,小目标loss权重大
355
+ #----------------------------------------#
356
+ box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
357
+ return y_true, noobj_mask, box_loss_scale
358
+
359
+ def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
360
+ #-----------------------------------------------------#
361
+ # 计算一共有多少张图片
362
+ #-----------------------------------------------------#
363
+ bs = len(targets)
364
+
365
+ #-----------------------------------------------------#
366
+ # 生成网格,先验框中心,网格左上角
367
+ #-----------------------------------------------------#
368
+ grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
369
+ int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
370
+ grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
371
+ int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
372
+
373
+ # 生成先验框的宽高
374
+ scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
375
+ anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
376
+ anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
377
+
378
+ anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
379
+ anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
380
+ #-------------------------------------------------------#
381
+ # 计算调整后的先验框中心与宽高
382
+ #-------------------------------------------------------#
383
+ pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
384
+ pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
385
+ pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
386
+ pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
387
+ pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
388
+
389
+ for b in range(bs):
390
+ #-------------------------------------------------------#
391
+ # 将预测结果转换一个形式
392
+ # pred_boxes_for_ignore num_anchors, 4
393
+ #-------------------------------------------------------#
394
+ pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
395
+ #-------------------------------------------------------#
396
+ # 计算真实框,并把真实框转换成相对于特征层的大小
397
+ # gt_box num_true_box, 4
398
+ #-------------------------------------------------------#
399
+ if len(targets[b]) > 0:
400
+ batch_target = torch.zeros_like(targets[b])
401
+ #-------------------------------------------------------#
402
+ # 计算出正样本在特征层上的中心点
403
+ #-------------------------------------------------------#
404
+ batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
405
+ batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
406
+ batch_target = batch_target[:, :4].type_as(x)
407
+ #-------------------------------------------------------#
408
+ # 计算交并比
409
+ # anch_ious num_true_box, num_anchors
410
+ #-------------------------------------------------------#
411
+ anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
412
+ #-------------------------------------------------------#
413
+ # 每个先验框对应真实框的最大重合度
414
+ # anch_ious_max num_anchors
415
+ #-------------------------------------------------------#
416
+ anch_ious_max, _ = torch.max(anch_ious, dim = 0)
417
+ anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
418
+ noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
419
+ return noobj_mask, pred_boxes
420
+
421
+ def weights_init(net, init_type='normal', init_gain = 0.02):
422
+ def init_func(m):
423
+ classname = m.__class__.__name__
424
+ if hasattr(m, 'weight') and classname.find('Conv') != -1:
425
+ if init_type == 'normal':
426
+ torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
427
+ elif init_type == 'xavier':
428
+ torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
429
+ elif init_type == 'kaiming':
430
+ torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
431
+ elif init_type == 'orthogonal':
432
+ torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
433
+ else:
434
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
435
+ elif classname.find('BatchNorm2d') != -1:
436
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
437
+ torch.nn.init.constant_(m.bias.data, 0.0)
438
+ print('initialize network with %s type' % init_type)
439
+ net.apply(init_func)
440
+
441
+ def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
442
+ def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
443
+ if iters <= warmup_total_iters:
444
+ # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
445
+ lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
446
+ elif iters >= total_iters - no_aug_iter:
447
+ lr = min_lr
448
+ else:
449
+ lr = min_lr + 0.5 * (lr - min_lr) * (
450
+ 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
451
+ )
452
+ return lr
453
+
454
+ def step_lr(lr, decay_rate, step_size, iters):
455
+ if step_size < 1:
456
+ raise ValueError("step_size must above 1.")
457
+ n = iters // step_size
458
+ out_lr = lr * decay_rate ** n
459
+ return out_lr
460
+
461
+ if lr_decay_type == "cos":
462
+ warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
463
+ warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
464
+ no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
465
+ func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
466
+ else:
467
+ decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
468
+ step_size = total_iters / step_num
469
+ func = partial(step_lr, lr, decay_rate, step_size)
470
+
471
+ return func
472
+
473
+ def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
474
+ lr = lr_scheduler_func(epoch)
475
+ for param_group in optimizer.param_groups:
476
+ param_group['lr'] = lr
nets/yolotiny_training.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ class YOLOLosstiny(nn.Module):
9
+ def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
10
+ super(YOLOLosstiny, self).__init__()
11
+ #-----------------------------------------------------------#
12
+ # 13x13的特征层对应的anchor是[81,82],[135,169],[344,319]
13
+ # 26x26的特征层对应的anchor是[10,14],[23,27],[37,58]
14
+ #-----------------------------------------------------------#
15
+ self.anchors = anchors
16
+ self.num_classes = num_classes
17
+ self.bbox_attrs = 5 + num_classes
18
+ self.input_shape = input_shape
19
+ self.anchors_mask = anchors_mask
20
+ self.label_smoothing = label_smoothing
21
+
22
+ self.balance = [0.4, 1.0, 4]
23
+ self.box_ratio = 0.05
24
+ self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
25
+ self.cls_ratio = 1 * (num_classes / 80)
26
+
27
+ self.ignore_threshold = 0.5
28
+ self.cuda = cuda
29
+
30
+ def clip_by_tensor(self, t, t_min, t_max):
31
+ t = t.float()
32
+ result = (t >= t_min).float() * t + (t < t_min).float() * t_min
33
+ result = (result <= t_max).float() * result + (result > t_max).float() * t_max
34
+ return result
35
+
36
+ def MSELoss(self, pred, target):
37
+ return torch.pow(pred - target, 2)
38
+
39
+ def BCELoss(self, pred, target):
40
+ epsilon = 1e-7
41
+ pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
42
+ output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
43
+ return output
44
+
45
+ def box_ciou(self, b1, b2):
46
+ """
47
+ 输入为:
48
+ ----------
49
+ b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
50
+ b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
51
+
52
+ 返回为:
53
+ -------
54
+ ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
55
+ """
56
+ #----------------------------------------------------#
57
+ # 求出预测框左上角右下角
58
+ #----------------------------------------------------#
59
+ b1_xy = b1[..., :2]
60
+ b1_wh = b1[..., 2:4]
61
+ b1_wh_half = b1_wh/2.
62
+ b1_mins = b1_xy - b1_wh_half
63
+ b1_maxes = b1_xy + b1_wh_half
64
+ #----------------------------------------------------#
65
+ # 求出真实框左上角右下角
66
+ #----------------------------------------------------#
67
+ b2_xy = b2[..., :2]
68
+ b2_wh = b2[..., 2:4]
69
+ b2_wh_half = b2_wh/2.
70
+ b2_mins = b2_xy - b2_wh_half
71
+ b2_maxes = b2_xy + b2_wh_half
72
+
73
+ #----------------------------------------------------#
74
+ # 求真实框和预测框所有的iou
75
+ #----------------------------------------------------#
76
+ intersect_mins = torch.max(b1_mins, b2_mins)
77
+ intersect_maxes = torch.min(b1_maxes, b2_maxes)
78
+ intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
79
+ intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
80
+ b1_area = b1_wh[..., 0] * b1_wh[..., 1]
81
+ b2_area = b2_wh[..., 0] * b2_wh[..., 1]
82
+ union_area = b1_area + b2_area - intersect_area
83
+ iou = intersect_area / torch.clamp(union_area,min = 1e-6)
84
+
85
+ #----------------------------------------------------#
86
+ # 计算中心的差距
87
+ #----------------------------------------------------#
88
+ center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
89
+
90
+ #----------------------------------------------------#
91
+ # 找到包裹两个框的最小框的左上角和右下角
92
+ #----------------------------------------------------#
93
+ enclose_mins = torch.min(b1_mins, b2_mins)
94
+ enclose_maxes = torch.max(b1_maxes, b2_maxes)
95
+ enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
96
+ #----------------------------------------------------#
97
+ # 计算对角线距离
98
+ #----------------------------------------------------#
99
+ enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
100
+ ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
101
+
102
+ v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
103
+ alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
104
+ ciou = ciou - alpha * v
105
+ return ciou
106
+
107
+ #---------------------------------------------------#
108
+ # 平滑标签
109
+ #---------------------------------------------------#
110
+ def smooth_labels(self, y_true, label_smoothing, num_classes):
111
+ return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
112
+
113
+ def forward(self, l, input, targets=None):
114
+ #----------------------------------------------------#
115
+ # l 代表使用的是第几个有效特征层
116
+ # input的shape为 bs, 3*(5+num_classes), 13, 13
117
+ # bs, 3*(5+num_classes), 26, 26
118
+ # targets 真实框的标签情况 [batch_size, num_gt, 5]
119
+ #----------------------------------------------------#
120
+ #--------------------------------#
121
+ # 获得图片数量,特征层的高和宽
122
+ #--------------------------------#
123
+ bs = input.size(0)
124
+ in_h = input.size(2)
125
+ in_w = input.size(3)
126
+ #-----------------------------------------------------------------------#
127
+ # 计算步长
128
+ # 每一个特征点对应原来的图片上多少个像素点
129
+ #
130
+ # 如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
131
+ # 如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点
132
+ # stride_h = stride_w = 32、16
133
+ #-----------------------------------------------------------------------#
134
+ stride_h = self.input_shape[0] / in_h
135
+ stride_w = self.input_shape[1] / in_w
136
+ #-------------------------------------------------#
137
+ # 此时获得的scaled_anchors大小是相对于特征层的
138
+ #-------------------------------------------------#
139
+ scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
140
+ #-----------------------------------------------#
141
+ # 输入的input一共有三个,他们的shape分别是
142
+ # bs, 3 * (5+num_classes), 13, 13 => bs, 3, 5 + num_classes, 13, 13 => batch_size, 3, 13, 13, 5 + num_classes
143
+
144
+ # batch_size, 3, 13, 13, 5 + num_classes
145
+ # batch_size, 3, 26, 26, 5 + num_classes
146
+ #-----------------------------------------------#
147
+ prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
148
+
149
+ #-----------------------------------------------#
150
+ # 先验框的中心位置的调整参数
151
+ #-----------------------------------------------#
152
+ x = torch.sigmoid(prediction[..., 0])
153
+ y = torch.sigmoid(prediction[..., 1])
154
+ #-----------------------------------------------#
155
+ # 先验框的宽高调整参数
156
+ #-----------------------------------------------#
157
+ w = prediction[..., 2]
158
+ h = prediction[..., 3]
159
+ #-----------------------------------------------#
160
+ # 获得置信度,是否有物体
161
+ #-----------------------------------------------#
162
+ conf = torch.sigmoid(prediction[..., 4])
163
+ #-----------------------------------------------#
164
+ # 种类置信度
165
+ #-----------------------------------------------#
166
+ pred_cls = torch.sigmoid(prediction[..., 5:])
167
+
168
+ #-----------------------------------------------#
169
+ # 获得网络应该有的预测结果
170
+ #-----------------------------------------------#
171
+ y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
172
+
173
+ #---------------------------------------------------------------#
174
+ # 将预测结果进行解码,判断预测结果和真实值的重合程度
175
+ # 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
176
+ # 作为负样本不合适
177
+ #----------------------------------------------------------------#
178
+ noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
179
+
180
+ if self.cuda:
181
+ y_true = y_true.type_as(x)
182
+ noobj_mask = noobj_mask.type_as(x)
183
+ box_loss_scale = box_loss_scale.type_as(x)
184
+ #--------------------------------------------------------------------------#
185
+ # box_loss_scale是真实框宽高的乘积,宽高均在0-1之间,因此乘积也在0-1之间。
186
+ # 2-宽高的乘积代表真实框越大,比重越小,小框的比重更大。
187
+ # 使用iou损失时,大中小目标的回归损失不存在比例失衡问题,故弃用
188
+ #--------------------------------------------------------------------------#
189
+ box_loss_scale = 2 - box_loss_scale
190
+
191
+ loss = 0
192
+ obj_mask = y_true[..., 4] == 1
193
+ n = torch.sum(obj_mask)
194
+ if n != 0:
195
+ #---------------------------------------------------------------#
196
+ # 计算预测结果和真实结果的差距
197
+ # loss_loc ciou回归损失
198
+ # loss_cls 分类损失
199
+ #---------------------------------------------------------------#
200
+ ciou = self.box_ciou(pred_boxes, y_true[..., :4]).type_as(x)
201
+ # loss_loc = torch.mean((1 - ciou)[obj_mask] * box_loss_scale[obj_mask])
202
+ loss_loc = torch.mean((1 - ciou)[obj_mask])
203
+
204
+ loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
205
+ loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
206
+
207
+ loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])
208
+ loss += loss_conf * self.balance[l] * self.obj_ratio
209
+ # if n != 0:
210
+ # print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
211
+ return loss
212
+
213
+ def calculate_iou(self, _box_a, _box_b):
214
+ #-----------------------------------------------------------#
215
+ # 计算真实框的左上角和右下角
216
+ #-----------------------------------------------------------#
217
+ b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
218
+ b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
219
+ #-----------------------------------------------------------#
220
+ # 计算先验框获得的预测框的左上角和右下角
221
+ #-----------------------------------------------------------#
222
+ b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
223
+ b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
224
+
225
+ #-----------------------------------------------------------#
226
+ # 将真实框和预测框都转化成左上角右下角的形式
227
+ #-----------------------------------------------------------#
228
+ box_a = torch.zeros_like(_box_a)
229
+ box_b = torch.zeros_like(_box_b)
230
+ box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
231
+ box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
232
+
233
+ #-----------------------------------------------------------#
234
+ # A为真实框的数量,B为先验框的数量
235
+ #-----------------------------------------------------------#
236
+ A = box_a.size(0)
237
+ B = box_b.size(0)
238
+
239
+ #-----------------------------------------------------------#
240
+ # 计算交的面积
241
+ #-----------------------------------------------------------#
242
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
243
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
244
+ inter = torch.clamp((max_xy - min_xy), min=0)
245
+ inter = inter[:, :, 0] * inter[:, :, 1]
246
+ #-----------------------------------------------------------#
247
+ # 计算预测框和真实框各自的面积
248
+ #-----------------------------------------------------------#
249
+ area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
250
+ area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
251
+ #-----------------------------------------------------------#
252
+ # 求IOU
253
+ #-----------------------------------------------------------#
254
+ union = area_a + area_b - inter
255
+ return inter / union # [A,B]
256
+
257
+ def get_target(self, l, targets, anchors, in_h, in_w):
258
+ #-----------------------------------------------------#
259
+ # 计算一共有多少张图片
260
+ #-----------------------------------------------------#
261
+ bs = len(targets)
262
+ #-----------------------------------------------------#
263
+ # 用于选取哪些先验框不包含物体
264
+ #-----------------------------------------------------#
265
+ noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
266
+ #-----------------------------------------------------#
267
+ # 让网络更加去关注小目标
268
+ #-----------------------------------------------------#
269
+ box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
270
+ #-----------------------------------------------------#
271
+ # batch_size, 3, 13, 13, 5 + num_classes
272
+ #-----------------------------------------------------#
273
+ y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
274
+ for b in range(bs):
275
+ if len(targets[b])==0:
276
+ continue
277
+ batch_target = torch.zeros_like(targets[b])
278
+ #-------------------------------------------------------#
279
+ # 计算出正样本在特征层上的中心点
280
+ #-------------------------------------------------------#
281
+ batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
282
+ batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
283
+ batch_target[:, 4] = targets[b][:, 4]
284
+ batch_target = batch_target.cpu()
285
+
286
+ #-------------------------------------------------------#
287
+ # 将真实框转换一个形式
288
+ # num_true_box, 4
289
+ #-------------------------------------------------------#
290
+ gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
291
+ #-------------------------------------------------------#
292
+ # 将先验框转换一个形式
293
+ # 9, 4
294
+ #-------------------------------------------------------#
295
+ anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
296
+ #-------------------------------------------------------#
297
+ # 计算交并比
298
+ # self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
299
+ # best_ns:
300
+ # [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号]
301
+ #-------------------------------------------------------#
302
+ iou = self.calculate_iou(gt_box, anchor_shapes)
303
+ best_ns = torch.argmax(iou, dim=-1)
304
+ sort_ns = torch.argsort(iou, dim=-1, descending=True)
305
+
306
+ def check_in_anchors_mask(index, anchors_mask):
307
+ for sub_anchors_mask in anchors_mask:
308
+ if index in sub_anchors_mask:
309
+ return True
310
+ return False
311
+
312
+ for t, best_n in enumerate(best_ns):
313
+ #----------------------------------------#
314
+ # 防止匹配到的先验框不在anchors_mask中
315
+ #----------------------------------------#
316
+ if not check_in_anchors_mask(best_n, self.anchors_mask):
317
+ for index in sort_ns[t]:
318
+ if check_in_anchors_mask(index, self.anchors_mask):
319
+ best_n = index
320
+ break
321
+
322
+ if best_n not in self.anchors_mask[l]:
323
+ continue
324
+ #----------------------------------------#
325
+ # 判断这个先验框是当前特征点的哪一个先验框
326
+ #----------------------------------------#
327
+ k = self.anchors_mask[l].index(best_n)
328
+ #----------------------------------------#
329
+ # 获得真实框属于哪个网格点
330
+ #----------------------------------------#
331
+ i = torch.floor(batch_target[t, 0]).long()
332
+ j = torch.floor(batch_target[t, 1]).long()
333
+ #----------------------------------------#
334
+ # 取出真实框的种类
335
+ #----------------------------------------#
336
+ c = batch_target[t, 4].long()
337
+
338
+ #----------------------------------------#
339
+ # noobj_mask代表无目标的特征点
340
+ #----------------------------------------#
341
+ noobj_mask[b, k, j, i] = 0
342
+ #----------------------------------------#
343
+ # tx、ty代表中心调整参数的真实值
344
+ #----------------------------------------#
345
+ y_true[b, k, j, i, 0] = batch_target[t, 0]
346
+ y_true[b, k, j, i, 1] = batch_target[t, 1]
347
+ y_true[b, k, j, i, 2] = batch_target[t, 2]
348
+ y_true[b, k, j, i, 3] = batch_target[t, 3]
349
+ y_true[b, k, j, i, 4] = 1
350
+ y_true[b, k, j, i, c + 5] = 1
351
+ #----------------------------------------#
352
+ # 用于获得xywh的比例
353
+ # 大目标loss权重小,小目标loss权重大
354
+ #----------------------------------------#
355
+ box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
356
+ return y_true, noobj_mask, box_loss_scale
357
+
358
+ def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
359
+ #-----------------------------------------------------#
360
+ # 计算一共有多少张图片
361
+ #-----------------------------------------------------#
362
+ bs = len(targets)
363
+
364
+ #-----------------------------------------------------#
365
+ # 生成网格,先验框中心,网格左上角
366
+ #-----------------------------------------------------#
367
+ grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
368
+ int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
369
+ grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
370
+ int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
371
+
372
+ # 生成先验框的宽高
373
+ scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
374
+ anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
375
+ anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
376
+
377
+ anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
378
+ anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
379
+ #-------------------------------------------------------#
380
+ # 计算调整后的先验框中心与宽高
381
+ #-------------------------------------------------------#
382
+ pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
383
+ pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
384
+ pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
385
+ pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
386
+ pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
387
+ for b in range(bs):
388
+ #-------------------------------------------------------#
389
+ # 将预测结果转换一个形式
390
+ # pred_boxes_for_ignore num_anchors, 4
391
+ #-------------------------------------------------------#
392
+ pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
393
+ #-------------------------------------------------------#
394
+ # 计算真实框,并把真实框转换成相对于特征层的大小
395
+ # gt_box num_true_box, 4
396
+ #-------------------------------------------------------#
397
+ if len(targets[b]) > 0:
398
+ batch_target = torch.zeros_like(targets[b])
399
+ #-------------------------------------------------------#
400
+ # 计算出正样本在特征层上的中心点
401
+ #-------------------------------------------------------#
402
+ batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
403
+ batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
404
+ batch_target = batch_target[:, :4].type_as(x)
405
+ #-------------------------------------------------------#
406
+ # 计算交并比
407
+ # anch_ious num_true_box, num_anchors
408
+ #-------------------------------------------------------#
409
+ anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
410
+ #-------------------------------------------------------#
411
+ # 每个先验框对应真实框的最大重合度
412
+ # anch_ious_max num_anchors
413
+ #-------------------------------------------------------#
414
+ anch_ious_max, _ = torch.max(anch_ious, dim = 0)
415
+ anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
416
+ noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
417
+ return noobj_mask, pred_boxes
418
+
419
+ def weights_init(net, init_type='normal', init_gain = 0.02):
420
+ def init_func(m):
421
+ classname = m.__class__.__name__
422
+ if hasattr(m, 'weight') and classname.find('Conv') != -1:
423
+ if init_type == 'normal':
424
+ torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
425
+ elif init_type == 'xavier':
426
+ torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
427
+ elif init_type == 'kaiming':
428
+ torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
429
+ elif init_type == 'orthogonal':
430
+ torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
431
+ else:
432
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
433
+ elif classname.find('BatchNorm2d') != -1:
434
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
435
+ torch.nn.init.constant_(m.bias.data, 0.0)
436
+ print('initialize network with %s type' % init_type)
437
+ net.apply(init_func)
438
+
439
+ def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
440
+ def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
441
+ if iters <= warmup_total_iters:
442
+ # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
443
+ lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
444
+ elif iters >= total_iters - no_aug_iter:
445
+ lr = min_lr
446
+ else:
447
+ lr = min_lr + 0.5 * (lr - min_lr) * (
448
+ 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
449
+ )
450
+ return lr
451
+
452
+ def step_lr(lr, decay_rate, step_size, iters):
453
+ if step_size < 1:
454
+ raise ValueError("step_size must above 1.")
455
+ n = iters // step_size
456
+ out_lr = lr * decay_rate ** n
457
+ return out_lr
458
+
459
+ if lr_decay_type == "cos":
460
+ warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
461
+ warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
462
+ no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
463
+ func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
464
+ else:
465
+ decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
466
+ step_size = total_iters / step_num
467
+ func = partial(step_lr, lr, decay_rate, step_size)
468
+
469
+ return func
470
+
471
+ def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
472
+ lr = lr_scheduler_func(epoch)
473
+ for param_group in optimizer.param_groups:
474
+ param_group['lr'] = lr
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
utils/callbacks.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+
4
+ import torch
5
+ import matplotlib
6
+ matplotlib.use('Agg')
7
+ import scipy.signal
8
+ from matplotlib import pyplot as plt
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+
12
+ class LossHistory():
13
+ def __init__(self, log_dir, model, input_shape):
14
+ time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
15
+ self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
16
+ self.losses = []
17
+ self.val_loss = []
18
+
19
+ os.makedirs(self.log_dir)
20
+ self.writer = SummaryWriter(self.log_dir)
21
+ try:
22
+ dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
23
+ self.writer.add_graph(model, dummy_input)
24
+ except:
25
+ pass
26
+
27
+
28
+ def append_loss(self, epoch, loss, val_loss):
29
+ if not os.path.exists(self.log_dir):
30
+ os.makedirs(self.log_dir)
31
+
32
+ self.losses.append(loss)
33
+ self.val_loss.append(val_loss)
34
+
35
+ with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
36
+ f.write(str(loss))
37
+ f.write("\n")
38
+ with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
39
+ f.write(str(val_loss))
40
+ f.write("\n")
41
+
42
+ self.writer.add_scalar('loss', loss, epoch)
43
+ self.writer.add_scalar('val_loss', val_loss, epoch)
44
+ self.loss_plot()
45
+
46
+ def loss_plot(self):
47
+ iters = range(len(self.losses))
48
+
49
+ plt.figure()
50
+ plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
51
+ plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
52
+ try:
53
+ if len(self.losses) < 25:
54
+ num = 5
55
+ else:
56
+ num = 15
57
+
58
+ plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
59
+ plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
60
+ except:
61
+ pass
62
+
63
+ plt.grid(True)
64
+ plt.xlabel('Epoch')
65
+ plt.ylabel('Loss')
66
+ plt.legend(loc="upper right")
67
+
68
+ plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
69
+
70
+ plt.cla()
71
+ plt.close("all")
utils/dataloader.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import sample, shuffle
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data.dataset import Dataset
8
+
9
+ from utils.utils import cvtColor, preprocess_input
10
+
11
+
12
+ class YoloDataset(Dataset):
13
+ def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, mosaic, train, mosaic_ratio = 0.7):
14
+ super(YoloDataset, self).__init__()
15
+ self.annotation_lines = annotation_lines
16
+ self.input_shape = input_shape
17
+ self.num_classes = num_classes
18
+ self.epoch_length = epoch_length
19
+ self.mosaic = mosaic
20
+ self.train = train
21
+ self.mosaic_ratio = mosaic_ratio
22
+
23
+ self.epoch_now = -1
24
+ self.length = len(self.annotation_lines)
25
+
26
+ def __len__(self):
27
+ return self.length
28
+
29
+ def __getitem__(self, index):
30
+ index = index % self.length
31
+
32
+ #---------------------------------------------------#
33
+ # 训练时进行数据的随机增强
34
+ # 验证时不进行数据的随机增强
35
+ #---------------------------------------------------#
36
+ if self.mosaic:
37
+ if self.rand() < 0.5 and self.epoch_now < self.epoch_length * self.mosaic_ratio:
38
+ lines = sample(self.annotation_lines, 3)
39
+ lines.append(self.annotation_lines[index])
40
+ shuffle(lines)
41
+ image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
42
+ else:
43
+ image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
44
+ else:
45
+ image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
46
+ image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
47
+ box = np.array(box, dtype=np.float32)
48
+ if len(box) != 0:
49
+ box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
50
+ box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
51
+
52
+ box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
53
+ box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
54
+ return image, box
55
+
56
+ def rand(self, a=0, b=1):
57
+ return np.random.rand()*(b-a) + a
58
+
59
+ def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
60
+ line = annotation_line.split()
61
+ #------------------------------#
62
+ # 读取图像并转换成RGB图像
63
+ #------------------------------#
64
+ image = Image.open(line[0])
65
+ image = cvtColor(image)
66
+ #------------------------------#
67
+ # 获得图像的高宽与目标高宽
68
+ #------------------------------#
69
+ iw, ih = image.size
70
+ h, w = input_shape
71
+ #------------------------------#
72
+ # 获得预测框
73
+ #------------------------------#
74
+ box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
75
+
76
+ if not random:
77
+ scale = min(w/iw, h/ih)
78
+ nw = int(iw*scale)
79
+ nh = int(ih*scale)
80
+ dx = (w-nw)//2
81
+ dy = (h-nh)//2
82
+
83
+ #---------------------------------#
84
+ # 将图像多余的部分加上灰条
85
+ #---------------------------------#
86
+ image = image.resize((nw,nh), Image.BICUBIC)
87
+ new_image = Image.new('RGB', (w,h), (128,128,128))
88
+ new_image.paste(image, (dx, dy))
89
+ image_data = np.array(new_image, np.float32)
90
+
91
+ #---------------------------------#
92
+ # 对真实框进行调整
93
+ #---------------------------------#
94
+ if len(box)>0:
95
+ np.random.shuffle(box)
96
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
97
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
98
+ box[:, 0:2][box[:, 0:2]<0] = 0
99
+ box[:, 2][box[:, 2]>w] = w
100
+ box[:, 3][box[:, 3]>h] = h
101
+ box_w = box[:, 2] - box[:, 0]
102
+ box_h = box[:, 3] - box[:, 1]
103
+ box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
104
+
105
+ return image_data, box
106
+
107
+ #------------------------------------------#
108
+ # 对图像进行缩放并且进行长和宽的扭曲
109
+ #------------------------------------------#
110
+ new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
111
+ scale = self.rand(.25, 2)
112
+ if new_ar < 1:
113
+ nh = int(scale*h)
114
+ nw = int(nh*new_ar)
115
+ else:
116
+ nw = int(scale*w)
117
+ nh = int(nw/new_ar)
118
+ image = image.resize((nw,nh), Image.BICUBIC)
119
+
120
+ #------------------------------------------#
121
+ # 将图像多余的部分加上灰条
122
+ #------------------------------------------#
123
+ dx = int(self.rand(0, w-nw))
124
+ dy = int(self.rand(0, h-nh))
125
+ new_image = Image.new('RGB', (w,h), (128,128,128))
126
+ new_image.paste(image, (dx, dy))
127
+ image = new_image
128
+
129
+ #------------------------------------------#
130
+ # 翻转图像
131
+ #------------------------------------------#
132
+ flip = self.rand()<.5
133
+ if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
134
+
135
+ image_data = np.array(image, np.uint8)
136
+ #---------------------------------#
137
+ # 对图像进行色域变换
138
+ # 计算色域变换的参数
139
+ #---------------------------------#
140
+ r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
141
+ #---------------------------------#
142
+ # 将图像转到HSV上
143
+ #---------------------------------#
144
+ hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
145
+ dtype = image_data.dtype
146
+ #---------------------------------#
147
+ # 应用变换
148
+ #---------------------------------#
149
+ x = np.arange(0, 256, dtype=r.dtype)
150
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
151
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
152
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
153
+
154
+ image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
155
+ image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
156
+
157
+ #---------------------------------#
158
+ # 对真实框进行调整
159
+ #---------------------------------#
160
+ if len(box)>0:
161
+ np.random.shuffle(box)
162
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
163
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
164
+ if flip: box[:, [0,2]] = w - box[:, [2,0]]
165
+ box[:, 0:2][box[:, 0:2]<0] = 0
166
+ box[:, 2][box[:, 2]>w] = w
167
+ box[:, 3][box[:, 3]>h] = h
168
+ box_w = box[:, 2] - box[:, 0]
169
+ box_h = box[:, 3] - box[:, 1]
170
+ box = box[np.logical_and(box_w>1, box_h>1)]
171
+
172
+ return image_data, box
173
+
174
+ def merge_bboxes(self, bboxes, cutx, cuty):
175
+ merge_bbox = []
176
+ for i in range(len(bboxes)):
177
+ for box in bboxes[i]:
178
+ tmp_box = []
179
+ x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
180
+
181
+ if i == 0:
182
+ if y1 > cuty or x1 > cutx:
183
+ continue
184
+ if y2 >= cuty and y1 <= cuty:
185
+ y2 = cuty
186
+ if x2 >= cutx and x1 <= cutx:
187
+ x2 = cutx
188
+
189
+ if i == 1:
190
+ if y2 < cuty or x1 > cutx:
191
+ continue
192
+ if y2 >= cuty and y1 <= cuty:
193
+ y1 = cuty
194
+ if x2 >= cutx and x1 <= cutx:
195
+ x2 = cutx
196
+
197
+ if i == 2:
198
+ if y2 < cuty or x2 < cutx:
199
+ continue
200
+ if y2 >= cuty and y1 <= cuty:
201
+ y1 = cuty
202
+ if x2 >= cutx and x1 <= cutx:
203
+ x1 = cutx
204
+
205
+ if i == 3:
206
+ if y1 > cuty or x2 < cutx:
207
+ continue
208
+ if y2 >= cuty and y1 <= cuty:
209
+ y2 = cuty
210
+ if x2 >= cutx and x1 <= cutx:
211
+ x1 = cutx
212
+ tmp_box.append(x1)
213
+ tmp_box.append(y1)
214
+ tmp_box.append(x2)
215
+ tmp_box.append(y2)
216
+ tmp_box.append(box[-1])
217
+ merge_bbox.append(tmp_box)
218
+ return merge_bbox
219
+
220
+ def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
221
+ h, w = input_shape
222
+ min_offset_x = self.rand(0.3, 0.7)
223
+ min_offset_y = self.rand(0.3, 0.7)
224
+
225
+ image_datas = []
226
+ box_datas = []
227
+ index = 0
228
+ for line in annotation_line:
229
+ #---------------------------------#
230
+ # 每一行进行分割
231
+ #---------------------------------#
232
+ line_content = line.split()
233
+ #---------------------------------#
234
+ # 打开图片
235
+ #---------------------------------#
236
+ image = Image.open(line_content[0])
237
+ image = cvtColor(image)
238
+
239
+ #---------------------------------#
240
+ # 图片的大小
241
+ #---------------------------------#
242
+ iw, ih = image.size
243
+ #---------------------------------#
244
+ # 保存框的位置
245
+ #---------------------------------#
246
+ box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
247
+
248
+ #---------------------------------#
249
+ # 是否翻转图片
250
+ #---------------------------------#
251
+ flip = self.rand()<.5
252
+ if flip and len(box)>0:
253
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
254
+ box[:, [0,2]] = iw - box[:, [2,0]]
255
+
256
+ #------------------------------------------#
257
+ # 对图像进行缩放并且进行长和宽的扭曲
258
+ #------------------------------------------#
259
+ new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
260
+ scale = self.rand(.4, 1)
261
+ if new_ar < 1:
262
+ nh = int(scale*h)
263
+ nw = int(nh*new_ar)
264
+ else:
265
+ nw = int(scale*w)
266
+ nh = int(nw/new_ar)
267
+ image = image.resize((nw, nh), Image.BICUBIC)
268
+
269
+ #-----------------------------------------------#
270
+ # 将图片进行放置,分别对应四张分割图片的位置
271
+ #-----------------------------------------------#
272
+ if index == 0:
273
+ dx = int(w*min_offset_x) - nw
274
+ dy = int(h*min_offset_y) - nh
275
+ elif index == 1:
276
+ dx = int(w*min_offset_x) - nw
277
+ dy = int(h*min_offset_y)
278
+ elif index == 2:
279
+ dx = int(w*min_offset_x)
280
+ dy = int(h*min_offset_y)
281
+ elif index == 3:
282
+ dx = int(w*min_offset_x)
283
+ dy = int(h*min_offset_y) - nh
284
+
285
+ new_image = Image.new('RGB', (w,h), (128,128,128))
286
+ new_image.paste(image, (dx, dy))
287
+ image_data = np.array(new_image)
288
+
289
+ index = index + 1
290
+ box_data = []
291
+ #---------------------------------#
292
+ # 对box进行重新处理
293
+ #---------------------------------#
294
+ if len(box)>0:
295
+ np.random.shuffle(box)
296
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
297
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
298
+ box[:, 0:2][box[:, 0:2]<0] = 0
299
+ box[:, 2][box[:, 2]>w] = w
300
+ box[:, 3][box[:, 3]>h] = h
301
+ box_w = box[:, 2] - box[:, 0]
302
+ box_h = box[:, 3] - box[:, 1]
303
+ box = box[np.logical_and(box_w>1, box_h>1)]
304
+ box_data = np.zeros((len(box),5))
305
+ box_data[:len(box)] = box
306
+
307
+ image_datas.append(image_data)
308
+ box_datas.append(box_data)
309
+
310
+ #---------------------------------#
311
+ # 将图片分割,放在一起
312
+ #---------------------------------#
313
+ cutx = int(w * min_offset_x)
314
+ cuty = int(h * min_offset_y)
315
+
316
+ new_image = np.zeros([h, w, 3])
317
+ new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
318
+ new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
319
+ new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
320
+ new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
321
+
322
+ new_image = np.array(new_image, np.uint8)
323
+ #---------------------------------#
324
+ # 对图像进行色域变换
325
+ # 计算色域变换的参数
326
+ #---------------------------------#
327
+ r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
328
+ #---------------------------------#
329
+ # 将图像转到HSV上
330
+ #---------------------------------#
331
+ hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
332
+ dtype = new_image.dtype
333
+ #---------------------------------#
334
+ # 应用变换
335
+ #---------------------------------#
336
+ x = np.arange(0, 256, dtype=r.dtype)
337
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
338
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
339
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
340
+
341
+ new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
342
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
343
+
344
+ #---------------------------------#
345
+ # 对框进行进一步的处理
346
+ #---------------------------------#
347
+ new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
348
+
349
+ return new_image, new_boxes
350
+
351
+ # DataLoader中collate_fn使用
352
+ def yolo_dataset_collate(batch):
353
+ images = []
354
+ bboxes = []
355
+ for img, box in batch:
356
+ images.append(img)
357
+ bboxes.append(box)
358
+ images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
359
+ bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
360
+ return images, bboxes
utils/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ #---------------------------------------------------------#
5
+ # 将图像转换成RGB图像,防止灰度图在预测时报错。
6
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
7
+ #---------------------------------------------------------#
8
+ def cvtColor(image):
9
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
10
+ return image
11
+ else:
12
+ image = image.convert('RGB')
13
+ return image
14
+
15
+ #---------------------------------------------------#
16
+ # 对输入图像进行resize
17
+ #---------------------------------------------------#
18
+ def resize_image(image, size, letterbox_image):
19
+ iw, ih = image.size
20
+ w, h = size
21
+ if letterbox_image:
22
+ scale = min(w/iw, h/ih)
23
+ nw = int(iw*scale)
24
+ nh = int(ih*scale)
25
+
26
+ image = image.resize((nw,nh), Image.BICUBIC)
27
+ new_image = Image.new('RGB', size, (128,128,128))
28
+ new_image.paste(image, ((w-nw)//2, (h-nh)//2))
29
+ else:
30
+ new_image = image.resize((w, h), Image.BICUBIC)
31
+ return new_image
32
+
33
+ #---------------------------------------------------#
34
+ # 获得类
35
+ #---------------------------------------------------#
36
+ def get_classes(classes_path):
37
+ with open(classes_path, encoding='utf-8') as f:
38
+ class_names = f.readlines()
39
+ class_names = [c.strip() for c in class_names]
40
+ return class_names, len(class_names)
41
+
42
+ #---------------------------------------------------#
43
+ # 获得先验框
44
+ #---------------------------------------------------#
45
+ def get_anchors(anchors_path):
46
+ '''loads the anchors from a file'''
47
+ with open(anchors_path, encoding='utf-8') as f:
48
+ anchors = f.readline()
49
+ anchors = [float(x) for x in anchors.split(',')]
50
+ anchors = np.array(anchors).reshape(-1, 2)
51
+ return anchors, len(anchors)
52
+
53
+ #---------------------------------------------------#
54
+ # 获得学习率
55
+ #---------------------------------------------------#
56
+ def get_lr(optimizer):
57
+ for param_group in optimizer.param_groups:
58
+ return param_group['lr']
59
+
60
+ def preprocess_input(image):
61
+ image /= 255.0
62
+ return image
utils/utils_bbox.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.ops import nms
4
+ import numpy as np
5
+
6
+ class DecodeBox():
7
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
8
+ super(DecodeBox, self).__init__()
9
+ self.anchors = anchors
10
+ self.num_classes = num_classes
11
+ self.bbox_attrs = 5 + num_classes
12
+ self.input_shape = input_shape
13
+ #-----------------------------------------------------------#
14
+ # 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
15
+ # 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146]
16
+ # 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28]
17
+ #-----------------------------------------------------------#
18
+ self.anchors_mask = anchors_mask
19
+
20
+ def decode_box(self, inputs):
21
+ outputs = []
22
+ for i, input in enumerate(inputs):
23
+ #-----------------------------------------------#
24
+ # 输入的input一共有三个,他们的shape分别是
25
+ # batch_size, 255, 13, 13
26
+ # batch_size, 255, 26, 26
27
+ # batch_size, 255, 52, 52
28
+ #-----------------------------------------------#
29
+ batch_size = input.size(0)
30
+ input_height = input.size(2)
31
+ input_width = input.size(3)
32
+
33
+ #-----------------------------------------------#
34
+ # 输入为416x416时
35
+ # stride_h = stride_w = 32、16、8
36
+ #-----------------------------------------------#
37
+ stride_h = self.input_shape[0] / input_height
38
+ stride_w = self.input_shape[1] / input_width
39
+ #-------------------------------------------------#
40
+ # 此时获得的scaled_anchors大小是相对于特征层的
41
+ #-------------------------------------------------#
42
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
43
+
44
+ #-----------------------------------------------#
45
+ # 输入的input一共有三个,他们的shape分别是
46
+ # batch_size, 3, 13, 13, 85
47
+ # batch_size, 3, 26, 26, 85
48
+ # batch_size, 3, 52, 52, 85
49
+ #-----------------------------------------------#
50
+ prediction = input.view(batch_size, len(self.anchors_mask[i]),
51
+ self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
52
+
53
+ #-----------------------------------------------#
54
+ # 先验框的中心位置的调整参数
55
+ #-----------------------------------------------#
56
+ x = torch.sigmoid(prediction[..., 0])
57
+ y = torch.sigmoid(prediction[..., 1])
58
+ #-----------------------------------------------#
59
+ # 先验框的宽高调整参数
60
+ #-----------------------------------------------#
61
+ w = prediction[..., 2]
62
+ h = prediction[..., 3]
63
+ #-----------------------------------------------#
64
+ # 获得置信度,是否有物体
65
+ #-----------------------------------------------#
66
+ conf = torch.sigmoid(prediction[..., 4])
67
+ #-----------------------------------------------#
68
+ # 种类置信度
69
+ #-----------------------------------------------#
70
+ pred_cls = torch.sigmoid(prediction[..., 5:])
71
+
72
+ FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
73
+ LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
74
+
75
+ #----------------------------------------------------------#
76
+ # 生成网格,先验框中心,网格左上角
77
+ # batch_size,3,13,13
78
+ #----------------------------------------------------------#
79
+ grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
80
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
81
+ grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
82
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
83
+
84
+ #----------------------------------------------------------#
85
+ # 按照网格格式生成先验框的宽高
86
+ # batch_size,3,13,13
87
+ #----------------------------------------------------------#
88
+ anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
89
+ anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
90
+ anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
91
+ anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
92
+
93
+ #----------------------------------------------------------#
94
+ # 利用预测结果对先验框进行调整
95
+ # 首先调整先验框的中心,从先验框中心向右下角偏移
96
+ # 再调整先验框的宽高。
97
+ #----------------------------------------------------------#
98
+ pred_boxes = FloatTensor(prediction[..., :4].shape)
99
+ pred_boxes[..., 0] = x.data + grid_x
100
+ pred_boxes[..., 1] = y.data + grid_y
101
+ pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
102
+ pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
103
+
104
+ #----------------------------------------------------------#
105
+ # 将输出结果归一化成小数的形式
106
+ #----------------------------------------------------------#
107
+ _scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
108
+ output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
109
+ conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
110
+ outputs.append(output.data)
111
+ return outputs
112
+
113
+ def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
114
+ #-----------------------------------------------------------------#
115
+ # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
116
+ #-----------------------------------------------------------------#
117
+ box_yx = box_xy[..., ::-1]
118
+ box_hw = box_wh[..., ::-1]
119
+ input_shape = np.array(input_shape)
120
+ image_shape = np.array(image_shape)
121
+
122
+ if letterbox_image:
123
+ #-----------------------------------------------------------------#
124
+ # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
125
+ # new_shape指的是宽高缩放情况
126
+ #-----------------------------------------------------------------#
127
+ new_shape = np.round(image_shape * np.min(input_shape/image_shape))
128
+ offset = (input_shape - new_shape)/2./input_shape
129
+ scale = input_shape/new_shape
130
+
131
+ box_yx = (box_yx - offset) * scale
132
+ box_hw *= scale
133
+
134
+ box_mins = box_yx - (box_hw / 2.)
135
+ box_maxes = box_yx + (box_hw / 2.)
136
+ boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
137
+ boxes *= np.concatenate([image_shape, image_shape], axis=-1)
138
+ return boxes
139
+
140
+ def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
141
+ #----------------------------------------------------------#
142
+ # 将预测结果的格式转换成左上角右下角的格式。
143
+ # prediction [batch_size, num_anchors, 85]
144
+ #----------------------------------------------------------#
145
+ box_corner = prediction.new(prediction.shape)
146
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
147
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
148
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
149
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
150
+ prediction[:, :, :4] = box_corner[:, :, :4]
151
+
152
+ output = [None for _ in range(len(prediction))]
153
+ for i, image_pred in enumerate(prediction):
154
+ #----------------------------------------------------------#
155
+ # 对种类预测部分取max。
156
+ # class_conf [num_anchors, 1] 种类置信度
157
+ # class_pred [num_anchors, 1] 种类
158
+ #----------------------------------------------------------#
159
+ class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
160
+
161
+ #----------------------------------------------------------#
162
+ # 利用置信度进行第一轮筛选
163
+ #----------------------------------------------------------#
164
+ conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
165
+
166
+ #----------------------------------------------------------#
167
+ # 根据置信度进行预测结果的筛选
168
+ #----------------------------------------------------------#
169
+ image_pred = image_pred[conf_mask]
170
+ class_conf = class_conf[conf_mask]
171
+ class_pred = class_pred[conf_mask]
172
+ if not image_pred.size(0):
173
+ continue
174
+ #-------------------------------------------------------------------------#
175
+ # detections [num_anchors, 7]
176
+ # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
177
+ #-------------------------------------------------------------------------#
178
+ detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
179
+
180
+ #------------------------------------------#
181
+ # 获得预测结果中包含的所有种类
182
+ #------------------------------------------#
183
+ unique_labels = detections[:, -1].cpu().unique()
184
+
185
+ if prediction.is_cuda:
186
+ unique_labels = unique_labels.cuda()
187
+ detections = detections.cuda()
188
+
189
+ for c in unique_labels:
190
+ #------------------------------------------#
191
+ # 获得某一类得分筛选后全部的预测结果
192
+ #------------------------------------------#
193
+ detections_class = detections[detections[:, -1] == c]
194
+
195
+ #------------------------------------------#
196
+ # 使用官方自带的非极大抑制会速度更快一些!
197
+ #------------------------------------------#
198
+ keep = nms(
199
+ detections_class[:, :4],
200
+ detections_class[:, 4] * detections_class[:, 5],
201
+ nms_thres
202
+ )
203
+ max_detections = detections_class[keep]
204
+
205
+ # # 按照存在物体的置信度排序
206
+ # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
207
+ # detections_class = detections_class[conf_sort_index]
208
+ # # 进行非极大抑制
209
+ # max_detections = []
210
+ # while detections_class.size(0):
211
+ # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
212
+ # max_detections.append(detections_class[0].unsqueeze(0))
213
+ # if len(detections_class) == 1:
214
+ # break
215
+ # ious = bbox_iou(max_detections[-1], detections_class[1:])
216
+ # detections_class = detections_class[1:][ious < nms_thres]
217
+ # # 堆叠
218
+ # max_detections = torch.cat(max_detections).data
219
+
220
+ # Add max detections to outputs
221
+ output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
222
+
223
+ if output[i] is not None:
224
+ output[i] = output[i].cpu().numpy()
225
+ box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
226
+ output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
227
+ return output
utils/utils_fit.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+
6
+ from utils.utils import get_lr
7
+
8
+
9
+ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
10
+ loss = 0
11
+ val_loss = 0
12
+
13
+ if local_rank == 0:
14
+ print('Start Train')
15
+ pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16
+ model_train.train()
17
+ for iteration, batch in enumerate(gen):
18
+ if iteration >= epoch_step:
19
+ break
20
+
21
+ images, targets = batch[0], batch[1]
22
+ with torch.no_grad():
23
+ if cuda:
24
+ images = images.cuda()
25
+ targets = [ann.cuda() for ann in targets]
26
+ #----------------------#
27
+ # 清零梯度
28
+ #----------------------#
29
+ optimizer.zero_grad()
30
+ if not fp16:
31
+ #----------------------#
32
+ # 前向传播
33
+ #----------------------#
34
+ outputs = model_train(images)
35
+
36
+ loss_value_all = 0
37
+ #----------------------#
38
+ # 计算损失
39
+ #----------------------#
40
+ for l in range(len(outputs)):
41
+ loss_item = yolo_loss(l, outputs[l], targets)
42
+ loss_value_all += loss_item
43
+ loss_value = loss_value_all
44
+
45
+ #----------------------#
46
+ # 反向传播
47
+ #----------------------#
48
+ loss_value.backward()
49
+ optimizer.step()
50
+ else:
51
+ from torch.cuda.amp import autocast
52
+ with autocast():
53
+ #----------------------#
54
+ # 前向传播
55
+ #----------------------#
56
+ outputs = model_train(images)
57
+
58
+ loss_value_all = 0
59
+ #----------------------#
60
+ # 计算损失
61
+ #----------------------#
62
+ for l in range(len(outputs)):
63
+ loss_item = yolo_loss(l, outputs[l], targets)
64
+ loss_value_all += loss_item
65
+ loss_value = loss_value_all
66
+
67
+ #----------------------#
68
+ # 反向传播
69
+ #----------------------#
70
+ scaler.scale(loss_value).backward()
71
+ scaler.step(optimizer)
72
+ scaler.update()
73
+
74
+ loss += loss_value.item()
75
+
76
+ if local_rank == 0:
77
+ pbar.set_postfix(**{'loss' : loss / (iteration + 1),
78
+ 'lr' : get_lr(optimizer)})
79
+ pbar.update(1)
80
+
81
+ if local_rank == 0:
82
+ pbar.close()
83
+ print('Finish Train')
84
+ print('Start Validation')
85
+ pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
86
+
87
+ model_train.eval()
88
+ for iteration, batch in enumerate(gen_val):
89
+ if iteration >= epoch_step_val:
90
+ break
91
+ images, targets = batch[0], batch[1]
92
+ with torch.no_grad():
93
+ if cuda:
94
+ images = images.cuda()
95
+ targets = [ann.cuda() for ann in targets]
96
+ #----------------------#
97
+ # 清零梯度
98
+ #----------------------#
99
+ optimizer.zero_grad()
100
+ #----------------------#
101
+ # 前向传播
102
+ #----------------------#
103
+ outputs = model_train(images)
104
+
105
+ loss_value_all = 0
106
+ #----------------------#
107
+ # 计算损失
108
+ #----------------------#
109
+ for l in range(len(outputs)):
110
+ loss_item = yolo_loss(l, outputs[l], targets)
111
+ loss_value_all += loss_item
112
+ loss_value = loss_value_all
113
+
114
+ val_loss += loss_value.item()
115
+ if local_rank == 0:
116
+ pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
117
+ pbar.update(1)
118
+
119
+ if local_rank == 0:
120
+ pbar.close()
121
+ print('Finish Validation')
122
+ loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
123
+ print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
124
+ print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
125
+ if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
126
+ torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
127
+ # 每次保存最后一个权重
128
+ torch.save(model.state_dict(), os.path.join(save_dir, "last.pth" ))
utils/utils_map.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import math
4
+ import operator
5
+ import os
6
+ import shutil
7
+ import sys
8
+
9
+ import cv2
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+
13
+ '''
14
+ 0,0 ------> x (width)
15
+ |
16
+ | (Left,Top)
17
+ | *_________
18
+ | | |
19
+ | |
20
+ y |_________|
21
+ (height) *
22
+ (Right,Bottom)
23
+ '''
24
+
25
+ def log_average_miss_rate(precision, fp_cumsum, num_images):
26
+ """
27
+ log-average miss rate:
28
+ Calculated by averaging miss rates at 9 evenly spaced FPPI points
29
+ between 10e-2 and 10e0, in log-space.
30
+
31
+ output:
32
+ lamr | log-average miss rate
33
+ mr | miss rate
34
+ fppi | false positives per image
35
+
36
+ references:
37
+ [1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
38
+ State of the Art." Pattern Analysis and Machine Intelligence, IEEE
39
+ Transactions on 34.4 (2012): 743 - 761.
40
+ """
41
+
42
+ if precision.size == 0:
43
+ lamr = 0
44
+ mr = 1
45
+ fppi = 0
46
+ return lamr, mr, fppi
47
+
48
+ fppi = fp_cumsum / float(num_images)
49
+ mr = (1 - precision)
50
+
51
+ fppi_tmp = np.insert(fppi, 0, -1.0)
52
+ mr_tmp = np.insert(mr, 0, 1.0)
53
+
54
+ ref = np.logspace(-2.0, 0.0, num = 9)
55
+ for i, ref_i in enumerate(ref):
56
+ j = np.where(fppi_tmp <= ref_i)[-1][-1]
57
+ ref[i] = mr_tmp[j]
58
+
59
+ lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
60
+
61
+ return lamr, mr, fppi
62
+
63
+ """
64
+ throw error and exit
65
+ """
66
+ def error(msg):
67
+ print(msg)
68
+ sys.exit(0)
69
+
70
+ """
71
+ check if the number is a float between 0.0 and 1.0
72
+ """
73
+ def is_float_between_0_and_1(value):
74
+ try:
75
+ val = float(value)
76
+ if val > 0.0 and val < 1.0:
77
+ return True
78
+ else:
79
+ return False
80
+ except ValueError:
81
+ return False
82
+
83
+ """
84
+ Calculate the AP given the recall and precision array
85
+ 1st) We compute a version of the measured precision/recall curve with
86
+ precision monotonically decreasing
87
+ 2nd) We compute the AP as the area under this curve by numerical integration.
88
+ """
89
+ def voc_ap(rec, prec):
90
+ """
91
+ --- Official matlab code VOC2012---
92
+ mrec=[0 ; rec ; 1];
93
+ mpre=[0 ; prec ; 0];
94
+ for i=numel(mpre)-1:-1:1
95
+ mpre(i)=max(mpre(i),mpre(i+1));
96
+ end
97
+ i=find(mrec(2:end)~=mrec(1:end-1))+1;
98
+ ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
99
+ """
100
+ rec.insert(0, 0.0) # insert 0.0 at begining of list
101
+ rec.append(1.0) # insert 1.0 at end of list
102
+ mrec = rec[:]
103
+ prec.insert(0, 0.0) # insert 0.0 at begining of list
104
+ prec.append(0.0) # insert 0.0 at end of list
105
+ mpre = prec[:]
106
+ """
107
+ This part makes the precision monotonically decreasing
108
+ (goes from the end to the beginning)
109
+ matlab: for i=numel(mpre)-1:-1:1
110
+ mpre(i)=max(mpre(i),mpre(i+1));
111
+ """
112
+ for i in range(len(mpre)-2, -1, -1):
113
+ mpre[i] = max(mpre[i], mpre[i+1])
114
+ """
115
+ This part creates a list of indexes where the recall changes
116
+ matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
117
+ """
118
+ i_list = []
119
+ for i in range(1, len(mrec)):
120
+ if mrec[i] != mrec[i-1]:
121
+ i_list.append(i) # if it was matlab would be i + 1
122
+ """
123
+ The Average Precision (AP) is the area under the curve
124
+ (numerical integration)
125
+ matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
126
+ """
127
+ ap = 0.0
128
+ for i in i_list:
129
+ ap += ((mrec[i]-mrec[i-1])*mpre[i])
130
+ return ap, mrec, mpre
131
+
132
+
133
+ """
134
+ Convert the lines of a file to a list
135
+ """
136
+ def file_lines_to_list(path):
137
+ # open txt file lines to a list
138
+ with open(path) as f:
139
+ content = f.readlines()
140
+ # remove whitespace characters like `\n` at the end of each line
141
+ content = [x.strip() for x in content]
142
+ return content
143
+
144
+ """
145
+ Draws text in image
146
+ """
147
+ def draw_text_in_image(img, text, pos, color, line_width):
148
+ font = cv2.FONT_HERSHEY_PLAIN
149
+ fontScale = 1
150
+ lineType = 1
151
+ bottomLeftCornerOfText = pos
152
+ cv2.putText(img, text,
153
+ bottomLeftCornerOfText,
154
+ font,
155
+ fontScale,
156
+ color,
157
+ lineType)
158
+ text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
159
+ return img, (line_width + text_width)
160
+
161
+ """
162
+ Plot - adjust axes
163
+ """
164
+ def adjust_axes(r, t, fig, axes):
165
+ # get text width for re-scaling
166
+ bb = t.get_window_extent(renderer=r)
167
+ text_width_inches = bb.width / fig.dpi
168
+ # get axis width in inches
169
+ current_fig_width = fig.get_figwidth()
170
+ new_fig_width = current_fig_width + text_width_inches
171
+ propotion = new_fig_width / current_fig_width
172
+ # get axis limit
173
+ x_lim = axes.get_xlim()
174
+ axes.set_xlim([x_lim[0], x_lim[1]*propotion])
175
+
176
+ """
177
+ Draw plot using Matplotlib
178
+ """
179
+ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
180
+ # sort the dictionary by decreasing value, into a list of tuples
181
+ sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
182
+ # unpacking the list of tuples into two lists
183
+ sorted_keys, sorted_values = zip(*sorted_dic_by_value)
184
+ #
185
+ if true_p_bar != "":
186
+ """
187
+ Special case to draw in:
188
+ - green -> TP: True Positives (object detected and matches ground-truth)
189
+ - red -> FP: False Positives (object detected but does not match ground-truth)
190
+ - orange -> FN: False Negatives (object not detected but present in the ground-truth)
191
+ """
192
+ fp_sorted = []
193
+ tp_sorted = []
194
+ for key in sorted_keys:
195
+ fp_sorted.append(dictionary[key] - true_p_bar[key])
196
+ tp_sorted.append(true_p_bar[key])
197
+ plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
198
+ plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
199
+ # add legend
200
+ plt.legend(loc='lower right')
201
+ """
202
+ Write number on side of bar
203
+ """
204
+ fig = plt.gcf() # gcf - get current figure
205
+ axes = plt.gca()
206
+ r = fig.canvas.get_renderer()
207
+ for i, val in enumerate(sorted_values):
208
+ fp_val = fp_sorted[i]
209
+ tp_val = tp_sorted[i]
210
+ fp_str_val = " " + str(fp_val)
211
+ tp_str_val = fp_str_val + " " + str(tp_val)
212
+ # trick to paint multicolor with offset:
213
+ # first paint everything and then repaint the first number
214
+ t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
215
+ plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
216
+ if i == (len(sorted_values)-1): # largest bar
217
+ adjust_axes(r, t, fig, axes)
218
+ else:
219
+ plt.barh(range(n_classes), sorted_values, color=plot_color)
220
+ """
221
+ Write number on side of bar
222
+ """
223
+ fig = plt.gcf() # gcf - get current figure
224
+ axes = plt.gca()
225
+ r = fig.canvas.get_renderer()
226
+ for i, val in enumerate(sorted_values):
227
+ str_val = " " + str(val) # add a space before
228
+ if val < 1.0:
229
+ str_val = " {0:.2f}".format(val)
230
+ t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
231
+ # re-set axes to show number inside the figure
232
+ if i == (len(sorted_values)-1): # largest bar
233
+ adjust_axes(r, t, fig, axes)
234
+ # set window title
235
+ fig.canvas.set_window_title(window_title)
236
+ # write classes in y axis
237
+ tick_font_size = 12
238
+ plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
239
+ """
240
+ Re-scale height accordingly
241
+ """
242
+ init_height = fig.get_figheight()
243
+ # comput the matrix height in points and inches
244
+ dpi = fig.dpi
245
+ height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
246
+ height_in = height_pt / dpi
247
+ # compute the required figure height
248
+ top_margin = 0.15 # in percentage of the figure height
249
+ bottom_margin = 0.05 # in percentage of the figure height
250
+ figure_height = height_in / (1 - top_margin - bottom_margin)
251
+ # set new height
252
+ if figure_height > init_height:
253
+ fig.set_figheight(figure_height)
254
+
255
+ # set plot title
256
+ plt.title(plot_title, fontsize=14)
257
+ # set axis titles
258
+ # plt.xlabel('classes')
259
+ plt.xlabel(x_label, fontsize='large')
260
+ # adjust size of window
261
+ fig.tight_layout()
262
+ # save the plot
263
+ fig.savefig(output_path)
264
+ # show image
265
+ if to_show:
266
+ plt.show()
267
+ # close the plot
268
+ plt.close()
269
+
270
+ def get_map(MINOVERLAP, draw_plot, path = './map_out'):
271
+ GT_PATH = os.path.join(path, 'ground-truth')
272
+ DR_PATH = os.path.join(path, 'detection-results')
273
+ IMG_PATH = os.path.join(path, 'images-optional')
274
+ TEMP_FILES_PATH = os.path.join(path, '.temp_files')
275
+ RESULTS_FILES_PATH = os.path.join(path, 'results')
276
+
277
+ show_animation = True
278
+ if os.path.exists(IMG_PATH):
279
+ for dirpath, dirnames, files in os.walk(IMG_PATH):
280
+ if not files:
281
+ show_animation = False
282
+ else:
283
+ show_animation = False
284
+
285
+ if not os.path.exists(TEMP_FILES_PATH):
286
+ os.makedirs(TEMP_FILES_PATH)
287
+
288
+ if os.path.exists(RESULTS_FILES_PATH):
289
+ shutil.rmtree(RESULTS_FILES_PATH)
290
+ if draw_plot:
291
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
292
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
293
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
294
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
295
+ if show_animation:
296
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
297
+
298
+ ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
299
+ if len(ground_truth_files_list) == 0:
300
+ error("Error: No ground-truth files found!")
301
+ ground_truth_files_list.sort()
302
+ gt_counter_per_class = {}
303
+ counter_images_per_class = {}
304
+
305
+ for txt_file in ground_truth_files_list:
306
+ file_id = txt_file.split(".txt", 1)[0]
307
+ file_id = os.path.basename(os.path.normpath(file_id))
308
+ temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
309
+ if not os.path.exists(temp_path):
310
+ error_msg = "Error. File not found: {}\n".format(temp_path)
311
+ error(error_msg)
312
+ lines_list = file_lines_to_list(txt_file)
313
+ bounding_boxes = []
314
+ is_difficult = False
315
+ already_seen_classes = []
316
+ for line in lines_list:
317
+ try:
318
+ if "difficult" in line:
319
+ class_name, left, top, right, bottom, _difficult = line.split()
320
+ is_difficult = True
321
+ else:
322
+ class_name, left, top, right, bottom = line.split()
323
+ except:
324
+ if "difficult" in line:
325
+ line_split = line.split()
326
+ _difficult = line_split[-1]
327
+ bottom = line_split[-2]
328
+ right = line_split[-3]
329
+ top = line_split[-4]
330
+ left = line_split[-5]
331
+ class_name = ""
332
+ for name in line_split[:-5]:
333
+ class_name += name + " "
334
+ class_name = class_name[:-1]
335
+ is_difficult = True
336
+ else:
337
+ line_split = line.split()
338
+ bottom = line_split[-1]
339
+ right = line_split[-2]
340
+ top = line_split[-3]
341
+ left = line_split[-4]
342
+ class_name = ""
343
+ for name in line_split[:-4]:
344
+ class_name += name + " "
345
+ class_name = class_name[:-1]
346
+
347
+ bbox = left + " " + top + " " + right + " " + bottom
348
+ if is_difficult:
349
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
350
+ is_difficult = False
351
+ else:
352
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
353
+ if class_name in gt_counter_per_class:
354
+ gt_counter_per_class[class_name] += 1
355
+ else:
356
+ gt_counter_per_class[class_name] = 1
357
+
358
+ if class_name not in already_seen_classes:
359
+ if class_name in counter_images_per_class:
360
+ counter_images_per_class[class_name] += 1
361
+ else:
362
+ counter_images_per_class[class_name] = 1
363
+ already_seen_classes.append(class_name)
364
+
365
+ with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
366
+ json.dump(bounding_boxes, outfile)
367
+
368
+ gt_classes = list(gt_counter_per_class.keys())
369
+ gt_classes = sorted(gt_classes)
370
+ n_classes = len(gt_classes)
371
+
372
+ dr_files_list = glob.glob(DR_PATH + '/*.txt')
373
+ dr_files_list.sort()
374
+ for class_index, class_name in enumerate(gt_classes):
375
+ bounding_boxes = []
376
+ for txt_file in dr_files_list:
377
+ file_id = txt_file.split(".txt",1)[0]
378
+ file_id = os.path.basename(os.path.normpath(file_id))
379
+ temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
380
+ if class_index == 0:
381
+ if not os.path.exists(temp_path):
382
+ error_msg = "Error. File not found: {}\n".format(temp_path)
383
+ error(error_msg)
384
+ lines = file_lines_to_list(txt_file)
385
+ for line in lines:
386
+ try:
387
+ tmp_class_name, confidence, left, top, right, bottom = line.split()
388
+ except:
389
+ line_split = line.split()
390
+ bottom = line_split[-1]
391
+ right = line_split[-2]
392
+ top = line_split[-3]
393
+ left = line_split[-4]
394
+ confidence = line_split[-5]
395
+ tmp_class_name = ""
396
+ for name in line_split[:-5]:
397
+ tmp_class_name += name + " "
398
+ tmp_class_name = tmp_class_name[:-1]
399
+
400
+ if tmp_class_name == class_name:
401
+ bbox = left + " " + top + " " + right + " " +bottom
402
+ bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
403
+
404
+ bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
405
+ with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
406
+ json.dump(bounding_boxes, outfile)
407
+
408
+ sum_AP = 0.0
409
+ ap_dictionary = {}
410
+ lamr_dictionary = {}
411
+ with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
412
+ results_file.write("# AP and precision/recall per class\n")
413
+ count_true_positives = {}
414
+
415
+ for class_index, class_name in enumerate(gt_classes):
416
+ count_true_positives[class_name] = 0
417
+ dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
418
+ dr_data = json.load(open(dr_file))
419
+
420
+ nd = len(dr_data)
421
+ tp = [0] * nd
422
+ fp = [0] * nd
423
+ score = [0] * nd
424
+ score05_idx = 0
425
+ for idx, detection in enumerate(dr_data):
426
+ file_id = detection["file_id"]
427
+ score[idx] = float(detection["confidence"])
428
+ if score[idx] > 0.5:
429
+ score05_idx = idx
430
+
431
+ if show_animation:
432
+ ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
433
+ if len(ground_truth_img) == 0:
434
+ error("Error. Image not found with id: " + file_id)
435
+ elif len(ground_truth_img) > 1:
436
+ error("Error. Multiple image with id: " + file_id)
437
+ else:
438
+ img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
439
+ img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
440
+ if os.path.isfile(img_cumulative_path):
441
+ img_cumulative = cv2.imread(img_cumulative_path)
442
+ else:
443
+ img_cumulative = img.copy()
444
+ bottom_border = 60
445
+ BLACK = [0, 0, 0]
446
+ img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
447
+
448
+ gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
449
+ ground_truth_data = json.load(open(gt_file))
450
+ ovmax = -1
451
+ gt_match = -1
452
+ bb = [float(x) for x in detection["bbox"].split()]
453
+ for obj in ground_truth_data:
454
+ if obj["class_name"] == class_name:
455
+ bbgt = [ float(x) for x in obj["bbox"].split() ]
456
+ bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
457
+ iw = bi[2] - bi[0] + 1
458
+ ih = bi[3] - bi[1] + 1
459
+ if iw > 0 and ih > 0:
460
+ ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
461
+ + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
462
+ ov = iw * ih / ua
463
+ if ov > ovmax:
464
+ ovmax = ov
465
+ gt_match = obj
466
+
467
+ if show_animation:
468
+ status = "NO MATCH FOUND!"
469
+
470
+ min_overlap = MINOVERLAP
471
+ if ovmax >= min_overlap:
472
+ if "difficult" not in gt_match:
473
+ if not bool(gt_match["used"]):
474
+ tp[idx] = 1
475
+ gt_match["used"] = True
476
+ count_true_positives[class_name] += 1
477
+ with open(gt_file, 'w') as f:
478
+ f.write(json.dumps(ground_truth_data))
479
+ if show_animation:
480
+ status = "MATCH!"
481
+ else:
482
+ fp[idx] = 1
483
+ if show_animation:
484
+ status = "REPEATED MATCH!"
485
+ else:
486
+ fp[idx] = 1
487
+ if ovmax > 0:
488
+ status = "INSUFFICIENT OVERLAP"
489
+
490
+ """
491
+ Draw image to show animation
492
+ """
493
+ if show_animation:
494
+ height, widht = img.shape[:2]
495
+ white = (255,255,255)
496
+ light_blue = (255,200,100)
497
+ green = (0,255,0)
498
+ light_red = (30,30,255)
499
+ margin = 10
500
+ # 1nd line
501
+ v_pos = int(height - margin - (bottom_border / 2.0))
502
+ text = "Image: " + ground_truth_img[0] + " "
503
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
504
+ text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
505
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
506
+ if ovmax != -1:
507
+ color = light_red
508
+ if status == "INSUFFICIENT OVERLAP":
509
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
510
+ else:
511
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
512
+ color = green
513
+ img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
514
+ # 2nd line
515
+ v_pos += int(bottom_border / 2.0)
516
+ rank_pos = str(idx+1)
517
+ text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
518
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
519
+ color = light_red
520
+ if status == "MATCH!":
521
+ color = green
522
+ text = "Result: " + status + " "
523
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
524
+
525
+ font = cv2.FONT_HERSHEY_SIMPLEX
526
+ if ovmax > 0:
527
+ bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
528
+ cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
529
+ cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
530
+ cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
531
+ bb = [int(i) for i in bb]
532
+ cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
533
+ cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
534
+ cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
535
+
536
+ cv2.imshow("Animation", img)
537
+ cv2.waitKey(20)
538
+ output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
539
+ cv2.imwrite(output_img_path, img)
540
+ cv2.imwrite(img_cumulative_path, img_cumulative)
541
+
542
+ cumsum = 0
543
+ for idx, val in enumerate(fp):
544
+ fp[idx] += cumsum
545
+ cumsum += val
546
+
547
+ cumsum = 0
548
+ for idx, val in enumerate(tp):
549
+ tp[idx] += cumsum
550
+ cumsum += val
551
+
552
+ rec = tp[:]
553
+ for idx, val in enumerate(tp):
554
+ rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
555
+
556
+ prec = tp[:]
557
+ for idx, val in enumerate(tp):
558
+ prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
559
+
560
+ ap, mrec, mprec = voc_ap(rec[:], prec[:])
561
+ F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
562
+
563
+ sum_AP += ap
564
+ text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
565
+
566
+ if len(prec)>0:
567
+ F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 "
568
+ Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall "
569
+ Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision "
570
+ else:
571
+ F1_text = "0.00" + " = " + class_name + " F1 "
572
+ Recall_text = "0.00%" + " = " + class_name + " Recall "
573
+ Precision_text = "0.00%" + " = " + class_name + " Precision "
574
+
575
+ rounded_prec = [ '%.2f' % elem for elem in prec ]
576
+ rounded_rec = [ '%.2f' % elem for elem in rec ]
577
+ results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
578
+ if len(prec)>0:
579
+ print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
580
+ + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
581
+ else:
582
+ print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
583
+ ap_dictionary[class_name] = ap
584
+
585
+ n_images = counter_images_per_class[class_name]
586
+ lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
587
+ lamr_dictionary[class_name] = lamr
588
+
589
+ if draw_plot:
590
+ plt.plot(rec, prec, '-o')
591
+ area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
592
+ area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
593
+ plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
594
+
595
+ fig = plt.gcf()
596
+ fig.canvas.set_window_title('AP ' + class_name)
597
+
598
+ plt.title('class: ' + text)
599
+ plt.xlabel('Recall')
600
+ plt.ylabel('Precision')
601
+ axes = plt.gca()
602
+ axes.set_xlim([0.0,1.0])
603
+ axes.set_ylim([0.0,1.05])
604
+ fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
605
+ plt.cla()
606
+
607
+ plt.plot(score, F1, "-", color='orangered')
608
+ plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
609
+ plt.xlabel('Score_Threhold')
610
+ plt.ylabel('F1')
611
+ axes = plt.gca()
612
+ axes.set_xlim([0.0,1.0])
613
+ axes.set_ylim([0.0,1.05])
614
+ fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
615
+ plt.cla()
616
+
617
+ plt.plot(score, rec, "-H", color='gold')
618
+ plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
619
+ plt.xlabel('Score_Threhold')
620
+ plt.ylabel('Recall')
621
+ axes = plt.gca()
622
+ axes.set_xlim([0.0,1.0])
623
+ axes.set_ylim([0.0,1.05])
624
+ fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
625
+ plt.cla()
626
+
627
+ plt.plot(score, prec, "-s", color='palevioletred')
628
+ plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
629
+ plt.xlabel('Score_Threhold')
630
+ plt.ylabel('Precision')
631
+ axes = plt.gca()
632
+ axes.set_xlim([0.0,1.0])
633
+ axes.set_ylim([0.0,1.05])
634
+ fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
635
+ plt.cla()
636
+
637
+ if show_animation:
638
+ cv2.destroyAllWindows()
639
+
640
+ results_file.write("\n# mAP of all classes\n")
641
+ mAP = sum_AP / n_classes
642
+ text = "mAP = {0:.2f}%".format(mAP*100)
643
+ results_file.write(text + "\n")
644
+ print(text)
645
+
646
+ shutil.rmtree(TEMP_FILES_PATH)
647
+
648
+ """
649
+ Count total of detection-results
650
+ """
651
+ det_counter_per_class = {}
652
+ for txt_file in dr_files_list:
653
+ lines_list = file_lines_to_list(txt_file)
654
+ for line in lines_list:
655
+ class_name = line.split()[0]
656
+ if class_name in det_counter_per_class:
657
+ det_counter_per_class[class_name] += 1
658
+ else:
659
+ det_counter_per_class[class_name] = 1
660
+ dr_classes = list(det_counter_per_class.keys())
661
+
662
+ """
663
+ Write number of ground-truth objects per class to results.txt
664
+ """
665
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
666
+ results_file.write("\n# Number of ground-truth objects per class\n")
667
+ for class_name in sorted(gt_counter_per_class):
668
+ results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
669
+
670
+ """
671
+ Finish counting true positives
672
+ """
673
+ for class_name in dr_classes:
674
+ if class_name not in gt_classes:
675
+ count_true_positives[class_name] = 0
676
+
677
+ """
678
+ Write number of detected objects per class to results.txt
679
+ """
680
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
681
+ results_file.write("\n# Number of detected objects per class\n")
682
+ for class_name in sorted(dr_classes):
683
+ n_det = det_counter_per_class[class_name]
684
+ text = class_name + ": " + str(n_det)
685
+ text += " (tp:" + str(count_true_positives[class_name]) + ""
686
+ text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
687
+ results_file.write(text)
688
+
689
+ """
690
+ Plot the total number of occurences of each class in the ground-truth
691
+ """
692
+ if draw_plot:
693
+ window_title = "ground-truth-info"
694
+ plot_title = "ground-truth\n"
695
+ plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
696
+ x_label = "Number of objects per class"
697
+ output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
698
+ to_show = False
699
+ plot_color = 'forestgreen'
700
+ draw_plot_func(
701
+ gt_counter_per_class,
702
+ n_classes,
703
+ window_title,
704
+ plot_title,
705
+ x_label,
706
+ output_path,
707
+ to_show,
708
+ plot_color,
709
+ '',
710
+ )
711
+
712
+ # """
713
+ # Plot the total number of occurences of each class in the "detection-results" folder
714
+ # """
715
+ # if draw_plot:
716
+ # window_title = "detection-results-info"
717
+ # # Plot title
718
+ # plot_title = "detection-results\n"
719
+ # plot_title += "(" + str(len(dr_files_list)) + " files and "
720
+ # count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
721
+ # plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
722
+ # # end Plot title
723
+ # x_label = "Number of objects per class"
724
+ # output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
725
+ # to_show = False
726
+ # plot_color = 'forestgreen'
727
+ # true_p_bar = count_true_positives
728
+ # draw_plot_func(
729
+ # det_counter_per_class,
730
+ # len(det_counter_per_class),
731
+ # window_title,
732
+ # plot_title,
733
+ # x_label,
734
+ # output_path,
735
+ # to_show,
736
+ # plot_color,
737
+ # true_p_bar
738
+ # )
739
+
740
+ """
741
+ Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
742
+ """
743
+ if draw_plot:
744
+ window_title = "lamr"
745
+ plot_title = "log-average miss rate"
746
+ x_label = "log-average miss rate"
747
+ output_path = RESULTS_FILES_PATH + "/lamr.png"
748
+ to_show = False
749
+ plot_color = 'royalblue'
750
+ draw_plot_func(
751
+ lamr_dictionary,
752
+ n_classes,
753
+ window_title,
754
+ plot_title,
755
+ x_label,
756
+ output_path,
757
+ to_show,
758
+ plot_color,
759
+ ""
760
+ )
761
+
762
+ """
763
+ Draw mAP plot (Show AP's of all classes in decreasing order)
764
+ """
765
+ if draw_plot:
766
+ window_title = "mAP"
767
+ plot_title = "mAP = {0:.2f}%".format(mAP*100)
768
+ x_label = "Average Precision"
769
+ output_path = RESULTS_FILES_PATH + "/mAP.png"
770
+ to_show = True
771
+ plot_color = 'royalblue'
772
+ draw_plot_func(
773
+ ap_dictionary,
774
+ n_classes,
775
+ window_title,
776
+ plot_title,
777
+ x_label,
778
+ output_path,
779
+ to_show,
780
+ plot_color,
781
+ ""
782
+ )
783
+
784
+ def preprocess_gt(gt_path, class_names):
785
+ image_ids = os.listdir(gt_path)
786
+ results = {}
787
+
788
+ images = []
789
+ bboxes = []
790
+ for i, image_id in enumerate(image_ids):
791
+ lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
792
+ boxes_per_image = []
793
+ image = {}
794
+ image_id = os.path.splitext(image_id)[0]
795
+ image['file_name'] = image_id + '.jpg'
796
+ image['width'] = 1
797
+ image['height'] = 1
798
+ #-----------------------------------------------------------------#
799
+ # 感谢 多学学英语吧 的提醒
800
+ # 解决了'Results do not correspond to current coco set'问题
801
+ #-----------------------------------------------------------------#
802
+ image['id'] = str(image_id)
803
+
804
+ for line in lines_list:
805
+ difficult = 0
806
+ if "difficult" in line:
807
+ line_split = line.split()
808
+ left, top, right, bottom, _difficult = line_split[-5:]
809
+ class_name = ""
810
+ for name in line_split[:-5]:
811
+ class_name += name + " "
812
+ class_name = class_name[:-1]
813
+ difficult = 1
814
+ else:
815
+ line_split = line.split()
816
+ left, top, right, bottom = line_split[-4:]
817
+ class_name = ""
818
+ for name in line_split[:-4]:
819
+ class_name += name + " "
820
+ class_name = class_name[:-1]
821
+
822
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
823
+ cls_id = class_names.index(class_name) + 1
824
+ bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
825
+ boxes_per_image.append(bbox)
826
+ images.append(image)
827
+ bboxes.extend(boxes_per_image)
828
+ results['images'] = images
829
+
830
+ categories = []
831
+ for i, cls in enumerate(class_names):
832
+ category = {}
833
+ category['supercategory'] = cls
834
+ category['name'] = cls
835
+ category['id'] = i + 1
836
+ categories.append(category)
837
+ results['categories'] = categories
838
+
839
+ annotations = []
840
+ for i, box in enumerate(bboxes):
841
+ annotation = {}
842
+ annotation['area'] = box[-1]
843
+ annotation['category_id'] = box[-2]
844
+ annotation['image_id'] = box[-3]
845
+ annotation['iscrowd'] = box[-4]
846
+ annotation['bbox'] = box[:4]
847
+ annotation['id'] = i
848
+ annotations.append(annotation)
849
+ results['annotations'] = annotations
850
+ return results
851
+
852
+ def preprocess_dr(dr_path, class_names):
853
+ image_ids = os.listdir(dr_path)
854
+ results = []
855
+ for image_id in image_ids:
856
+ lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
857
+ image_id = os.path.splitext(image_id)[0]
858
+ for line in lines_list:
859
+ line_split = line.split()
860
+ confidence, left, top, right, bottom = line_split[-5:]
861
+ class_name = ""
862
+ for name in line_split[:-5]:
863
+ class_name += name + " "
864
+ class_name = class_name[:-1]
865
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
866
+ result = {}
867
+ result["image_id"] = str(image_id)
868
+ result["category_id"] = class_names.index(class_name) + 1
869
+ result["bbox"] = [left, top, right - left, bottom - top]
870
+ result["score"] = float(confidence)
871
+ results.append(result)
872
+ return results
873
+
874
+ def get_coco_map(class_names, path):
875
+ from pycocotools.coco import COCO
876
+ from pycocotools.cocoeval import COCOeval
877
+
878
+ GT_PATH = os.path.join(path, 'ground-truth')
879
+ DR_PATH = os.path.join(path, 'detection-results')
880
+ COCO_PATH = os.path.join(path, 'coco_eval')
881
+
882
+ if not os.path.exists(COCO_PATH):
883
+ os.makedirs(COCO_PATH)
884
+
885
+ GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
886
+ DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
887
+
888
+ with open(GT_JSON_PATH, "w") as f:
889
+ results_gt = preprocess_gt(GT_PATH, class_names)
890
+ json.dump(results_gt, f, indent=4)
891
+
892
+ with open(DR_JSON_PATH, "w") as f:
893
+ results_dr = preprocess_dr(DR_PATH, class_names)
894
+ json.dump(results_dr, f, indent=4)
895
+
896
+ cocoGt = COCO(GT_JSON_PATH)
897
+ cocoDt = cocoGt.loadRes(DR_JSON_PATH)
898
+ cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
899
+ cocoEval.evaluate()
900
+ cocoEval.accumulate()
901
+ cocoEval.summarize()