uestc_yhr
commited on
Commit
·
71b93be
1
Parent(s):
6b6db36
Add
Browse files- class_indices.json +7 -0
- model.py +377 -0
- my_dataset.py +37 -0
- predict.py +65 -0
- train.py +143 -0
- trans_effv2_weights.py +160 -0
- utils.py +175 -0
- weights/model-20.pth +3 -0
class_indices.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"0": "daisy",
|
3 |
+
"1": "dandelion",
|
4 |
+
"2": "roses",
|
5 |
+
"3": "sunflowers",
|
6 |
+
"4": "tulips"
|
7 |
+
}
|
model.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable, Optional
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
|
9 |
+
|
10 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
11 |
+
"""
|
12 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
13 |
+
"Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
|
14 |
+
|
15 |
+
This function is taken from the rwightman.
|
16 |
+
It can be seen here:
|
17 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L140
|
18 |
+
"""
|
19 |
+
if drop_prob == 0. or not training:
|
20 |
+
return x
|
21 |
+
keep_prob = 1 - drop_prob
|
22 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
23 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
24 |
+
random_tensor.floor_() # binarize
|
25 |
+
output = x.div(keep_prob) * random_tensor
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
class DropPath(nn.Module):
|
30 |
+
"""
|
31 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
32 |
+
"Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
|
33 |
+
"""
|
34 |
+
def __init__(self, drop_prob=None):
|
35 |
+
super(DropPath, self).__init__()
|
36 |
+
self.drop_prob = drop_prob
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return drop_path(x, self.drop_prob, self.training)
|
40 |
+
|
41 |
+
|
42 |
+
class ConvBNAct(nn.Module):
|
43 |
+
def __init__(self,
|
44 |
+
in_planes: int,
|
45 |
+
out_planes: int,
|
46 |
+
kernel_size: int = 3,
|
47 |
+
stride: int = 1,
|
48 |
+
groups: int = 1,
|
49 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
50 |
+
activation_layer: Optional[Callable[..., nn.Module]] = None):
|
51 |
+
super(ConvBNAct, self).__init__()
|
52 |
+
|
53 |
+
padding = (kernel_size - 1) // 2
|
54 |
+
if norm_layer is None:
|
55 |
+
norm_layer = nn.BatchNorm2d
|
56 |
+
if activation_layer is None:
|
57 |
+
activation_layer = nn.SiLU # alias Swish (torch>=1.7)
|
58 |
+
|
59 |
+
self.conv = nn.Conv2d(in_channels=in_planes,
|
60 |
+
out_channels=out_planes,
|
61 |
+
kernel_size=kernel_size,
|
62 |
+
stride=stride,
|
63 |
+
padding=padding,
|
64 |
+
groups=groups,
|
65 |
+
bias=False)
|
66 |
+
|
67 |
+
self.bn = norm_layer(out_planes)
|
68 |
+
self.act = activation_layer()
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
result = self.conv(x)
|
72 |
+
result = self.bn(result)
|
73 |
+
result = self.act(result)
|
74 |
+
|
75 |
+
return result
|
76 |
+
|
77 |
+
|
78 |
+
class SqueezeExcite(nn.Module):
|
79 |
+
def __init__(self,
|
80 |
+
input_c: int, # block input channel
|
81 |
+
expand_c: int, # block expand channel
|
82 |
+
se_ratio: float = 0.25):
|
83 |
+
super(SqueezeExcite, self).__init__()
|
84 |
+
squeeze_c = int(input_c * se_ratio)
|
85 |
+
self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1)
|
86 |
+
self.act1 = nn.SiLU() # alias Swish
|
87 |
+
self.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1)
|
88 |
+
self.act2 = nn.Sigmoid()
|
89 |
+
|
90 |
+
def forward(self, x: Tensor) -> Tensor:
|
91 |
+
scale = x.mean((2, 3), keepdim=True)
|
92 |
+
scale = self.conv_reduce(scale)
|
93 |
+
scale = self.act1(scale)
|
94 |
+
scale = self.conv_expand(scale)
|
95 |
+
scale = self.act2(scale)
|
96 |
+
return scale * x
|
97 |
+
|
98 |
+
|
99 |
+
class MBConv(nn.Module):
|
100 |
+
def __init__(self,
|
101 |
+
kernel_size: int,
|
102 |
+
input_c: int,
|
103 |
+
out_c: int,
|
104 |
+
expand_ratio: int,
|
105 |
+
stride: int,
|
106 |
+
se_ratio: float,
|
107 |
+
drop_rate: float,
|
108 |
+
norm_layer: Callable[..., nn.Module]):
|
109 |
+
super(MBConv, self).__init__()
|
110 |
+
|
111 |
+
if stride not in [1, 2]:
|
112 |
+
raise ValueError("illegal stride value.")
|
113 |
+
|
114 |
+
self.has_shortcut = (stride == 1 and input_c == out_c)
|
115 |
+
|
116 |
+
activation_layer = nn.SiLU # alias Swish
|
117 |
+
expanded_c = input_c * expand_ratio
|
118 |
+
|
119 |
+
# 在EfficientNetV2中,MBConv中不存在expansion=1的情况所以conv_pw肯定存在
|
120 |
+
assert expand_ratio != 1
|
121 |
+
# Point-wise expansion
|
122 |
+
self.expand_conv = ConvBNAct(input_c,
|
123 |
+
expanded_c,
|
124 |
+
kernel_size=1,
|
125 |
+
norm_layer=norm_layer,
|
126 |
+
activation_layer=activation_layer)
|
127 |
+
|
128 |
+
# Depth-wise convolution
|
129 |
+
self.dwconv = ConvBNAct(expanded_c,
|
130 |
+
expanded_c,
|
131 |
+
kernel_size=kernel_size,
|
132 |
+
stride=stride,
|
133 |
+
groups=expanded_c,
|
134 |
+
norm_layer=norm_layer,
|
135 |
+
activation_layer=activation_layer)
|
136 |
+
|
137 |
+
self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else nn.Identity()
|
138 |
+
|
139 |
+
# Point-wise linear projection
|
140 |
+
self.project_conv = ConvBNAct(expanded_c,
|
141 |
+
out_planes=out_c,
|
142 |
+
kernel_size=1,
|
143 |
+
norm_layer=norm_layer,
|
144 |
+
activation_layer=nn.Identity) # 注意这里没有激活函数,所有传入Identity
|
145 |
+
|
146 |
+
self.out_channels = out_c
|
147 |
+
|
148 |
+
# 只有在使用shortcut连接时才使用dropout层
|
149 |
+
self.drop_rate = drop_rate
|
150 |
+
if self.has_shortcut and drop_rate > 0:
|
151 |
+
self.dropout = DropPath(drop_rate)
|
152 |
+
|
153 |
+
def forward(self, x: Tensor) -> Tensor:
|
154 |
+
result = self.expand_conv(x)
|
155 |
+
result = self.dwconv(result)
|
156 |
+
result = self.se(result)
|
157 |
+
result = self.project_conv(result)
|
158 |
+
|
159 |
+
if self.has_shortcut:
|
160 |
+
if self.drop_rate > 0:
|
161 |
+
result = self.dropout(result)
|
162 |
+
result += x
|
163 |
+
|
164 |
+
return result
|
165 |
+
|
166 |
+
|
167 |
+
class FusedMBConv(nn.Module):
|
168 |
+
def __init__(self,
|
169 |
+
kernel_size: int,
|
170 |
+
input_c: int,
|
171 |
+
out_c: int,
|
172 |
+
expand_ratio: int,
|
173 |
+
stride: int,
|
174 |
+
se_ratio: float,
|
175 |
+
drop_rate: float,
|
176 |
+
norm_layer: Callable[..., nn.Module]):
|
177 |
+
super(FusedMBConv, self).__init__()
|
178 |
+
|
179 |
+
assert stride in [1, 2]
|
180 |
+
assert se_ratio == 0
|
181 |
+
|
182 |
+
self.has_shortcut = stride == 1 and input_c == out_c
|
183 |
+
self.drop_rate = drop_rate
|
184 |
+
|
185 |
+
self.has_expansion = expand_ratio != 1
|
186 |
+
|
187 |
+
activation_layer = nn.SiLU # alias Swish
|
188 |
+
expanded_c = input_c * expand_ratio
|
189 |
+
|
190 |
+
# 只有当expand ratio不等于1时才有expand conv
|
191 |
+
if self.has_expansion:
|
192 |
+
# Expansion convolution
|
193 |
+
self.expand_conv = ConvBNAct(input_c,
|
194 |
+
expanded_c,
|
195 |
+
kernel_size=kernel_size,
|
196 |
+
stride=stride,
|
197 |
+
norm_layer=norm_layer,
|
198 |
+
activation_layer=activation_layer)
|
199 |
+
|
200 |
+
self.project_conv = ConvBNAct(expanded_c,
|
201 |
+
out_c,
|
202 |
+
kernel_size=1,
|
203 |
+
norm_layer=norm_layer,
|
204 |
+
activation_layer=nn.Identity) # 注意没有激活函数
|
205 |
+
else:
|
206 |
+
# 当只有project_conv时的情况
|
207 |
+
self.project_conv = ConvBNAct(input_c,
|
208 |
+
out_c,
|
209 |
+
kernel_size=kernel_size,
|
210 |
+
stride=stride,
|
211 |
+
norm_layer=norm_layer,
|
212 |
+
activation_layer=activation_layer) # 注意有激活函数
|
213 |
+
|
214 |
+
self.out_channels = out_c
|
215 |
+
|
216 |
+
# 只有在使用shortcut连接时才使用dropout层
|
217 |
+
self.drop_rate = drop_rate
|
218 |
+
if self.has_shortcut and drop_rate > 0:
|
219 |
+
self.dropout = DropPath(drop_rate)
|
220 |
+
|
221 |
+
def forward(self, x: Tensor) -> Tensor:
|
222 |
+
if self.has_expansion:
|
223 |
+
result = self.expand_conv(x)
|
224 |
+
result = self.project_conv(result)
|
225 |
+
else:
|
226 |
+
result = self.project_conv(x)
|
227 |
+
|
228 |
+
if self.has_shortcut:
|
229 |
+
if self.drop_rate > 0:
|
230 |
+
result = self.dropout(result)
|
231 |
+
|
232 |
+
result += x
|
233 |
+
|
234 |
+
return result
|
235 |
+
|
236 |
+
|
237 |
+
class EfficientNetV2(nn.Module):
|
238 |
+
def __init__(self,
|
239 |
+
model_cnf: list,
|
240 |
+
num_classes: int = 1000,
|
241 |
+
num_features: int = 1280,
|
242 |
+
dropout_rate: float = 0.2,
|
243 |
+
drop_connect_rate: float = 0.2):
|
244 |
+
super(EfficientNetV2, self).__init__()
|
245 |
+
|
246 |
+
for cnf in model_cnf:
|
247 |
+
assert len(cnf) == 8
|
248 |
+
|
249 |
+
norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)
|
250 |
+
|
251 |
+
stem_filter_num = model_cnf[0][4]
|
252 |
+
|
253 |
+
self.stem = ConvBNAct(3,
|
254 |
+
stem_filter_num,
|
255 |
+
kernel_size=3,
|
256 |
+
stride=2,
|
257 |
+
norm_layer=norm_layer) # 激活函数默认是SiLU
|
258 |
+
|
259 |
+
total_blocks = sum([i[0] for i in model_cnf])
|
260 |
+
block_id = 0
|
261 |
+
blocks = []
|
262 |
+
for cnf in model_cnf:
|
263 |
+
repeats = cnf[0]
|
264 |
+
op = FusedMBConv if cnf[-2] == 0 else MBConv
|
265 |
+
for i in range(repeats):
|
266 |
+
blocks.append(op(kernel_size=cnf[1],
|
267 |
+
input_c=cnf[4] if i == 0 else cnf[5],
|
268 |
+
out_c=cnf[5],
|
269 |
+
expand_ratio=cnf[3],
|
270 |
+
stride=cnf[2] if i == 0 else 1,
|
271 |
+
se_ratio=cnf[-1],
|
272 |
+
drop_rate=drop_connect_rate * block_id / total_blocks,
|
273 |
+
norm_layer=norm_layer))
|
274 |
+
block_id += 1
|
275 |
+
self.blocks = nn.Sequential(*blocks)
|
276 |
+
|
277 |
+
head_input_c = model_cnf[-1][-3]
|
278 |
+
head = OrderedDict()
|
279 |
+
|
280 |
+
head.update({"project_conv": ConvBNAct(head_input_c,
|
281 |
+
num_features,
|
282 |
+
kernel_size=1,
|
283 |
+
norm_layer=norm_layer)}) # 激活函数默认是SiLU
|
284 |
+
|
285 |
+
head.update({"avgpool": nn.AdaptiveAvgPool2d(1)})
|
286 |
+
head.update({"flatten": nn.Flatten()})
|
287 |
+
|
288 |
+
if dropout_rate > 0:
|
289 |
+
head.update({"dropout": nn.Dropout(p=dropout_rate, inplace=True)})
|
290 |
+
head.update({"classifier": nn.Linear(num_features, num_classes)})
|
291 |
+
|
292 |
+
self.head = nn.Sequential(head)
|
293 |
+
|
294 |
+
# initial weights
|
295 |
+
for m in self.modules():
|
296 |
+
if isinstance(m, nn.Conv2d):
|
297 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
298 |
+
if m.bias is not None:
|
299 |
+
nn.init.zeros_(m.bias)
|
300 |
+
elif isinstance(m, nn.BatchNorm2d):
|
301 |
+
nn.init.ones_(m.weight)
|
302 |
+
nn.init.zeros_(m.bias)
|
303 |
+
elif isinstance(m, nn.Linear):
|
304 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
305 |
+
nn.init.zeros_(m.bias)
|
306 |
+
|
307 |
+
def forward(self, x: Tensor) -> Tensor:
|
308 |
+
x = self.stem(x)
|
309 |
+
x = self.blocks(x)
|
310 |
+
x = self.head(x)
|
311 |
+
|
312 |
+
return x
|
313 |
+
|
314 |
+
|
315 |
+
def efficientnetv2_s(num_classes: int = 1000):
|
316 |
+
"""
|
317 |
+
EfficientNetV2
|
318 |
+
https://arxiv.org/abs/2104.00298
|
319 |
+
"""
|
320 |
+
# train_size: 300, eval_size: 384
|
321 |
+
|
322 |
+
# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
|
323 |
+
model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
|
324 |
+
[4, 3, 2, 4, 24, 48, 0, 0],
|
325 |
+
[4, 3, 2, 4, 48, 64, 0, 0],
|
326 |
+
[6, 3, 2, 4, 64, 128, 1, 0.25],
|
327 |
+
[9, 3, 1, 6, 128, 160, 1, 0.25],
|
328 |
+
[15, 3, 2, 6, 160, 256, 1, 0.25]]
|
329 |
+
|
330 |
+
model = EfficientNetV2(model_cnf=model_config,
|
331 |
+
num_classes=num_classes,
|
332 |
+
dropout_rate=0.2)
|
333 |
+
return model
|
334 |
+
|
335 |
+
|
336 |
+
def efficientnetv2_m(num_classes: int = 1000):
|
337 |
+
"""
|
338 |
+
EfficientNetV2
|
339 |
+
https://arxiv.org/abs/2104.00298
|
340 |
+
"""
|
341 |
+
# train_size: 384, eval_size: 480
|
342 |
+
|
343 |
+
# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
|
344 |
+
model_config = [[3, 3, 1, 1, 24, 24, 0, 0],
|
345 |
+
[5, 3, 2, 4, 24, 48, 0, 0],
|
346 |
+
[5, 3, 2, 4, 48, 80, 0, 0],
|
347 |
+
[7, 3, 2, 4, 80, 160, 1, 0.25],
|
348 |
+
[14, 3, 1, 6, 160, 176, 1, 0.25],
|
349 |
+
[18, 3, 2, 6, 176, 304, 1, 0.25],
|
350 |
+
[5, 3, 1, 6, 304, 512, 1, 0.25]]
|
351 |
+
|
352 |
+
model = EfficientNetV2(model_cnf=model_config,
|
353 |
+
num_classes=num_classes,
|
354 |
+
dropout_rate=0.3)
|
355 |
+
return model
|
356 |
+
|
357 |
+
|
358 |
+
def efficientnetv2_l(num_classes: int = 1000):
|
359 |
+
"""
|
360 |
+
EfficientNetV2
|
361 |
+
https://arxiv.org/abs/2104.00298
|
362 |
+
"""
|
363 |
+
# train_size: 384, eval_size: 480
|
364 |
+
|
365 |
+
# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
|
366 |
+
model_config = [[4, 3, 1, 1, 32, 32, 0, 0],
|
367 |
+
[7, 3, 2, 4, 32, 64, 0, 0],
|
368 |
+
[7, 3, 2, 4, 64, 96, 0, 0],
|
369 |
+
[10, 3, 2, 4, 96, 192, 1, 0.25],
|
370 |
+
[19, 3, 1, 6, 192, 224, 1, 0.25],
|
371 |
+
[25, 3, 2, 6, 224, 384, 1, 0.25],
|
372 |
+
[7, 3, 1, 6, 384, 640, 1, 0.25]]
|
373 |
+
|
374 |
+
model = EfficientNetV2(model_cnf=model_config,
|
375 |
+
num_classes=num_classes,
|
376 |
+
dropout_rate=0.4)
|
377 |
+
return model
|
my_dataset.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
|
6 |
+
class MyDataSet(Dataset):
|
7 |
+
"""自定义数据集"""
|
8 |
+
|
9 |
+
def __init__(self, images_path: list, images_class: list, transform=None):
|
10 |
+
self.images_path = images_path
|
11 |
+
self.images_class = images_class
|
12 |
+
self.transform = transform
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.images_path)
|
16 |
+
|
17 |
+
def __getitem__(self, item):
|
18 |
+
img = Image.open(self.images_path[item])
|
19 |
+
# RGB为彩色图片,L为灰度图片
|
20 |
+
if img.mode != 'RGB':
|
21 |
+
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
|
22 |
+
label = self.images_class[item]
|
23 |
+
|
24 |
+
if self.transform is not None:
|
25 |
+
img = self.transform(img)
|
26 |
+
|
27 |
+
return img, label
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def collate_fn(batch):
|
31 |
+
# 官方实现的default_collate可以参考
|
32 |
+
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
|
33 |
+
images, labels = tuple(zip(*batch))
|
34 |
+
|
35 |
+
images = torch.stack(images, dim=0)
|
36 |
+
labels = torch.as_tensor(labels)
|
37 |
+
return images, labels
|
predict.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
from model import efficientnetv2_m as create_model
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
img_size = {"s": [300, 384], # train_size, val_size
|
16 |
+
"m": [384, 480],
|
17 |
+
"l": [384, 480]}
|
18 |
+
num_model = "s"
|
19 |
+
|
20 |
+
data_transform = transforms.Compose(
|
21 |
+
[transforms.Resize(img_size[num_model][1]),
|
22 |
+
transforms.CenterCrop(img_size[num_model][1]),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
25 |
+
|
26 |
+
# load image
|
27 |
+
img_path = "../d.jpg"
|
28 |
+
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
|
29 |
+
img = Image.open(img_path)
|
30 |
+
plt.imshow(img)
|
31 |
+
# [N, C, H, W]
|
32 |
+
img = data_transform(img)
|
33 |
+
# expand batch dimension
|
34 |
+
img = torch.unsqueeze(img, dim=0)
|
35 |
+
|
36 |
+
# read class_indict
|
37 |
+
json_path = './class_indices.json'
|
38 |
+
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
|
39 |
+
|
40 |
+
json_file = open(json_path, "r")
|
41 |
+
class_indict = json.load(json_file)
|
42 |
+
|
43 |
+
# create model
|
44 |
+
model = create_model(num_classes=5).to(device)
|
45 |
+
# load model weights
|
46 |
+
model_weight_path = "./weights/model-20.pth"
|
47 |
+
model.load_state_dict(torch.load(model_weight_path, map_location=device))
|
48 |
+
model.eval()
|
49 |
+
with torch.no_grad():
|
50 |
+
# predict class
|
51 |
+
output = torch.squeeze(model(img.to(device))).cpu()
|
52 |
+
predict = torch.softmax(output, dim=0)
|
53 |
+
predict_cla = torch.argmax(predict).numpy()
|
54 |
+
|
55 |
+
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
|
56 |
+
predict[predict_cla].numpy())
|
57 |
+
plt.title(print_res)
|
58 |
+
for i in range(len(predict)):
|
59 |
+
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
|
60 |
+
predict[i].numpy()))
|
61 |
+
plt.show()
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
from torchvision import transforms
|
9 |
+
import torch.optim.lr_scheduler as lr_scheduler
|
10 |
+
|
11 |
+
from model import efficientnetv2_m as create_model
|
12 |
+
from my_dataset import MyDataSet
|
13 |
+
from utils import read_split_data, train_one_epoch, evaluate
|
14 |
+
|
15 |
+
|
16 |
+
def main(args):
|
17 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
print(args)
|
20 |
+
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
|
21 |
+
tb_writer = SummaryWriter()
|
22 |
+
if os.path.exists("./weights") is False:
|
23 |
+
os.makedirs("./weights")
|
24 |
+
|
25 |
+
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
|
26 |
+
|
27 |
+
img_size = {"s": [300, 384], # train_size, val_size
|
28 |
+
"m": [384, 480],
|
29 |
+
"l": [384, 480]}
|
30 |
+
num_model = "s"
|
31 |
+
|
32 |
+
data_transform = {
|
33 |
+
"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
|
34 |
+
transforms.RandomHorizontalFlip(),
|
35 |
+
transforms.ToTensor(),
|
36 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
37 |
+
"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
|
38 |
+
transforms.CenterCrop(img_size[num_model][1]),
|
39 |
+
transforms.ToTensor(),
|
40 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
|
41 |
+
|
42 |
+
# 实例化训练数据集
|
43 |
+
train_dataset = MyDataSet(images_path=train_images_path,
|
44 |
+
images_class=train_images_label,
|
45 |
+
transform=data_transform["train"])
|
46 |
+
|
47 |
+
# 实例化验证数据集
|
48 |
+
val_dataset = MyDataSet(images_path=val_images_path,
|
49 |
+
images_class=val_images_label,
|
50 |
+
transform=data_transform["val"])
|
51 |
+
|
52 |
+
batch_size = args.batch_size
|
53 |
+
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
|
54 |
+
print('Using {} dataloader workers every process'.format(nw))
|
55 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
56 |
+
batch_size=batch_size,
|
57 |
+
shuffle=True,
|
58 |
+
pin_memory=True,
|
59 |
+
num_workers=nw,
|
60 |
+
collate_fn=train_dataset.collate_fn)
|
61 |
+
|
62 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
63 |
+
batch_size=batch_size,
|
64 |
+
shuffle=False,
|
65 |
+
pin_memory=True,
|
66 |
+
num_workers=nw,
|
67 |
+
collate_fn=val_dataset.collate_fn)
|
68 |
+
|
69 |
+
# 如果存在预训练权重则载入
|
70 |
+
model = create_model(num_classes=args.num_classes).to(device)
|
71 |
+
if args.weights != "":
|
72 |
+
if os.path.exists(args.weights):
|
73 |
+
weights_dict = torch.load(args.weights, map_location=device)
|
74 |
+
load_weights_dict = {k: v for k, v in weights_dict.items()
|
75 |
+
if model.state_dict()[k].numel() == v.numel()}
|
76 |
+
print(model.load_state_dict(load_weights_dict, strict=False))
|
77 |
+
else:
|
78 |
+
raise FileNotFoundError("not found weights file: {}".format(args.weights))
|
79 |
+
|
80 |
+
# 是否冻结权重
|
81 |
+
if args.freeze_layers:
|
82 |
+
for name, para in model.named_parameters():
|
83 |
+
# 除head外,其他权重全部冻结
|
84 |
+
if "head" not in name:
|
85 |
+
para.requires_grad_(False)
|
86 |
+
else:
|
87 |
+
print("training {}".format(name))
|
88 |
+
|
89 |
+
pg = [p for p in model.parameters() if p.requires_grad]
|
90 |
+
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
|
91 |
+
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
92 |
+
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
|
93 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
94 |
+
|
95 |
+
for epoch in range(args.epochs):
|
96 |
+
# train
|
97 |
+
train_loss, train_acc = train_one_epoch(model=model,
|
98 |
+
optimizer=optimizer,
|
99 |
+
data_loader=train_loader,
|
100 |
+
device=device,
|
101 |
+
epoch=epoch)
|
102 |
+
|
103 |
+
scheduler.step()
|
104 |
+
|
105 |
+
# validate
|
106 |
+
val_loss, val_acc = evaluate(model=model,
|
107 |
+
data_loader=val_loader,
|
108 |
+
device=device,
|
109 |
+
epoch=epoch)
|
110 |
+
|
111 |
+
tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
|
112 |
+
tb_writer.add_scalar(tags[0], train_loss, epoch)
|
113 |
+
tb_writer.add_scalar(tags[1], train_acc, epoch)
|
114 |
+
tb_writer.add_scalar(tags[2], val_loss, epoch)
|
115 |
+
tb_writer.add_scalar(tags[3], val_acc, epoch)
|
116 |
+
tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
|
117 |
+
|
118 |
+
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
parser.add_argument('--num_classes', type=int, default=5)
|
124 |
+
parser.add_argument('--epochs', type=int, default=30)
|
125 |
+
parser.add_argument('--batch-size', type=int, default=8)
|
126 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
127 |
+
parser.add_argument('--lrf', type=float, default=0.01)
|
128 |
+
|
129 |
+
# 数据集所在根目录
|
130 |
+
# http://download.tensorflow.org/example_images/flower_photos.tgz
|
131 |
+
parser.add_argument('--data-path', type=str,
|
132 |
+
default="../../data_set/flower_data/flower_photos")
|
133 |
+
|
134 |
+
# download model weights
|
135 |
+
# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1
|
136 |
+
parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-m.pth',
|
137 |
+
help='initial weights path')
|
138 |
+
parser.add_argument('--freeze-layers', type=bool, default=True)
|
139 |
+
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
|
140 |
+
|
141 |
+
opt = parser.parse_args()
|
142 |
+
|
143 |
+
main(opt)
|
trans_effv2_weights.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def main(model_name: str = "efficientnetv2-s",
|
7 |
+
tf_weights_path: str = "./efficientnetv2-s/model",
|
8 |
+
stage0_num: int = 2,
|
9 |
+
fused_conv_num: int = 10):
|
10 |
+
|
11 |
+
except_var = ["global_step"]
|
12 |
+
|
13 |
+
new_weights = {}
|
14 |
+
var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]]
|
15 |
+
reader = tf.train.load_checkpoint(tf_weights_path)
|
16 |
+
for v in var_list:
|
17 |
+
if v[0] in except_var:
|
18 |
+
continue
|
19 |
+
new_name = v[0].replace(model_name + "/", "").replace("/", ".")
|
20 |
+
|
21 |
+
if "stem" in v[0]:
|
22 |
+
new_name = new_name.replace("conv2d.kernel",
|
23 |
+
"conv.weight")
|
24 |
+
|
25 |
+
new_name = new_name.replace("tpu_batch_normalization.beta",
|
26 |
+
"bn.bias")
|
27 |
+
new_name = new_name.replace("tpu_batch_normalization.gamma",
|
28 |
+
"bn.weight")
|
29 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
|
30 |
+
"bn.running_mean")
|
31 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
|
32 |
+
"bn.running_var")
|
33 |
+
elif "head" in v[0]:
|
34 |
+
new_name = new_name.replace("conv2d.kernel",
|
35 |
+
"project_conv.conv.weight")
|
36 |
+
new_name = new_name.replace("dense.kernel",
|
37 |
+
"classifier.weight")
|
38 |
+
new_name = new_name.replace("dense.bias",
|
39 |
+
"classifier.bias")
|
40 |
+
|
41 |
+
new_name = new_name.replace("tpu_batch_normalization.beta",
|
42 |
+
"project_conv.bn.bias")
|
43 |
+
new_name = new_name.replace("tpu_batch_normalization.gamma",
|
44 |
+
"project_conv.bn.weight")
|
45 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
|
46 |
+
"project_conv.bn.running_mean")
|
47 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
|
48 |
+
"project_conv.bn.running_var")
|
49 |
+
elif "blocks" in v[0]:
|
50 |
+
# e.g. blocks_0.conv2d.kernel -> 0
|
51 |
+
blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "")
|
52 |
+
new_name = new_name.replace("blocks_{}".format(blocks_id),
|
53 |
+
"blocks.{}".format(blocks_id))
|
54 |
+
|
55 |
+
if int(blocks_id) <= stage0_num - 1: # expansion=1 fused_mbconv
|
56 |
+
new_name = new_name.replace("conv2d.kernel",
|
57 |
+
"project_conv.conv.weight")
|
58 |
+
new_name = new_name.replace("tpu_batch_normalization.beta",
|
59 |
+
"project_conv.bn.bias")
|
60 |
+
new_name = new_name.replace("tpu_batch_normalization.gamma",
|
61 |
+
"project_conv.bn.weight")
|
62 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
|
63 |
+
"project_conv.bn.running_mean")
|
64 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
|
65 |
+
"project_conv.bn.running_var")
|
66 |
+
else:
|
67 |
+
new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id),
|
68 |
+
"blocks.{}.expand_conv.conv.weight".format(blocks_id))
|
69 |
+
new_name = new_name.replace("tpu_batch_normalization.beta",
|
70 |
+
"expand_conv.bn.bias")
|
71 |
+
new_name = new_name.replace("tpu_batch_normalization.gamma",
|
72 |
+
"expand_conv.bn.weight")
|
73 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_mean",
|
74 |
+
"expand_conv.bn.running_mean")
|
75 |
+
new_name = new_name.replace("tpu_batch_normalization.moving_variance",
|
76 |
+
"expand_conv.bn.running_var")
|
77 |
+
|
78 |
+
if int(blocks_id) <= fused_conv_num - 1: # fused_mbconv
|
79 |
+
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
|
80 |
+
"blocks.{}.project_conv.conv.weight".format(blocks_id))
|
81 |
+
new_name = new_name.replace("tpu_batch_normalization_1.beta",
|
82 |
+
"project_conv.bn.bias")
|
83 |
+
new_name = new_name.replace("tpu_batch_normalization_1.gamma",
|
84 |
+
"project_conv.bn.weight")
|
85 |
+
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
|
86 |
+
"project_conv.bn.running_mean")
|
87 |
+
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
|
88 |
+
"project_conv.bn.running_var")
|
89 |
+
else: # mbconv
|
90 |
+
new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
|
91 |
+
"blocks.{}.project_conv.conv.weight".format(blocks_id))
|
92 |
+
|
93 |
+
new_name = new_name.replace("depthwise_conv2d.depthwise_kernel",
|
94 |
+
"dwconv.conv.weight")
|
95 |
+
|
96 |
+
new_name = new_name.replace("tpu_batch_normalization_1.beta",
|
97 |
+
"dwconv.bn.bias")
|
98 |
+
new_name = new_name.replace("tpu_batch_normalization_1.gamma",
|
99 |
+
"dwconv.bn.weight")
|
100 |
+
new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
|
101 |
+
"dwconv.bn.running_mean")
|
102 |
+
new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
|
103 |
+
"dwconv.bn.running_var")
|
104 |
+
|
105 |
+
new_name = new_name.replace("tpu_batch_normalization_2.beta",
|
106 |
+
"project_conv.bn.bias")
|
107 |
+
new_name = new_name.replace("tpu_batch_normalization_2.gamma",
|
108 |
+
"project_conv.bn.weight")
|
109 |
+
new_name = new_name.replace("tpu_batch_normalization_2.moving_mean",
|
110 |
+
"project_conv.bn.running_mean")
|
111 |
+
new_name = new_name.replace("tpu_batch_normalization_2.moving_variance",
|
112 |
+
"project_conv.bn.running_var")
|
113 |
+
|
114 |
+
new_name = new_name.replace("se.conv2d.bias",
|
115 |
+
"se.conv_reduce.bias")
|
116 |
+
new_name = new_name.replace("se.conv2d.kernel",
|
117 |
+
"se.conv_reduce.weight")
|
118 |
+
new_name = new_name.replace("se.conv2d_1.bias",
|
119 |
+
"se.conv_expand.bias")
|
120 |
+
new_name = new_name.replace("se.conv2d_1.kernel",
|
121 |
+
"se.conv_expand.weight")
|
122 |
+
else:
|
123 |
+
print("not recognized name: " + v[0])
|
124 |
+
|
125 |
+
var = reader.get_tensor(v[0])
|
126 |
+
new_var = var
|
127 |
+
if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name:
|
128 |
+
assert len(var.shape) == 4
|
129 |
+
# conv kernel [h, w, c, n] -> [n, c, h, w]
|
130 |
+
new_var = np.transpose(var, (3, 2, 0, 1))
|
131 |
+
elif "bn" in new_name:
|
132 |
+
pass
|
133 |
+
elif "dwconv" in new_name and "weight" in new_name:
|
134 |
+
# dw_kernel [h, w, n, c] -> [n, c, h, w]
|
135 |
+
assert len(var.shape) == 4
|
136 |
+
new_var = np.transpose(var, (2, 3, 0, 1))
|
137 |
+
elif "classifier" in new_name and "weight" in new_name:
|
138 |
+
assert len(var.shape) == 2
|
139 |
+
new_var = np.transpose(var, (1, 0))
|
140 |
+
|
141 |
+
new_weights[new_name] = torch.as_tensor(new_var)
|
142 |
+
|
143 |
+
torch.save(new_weights, "pre_" + model_name + ".pth")
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
main(model_name="efficientnetv2-s",
|
148 |
+
tf_weights_path="./efficientnetv2-s/model",
|
149 |
+
stage0_num=2,
|
150 |
+
fused_conv_num=10)
|
151 |
+
|
152 |
+
# main(model_name="efficientnetv2-m",
|
153 |
+
# tf_weights_path="./efficientnetv2-m/model",
|
154 |
+
# stage0_num=3,
|
155 |
+
# fused_conv_num=13)
|
156 |
+
|
157 |
+
# main(model_name="efficientnetv2-l",
|
158 |
+
# tf_weights_path="./efficientnetv2-l/model",
|
159 |
+
# stage0_num=4,
|
160 |
+
# fused_conv_num=18)
|
utils.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
|
13 |
+
def read_split_data(root: str, val_rate: float = 0.2):
|
14 |
+
random.seed(0) # 保证随机结果可复现
|
15 |
+
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
|
16 |
+
|
17 |
+
# 遍历文件夹,一个文件夹对应一个类别
|
18 |
+
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
|
19 |
+
# 排序,保证顺序一致
|
20 |
+
flower_class.sort()
|
21 |
+
# 生成类别名称以及对应的数字索引
|
22 |
+
class_indices = dict((k, v) for v, k in enumerate(flower_class))
|
23 |
+
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
|
24 |
+
with open('class_indices.json', 'w') as json_file:
|
25 |
+
json_file.write(json_str)
|
26 |
+
|
27 |
+
train_images_path = [] # 存储训练集的所有图片路径
|
28 |
+
train_images_label = [] # 存储训练集图片对应索引信息
|
29 |
+
val_images_path = [] # 存储验证集的所有图片路径
|
30 |
+
val_images_label = [] # 存储验证集图片对应索引信息
|
31 |
+
every_class_num = [] # 存储每个类别的样本总数
|
32 |
+
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
|
33 |
+
# 遍历每个文件夹下的文件
|
34 |
+
for cla in flower_class:
|
35 |
+
cla_path = os.path.join(root, cla)
|
36 |
+
# 遍历获取supported支持的所有文件路径
|
37 |
+
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
|
38 |
+
if os.path.splitext(i)[-1] in supported]
|
39 |
+
# 获取该类别对应的索引
|
40 |
+
image_class = class_indices[cla]
|
41 |
+
# 记录该类别的样本数量
|
42 |
+
every_class_num.append(len(images))
|
43 |
+
# 按比例随机采样验证样本
|
44 |
+
val_path = random.sample(images, k=int(len(images) * val_rate))
|
45 |
+
|
46 |
+
for img_path in images:
|
47 |
+
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
|
48 |
+
val_images_path.append(img_path)
|
49 |
+
val_images_label.append(image_class)
|
50 |
+
else: # 否则存入训练集
|
51 |
+
train_images_path.append(img_path)
|
52 |
+
train_images_label.append(image_class)
|
53 |
+
|
54 |
+
print("{} images were found in the dataset.".format(sum(every_class_num)))
|
55 |
+
print("{} images for training.".format(len(train_images_path)))
|
56 |
+
print("{} images for validation.".format(len(val_images_path)))
|
57 |
+
|
58 |
+
plot_image = False
|
59 |
+
if plot_image:
|
60 |
+
# 绘制每种类别个数柱状图
|
61 |
+
plt.bar(range(len(flower_class)), every_class_num, align='center')
|
62 |
+
# 将横坐标0,1,2,3,4替换为相应的类别名称
|
63 |
+
plt.xticks(range(len(flower_class)), flower_class)
|
64 |
+
# 在柱状图上添加数值标签
|
65 |
+
for i, v in enumerate(every_class_num):
|
66 |
+
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
67 |
+
# 设置x坐标
|
68 |
+
plt.xlabel('image class')
|
69 |
+
# 设置y坐标
|
70 |
+
plt.ylabel('number of images')
|
71 |
+
# 设置柱状图的标题
|
72 |
+
plt.title('flower class distribution')
|
73 |
+
plt.show()
|
74 |
+
|
75 |
+
return train_images_path, train_images_label, val_images_path, val_images_label
|
76 |
+
|
77 |
+
|
78 |
+
def plot_data_loader_image(data_loader):
|
79 |
+
batch_size = data_loader.batch_size
|
80 |
+
plot_num = min(batch_size, 4)
|
81 |
+
|
82 |
+
json_path = './class_indices.json'
|
83 |
+
assert os.path.exists(json_path), json_path + " does not exist."
|
84 |
+
json_file = open(json_path, 'r')
|
85 |
+
class_indices = json.load(json_file)
|
86 |
+
|
87 |
+
for data in data_loader:
|
88 |
+
images, labels = data
|
89 |
+
for i in range(plot_num):
|
90 |
+
# [C, H, W] -> [H, W, C]
|
91 |
+
img = images[i].numpy().transpose(1, 2, 0)
|
92 |
+
# 反Normalize操作
|
93 |
+
img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
|
94 |
+
label = labels[i].item()
|
95 |
+
plt.subplot(1, plot_num, i+1)
|
96 |
+
plt.xlabel(class_indices[str(label)])
|
97 |
+
plt.xticks([]) # 去掉x轴的刻度
|
98 |
+
plt.yticks([]) # 去掉y轴的刻度
|
99 |
+
plt.imshow(img.astype('uint8'))
|
100 |
+
plt.show()
|
101 |
+
|
102 |
+
|
103 |
+
def write_pickle(list_info: list, file_name: str):
|
104 |
+
with open(file_name, 'wb') as f:
|
105 |
+
pickle.dump(list_info, f)
|
106 |
+
|
107 |
+
|
108 |
+
def read_pickle(file_name: str) -> list:
|
109 |
+
with open(file_name, 'rb') as f:
|
110 |
+
info_list = pickle.load(f)
|
111 |
+
return info_list
|
112 |
+
|
113 |
+
|
114 |
+
def train_one_epoch(model, optimizer, data_loader, device, epoch):
|
115 |
+
model.train()
|
116 |
+
loss_function = torch.nn.CrossEntropyLoss()
|
117 |
+
accu_loss = torch.zeros(1).to(device) # 累计损失
|
118 |
+
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
119 |
+
optimizer.zero_grad()
|
120 |
+
|
121 |
+
sample_num = 0
|
122 |
+
data_loader = tqdm(data_loader)
|
123 |
+
for step, data in enumerate(data_loader):
|
124 |
+
images, labels = data
|
125 |
+
sample_num += images.shape[0]
|
126 |
+
|
127 |
+
pred = model(images.to(device))
|
128 |
+
pred_classes = torch.max(pred, dim=1)[1]
|
129 |
+
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
130 |
+
|
131 |
+
loss = loss_function(pred, labels.to(device))
|
132 |
+
loss.backward()
|
133 |
+
accu_loss += loss.detach()
|
134 |
+
|
135 |
+
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
136 |
+
accu_loss.item() / (step + 1),
|
137 |
+
accu_num.item() / sample_num)
|
138 |
+
|
139 |
+
if not torch.isfinite(loss):
|
140 |
+
print('WARNING: non-finite loss, ending training ', loss)
|
141 |
+
sys.exit(1)
|
142 |
+
|
143 |
+
optimizer.step()
|
144 |
+
optimizer.zero_grad()
|
145 |
+
|
146 |
+
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
147 |
+
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def evaluate(model, data_loader, device, epoch):
|
151 |
+
loss_function = torch.nn.CrossEntropyLoss()
|
152 |
+
|
153 |
+
model.eval()
|
154 |
+
|
155 |
+
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
|
156 |
+
accu_loss = torch.zeros(1).to(device) # 累计损失
|
157 |
+
|
158 |
+
sample_num = 0
|
159 |
+
data_loader = tqdm(data_loader)
|
160 |
+
for step, data in enumerate(data_loader):
|
161 |
+
images, labels = data
|
162 |
+
sample_num += images.shape[0]
|
163 |
+
|
164 |
+
pred = model(images.to(device))
|
165 |
+
pred_classes = torch.max(pred, dim=1)[1]
|
166 |
+
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
|
167 |
+
|
168 |
+
loss = loss_function(pred, labels.to(device))
|
169 |
+
accu_loss += loss
|
170 |
+
|
171 |
+
data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
|
172 |
+
accu_loss.item() / (step + 1),
|
173 |
+
accu_num.item() / sample_num)
|
174 |
+
|
175 |
+
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
|
weights/model-20.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e3027cc78d0448540d99ed074391d48031c1ab3c6d23e3868bb83a7c9c90c9e
|
3 |
+
size 213027833
|