uestc_yhr commited on
Commit
71b93be
1 Parent(s): 6b6db36
Files changed (8) hide show
  1. class_indices.json +7 -0
  2. model.py +377 -0
  3. my_dataset.py +37 -0
  4. predict.py +65 -0
  5. train.py +143 -0
  6. trans_effv2_weights.py +160 -0
  7. utils.py +175 -0
  8. weights/model-20.pth +3 -0
class_indices.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "daisy",
3
+ "1": "dandelion",
4
+ "2": "roses",
5
+ "3": "sunflowers",
6
+ "4": "tulips"
7
+ }
model.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+ from typing import Callable, Optional
4
+
5
+ import torch.nn as nn
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
11
+ """
12
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
13
+ "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
14
+
15
+ This function is taken from the rwightman.
16
+ It can be seen here:
17
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L140
18
+ """
19
+ if drop_prob == 0. or not training:
20
+ return x
21
+ keep_prob = 1 - drop_prob
22
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
23
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
24
+ random_tensor.floor_() # binarize
25
+ output = x.div(keep_prob) * random_tensor
26
+ return output
27
+
28
+
29
+ class DropPath(nn.Module):
30
+ """
31
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
33
+ """
34
+ def __init__(self, drop_prob=None):
35
+ super(DropPath, self).__init__()
36
+ self.drop_prob = drop_prob
37
+
38
+ def forward(self, x):
39
+ return drop_path(x, self.drop_prob, self.training)
40
+
41
+
42
+ class ConvBNAct(nn.Module):
43
+ def __init__(self,
44
+ in_planes: int,
45
+ out_planes: int,
46
+ kernel_size: int = 3,
47
+ stride: int = 1,
48
+ groups: int = 1,
49
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
50
+ activation_layer: Optional[Callable[..., nn.Module]] = None):
51
+ super(ConvBNAct, self).__init__()
52
+
53
+ padding = (kernel_size - 1) // 2
54
+ if norm_layer is None:
55
+ norm_layer = nn.BatchNorm2d
56
+ if activation_layer is None:
57
+ activation_layer = nn.SiLU # alias Swish (torch>=1.7)
58
+
59
+ self.conv = nn.Conv2d(in_channels=in_planes,
60
+ out_channels=out_planes,
61
+ kernel_size=kernel_size,
62
+ stride=stride,
63
+ padding=padding,
64
+ groups=groups,
65
+ bias=False)
66
+
67
+ self.bn = norm_layer(out_planes)
68
+ self.act = activation_layer()
69
+
70
+ def forward(self, x):
71
+ result = self.conv(x)
72
+ result = self.bn(result)
73
+ result = self.act(result)
74
+
75
+ return result
76
+
77
+
78
+ class SqueezeExcite(nn.Module):
79
+ def __init__(self,
80
+ input_c: int, # block input channel
81
+ expand_c: int, # block expand channel
82
+ se_ratio: float = 0.25):
83
+ super(SqueezeExcite, self).__init__()
84
+ squeeze_c = int(input_c * se_ratio)
85
+ self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1)
86
+ self.act1 = nn.SiLU() # alias Swish
87
+ self.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1)
88
+ self.act2 = nn.Sigmoid()
89
+
90
+ def forward(self, x: Tensor) -> Tensor:
91
+ scale = x.mean((2, 3), keepdim=True)
92
+ scale = self.conv_reduce(scale)
93
+ scale = self.act1(scale)
94
+ scale = self.conv_expand(scale)
95
+ scale = self.act2(scale)
96
+ return scale * x
97
+
98
+
99
+ class MBConv(nn.Module):
100
+ def __init__(self,
101
+ kernel_size: int,
102
+ input_c: int,
103
+ out_c: int,
104
+ expand_ratio: int,
105
+ stride: int,
106
+ se_ratio: float,
107
+ drop_rate: float,
108
+ norm_layer: Callable[..., nn.Module]):
109
+ super(MBConv, self).__init__()
110
+
111
+ if stride not in [1, 2]:
112
+ raise ValueError("illegal stride value.")
113
+
114
+ self.has_shortcut = (stride == 1 and input_c == out_c)
115
+
116
+ activation_layer = nn.SiLU # alias Swish
117
+ expanded_c = input_c * expand_ratio
118
+
119
+ # 在EfficientNetV2中,MBConv中不存在expansion=1的情况所以conv_pw肯定存在
120
+ assert expand_ratio != 1
121
+ # Point-wise expansion
122
+ self.expand_conv = ConvBNAct(input_c,
123
+ expanded_c,
124
+ kernel_size=1,
125
+ norm_layer=norm_layer,
126
+ activation_layer=activation_layer)
127
+
128
+ # Depth-wise convolution
129
+ self.dwconv = ConvBNAct(expanded_c,
130
+ expanded_c,
131
+ kernel_size=kernel_size,
132
+ stride=stride,
133
+ groups=expanded_c,
134
+ norm_layer=norm_layer,
135
+ activation_layer=activation_layer)
136
+
137
+ self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else nn.Identity()
138
+
139
+ # Point-wise linear projection
140
+ self.project_conv = ConvBNAct(expanded_c,
141
+ out_planes=out_c,
142
+ kernel_size=1,
143
+ norm_layer=norm_layer,
144
+ activation_layer=nn.Identity) # 注意这里没有激活函数,所有传入Identity
145
+
146
+ self.out_channels = out_c
147
+
148
+ # 只有在使用shortcut连接时才使用dropout层
149
+ self.drop_rate = drop_rate
150
+ if self.has_shortcut and drop_rate > 0:
151
+ self.dropout = DropPath(drop_rate)
152
+
153
+ def forward(self, x: Tensor) -> Tensor:
154
+ result = self.expand_conv(x)
155
+ result = self.dwconv(result)
156
+ result = self.se(result)
157
+ result = self.project_conv(result)
158
+
159
+ if self.has_shortcut:
160
+ if self.drop_rate > 0:
161
+ result = self.dropout(result)
162
+ result += x
163
+
164
+ return result
165
+
166
+
167
+ class FusedMBConv(nn.Module):
168
+ def __init__(self,
169
+ kernel_size: int,
170
+ input_c: int,
171
+ out_c: int,
172
+ expand_ratio: int,
173
+ stride: int,
174
+ se_ratio: float,
175
+ drop_rate: float,
176
+ norm_layer: Callable[..., nn.Module]):
177
+ super(FusedMBConv, self).__init__()
178
+
179
+ assert stride in [1, 2]
180
+ assert se_ratio == 0
181
+
182
+ self.has_shortcut = stride == 1 and input_c == out_c
183
+ self.drop_rate = drop_rate
184
+
185
+ self.has_expansion = expand_ratio != 1
186
+
187
+ activation_layer = nn.SiLU # alias Swish
188
+ expanded_c = input_c * expand_ratio
189
+
190
+ # 只有当expand ratio不等于1时才有expand conv
191
+ if self.has_expansion:
192
+ # Expansion convolution
193
+ self.expand_conv = ConvBNAct(input_c,
194
+ expanded_c,
195
+ kernel_size=kernel_size,
196
+ stride=stride,
197
+ norm_layer=norm_layer,
198
+ activation_layer=activation_layer)
199
+
200
+ self.project_conv = ConvBNAct(expanded_c,
201
+ out_c,
202
+ kernel_size=1,
203
+ norm_layer=norm_layer,
204
+ activation_layer=nn.Identity) # 注意没有激活函数
205
+ else:
206
+ # 当只有project_conv时的情况
207
+ self.project_conv = ConvBNAct(input_c,
208
+ out_c,
209
+ kernel_size=kernel_size,
210
+ stride=stride,
211
+ norm_layer=norm_layer,
212
+ activation_layer=activation_layer) # 注意有激活函数
213
+
214
+ self.out_channels = out_c
215
+
216
+ # 只有在使用shortcut连接时才使用dropout层
217
+ self.drop_rate = drop_rate
218
+ if self.has_shortcut and drop_rate > 0:
219
+ self.dropout = DropPath(drop_rate)
220
+
221
+ def forward(self, x: Tensor) -> Tensor:
222
+ if self.has_expansion:
223
+ result = self.expand_conv(x)
224
+ result = self.project_conv(result)
225
+ else:
226
+ result = self.project_conv(x)
227
+
228
+ if self.has_shortcut:
229
+ if self.drop_rate > 0:
230
+ result = self.dropout(result)
231
+
232
+ result += x
233
+
234
+ return result
235
+
236
+
237
+ class EfficientNetV2(nn.Module):
238
+ def __init__(self,
239
+ model_cnf: list,
240
+ num_classes: int = 1000,
241
+ num_features: int = 1280,
242
+ dropout_rate: float = 0.2,
243
+ drop_connect_rate: float = 0.2):
244
+ super(EfficientNetV2, self).__init__()
245
+
246
+ for cnf in model_cnf:
247
+ assert len(cnf) == 8
248
+
249
+ norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)
250
+
251
+ stem_filter_num = model_cnf[0][4]
252
+
253
+ self.stem = ConvBNAct(3,
254
+ stem_filter_num,
255
+ kernel_size=3,
256
+ stride=2,
257
+ norm_layer=norm_layer) # 激活函数默认是SiLU
258
+
259
+ total_blocks = sum([i[0] for i in model_cnf])
260
+ block_id = 0
261
+ blocks = []
262
+ for cnf in model_cnf:
263
+ repeats = cnf[0]
264
+ op = FusedMBConv if cnf[-2] == 0 else MBConv
265
+ for i in range(repeats):
266
+ blocks.append(op(kernel_size=cnf[1],
267
+ input_c=cnf[4] if i == 0 else cnf[5],
268
+ out_c=cnf[5],
269
+ expand_ratio=cnf[3],
270
+ stride=cnf[2] if i == 0 else 1,
271
+ se_ratio=cnf[-1],
272
+ drop_rate=drop_connect_rate * block_id / total_blocks,
273
+ norm_layer=norm_layer))
274
+ block_id += 1
275
+ self.blocks = nn.Sequential(*blocks)
276
+
277
+ head_input_c = model_cnf[-1][-3]
278
+ head = OrderedDict()
279
+
280
+ head.update({"project_conv": ConvBNAct(head_input_c,
281
+ num_features,
282
+ kernel_size=1,
283
+ norm_layer=norm_layer)}) # 激活函数默认是SiLU
284
+
285
+ head.update({"avgpool": nn.AdaptiveAvgPool2d(1)})
286
+ head.update({"flatten": nn.Flatten()})
287
+
288
+ if dropout_rate > 0:
289
+ head.update({"dropout": nn.Dropout(p=dropout_rate, inplace=True)})
290
+ head.update({"classifier": nn.Linear(num_features, num_classes)})
291
+
292
+ self.head = nn.Sequential(head)
293
+
294
+ # initial weights
295
+ for m in self.modules():
296
+ if isinstance(m, nn.Conv2d):
297
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
298
+ if m.bias is not None:
299
+ nn.init.zeros_(m.bias)
300
+ elif isinstance(m, nn.BatchNorm2d):
301
+ nn.init.ones_(m.weight)
302
+ nn.init.zeros_(m.bias)
303
+ elif isinstance(m, nn.Linear):
304
+ nn.init.normal_(m.weight, 0, 0.01)
305
+ nn.init.zeros_(m.bias)
306
+
307
+ def forward(self, x: Tensor) -> Tensor:
308
+ x = self.stem(x)
309
+ x = self.blocks(x)
310
+ x = self.head(x)
311
+
312
+ return x
313
+
314
+
315
+ def efficientnetv2_s(num_classes: int = 1000):
316
+ """
317
+ EfficientNetV2
318
+ https://arxiv.org/abs/2104.00298
319
+ """
320
+ # train_size: 300, eval_size: 384
321
+
322
+ # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
323
+ model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
324
+ [4, 3, 2, 4, 24, 48, 0, 0],
325
+ [4, 3, 2, 4, 48, 64, 0, 0],
326
+ [6, 3, 2, 4, 64, 128, 1, 0.25],
327
+ [9, 3, 1, 6, 128, 160, 1, 0.25],
328
+ [15, 3, 2, 6, 160, 256, 1, 0.25]]
329
+
330
+ model = EfficientNetV2(model_cnf=model_config,
331
+ num_classes=num_classes,
332
+ dropout_rate=0.2)
333
+ return model
334
+
335
+
336
+ def efficientnetv2_m(num_classes: int = 1000):
337
+ """
338
+ EfficientNetV2
339
+ https://arxiv.org/abs/2104.00298
340
+ """
341
+ # train_size: 384, eval_size: 480
342
+
343
+ # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
344
+ model_config = [[3, 3, 1, 1, 24, 24, 0, 0],
345
+ [5, 3, 2, 4, 24, 48, 0, 0],
346
+ [5, 3, 2, 4, 48, 80, 0, 0],
347
+ [7, 3, 2, 4, 80, 160, 1, 0.25],
348
+ [14, 3, 1, 6, 160, 176, 1, 0.25],
349
+ [18, 3, 2, 6, 176, 304, 1, 0.25],
350
+ [5, 3, 1, 6, 304, 512, 1, 0.25]]
351
+
352
+ model = EfficientNetV2(model_cnf=model_config,
353
+ num_classes=num_classes,
354
+ dropout_rate=0.3)
355
+ return model
356
+
357
+
358
+ def efficientnetv2_l(num_classes: int = 1000):
359
+ """
360
+ EfficientNetV2
361
+ https://arxiv.org/abs/2104.00298
362
+ """
363
+ # train_size: 384, eval_size: 480
364
+
365
+ # repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
366
+ model_config = [[4, 3, 1, 1, 32, 32, 0, 0],
367
+ [7, 3, 2, 4, 32, 64, 0, 0],
368
+ [7, 3, 2, 4, 64, 96, 0, 0],
369
+ [10, 3, 2, 4, 96, 192, 1, 0.25],
370
+ [19, 3, 1, 6, 192, 224, 1, 0.25],
371
+ [25, 3, 2, 6, 224, 384, 1, 0.25],
372
+ [7, 3, 1, 6, 384, 640, 1, 0.25]]
373
+
374
+ model = EfficientNetV2(model_cnf=model_config,
375
+ num_classes=num_classes,
376
+ dropout_rate=0.4)
377
+ return model
my_dataset.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+
5
+
6
+ class MyDataSet(Dataset):
7
+ """自定义数据集"""
8
+
9
+ def __init__(self, images_path: list, images_class: list, transform=None):
10
+ self.images_path = images_path
11
+ self.images_class = images_class
12
+ self.transform = transform
13
+
14
+ def __len__(self):
15
+ return len(self.images_path)
16
+
17
+ def __getitem__(self, item):
18
+ img = Image.open(self.images_path[item])
19
+ # RGB为彩色图片,L为灰度图片
20
+ if img.mode != 'RGB':
21
+ raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
22
+ label = self.images_class[item]
23
+
24
+ if self.transform is not None:
25
+ img = self.transform(img)
26
+
27
+ return img, label
28
+
29
+ @staticmethod
30
+ def collate_fn(batch):
31
+ # 官方实现的default_collate可以参考
32
+ # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
33
+ images, labels = tuple(zip(*batch))
34
+
35
+ images = torch.stack(images, dim=0)
36
+ labels = torch.as_tensor(labels)
37
+ return images, labels
predict.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ import matplotlib.pyplot as plt
8
+
9
+ from model import efficientnetv2_m as create_model
10
+
11
+
12
+ def main():
13
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
+ img_size = {"s": [300, 384], # train_size, val_size
16
+ "m": [384, 480],
17
+ "l": [384, 480]}
18
+ num_model = "s"
19
+
20
+ data_transform = transforms.Compose(
21
+ [transforms.Resize(img_size[num_model][1]),
22
+ transforms.CenterCrop(img_size[num_model][1]),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
25
+
26
+ # load image
27
+ img_path = "../d.jpg"
28
+ assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
29
+ img = Image.open(img_path)
30
+ plt.imshow(img)
31
+ # [N, C, H, W]
32
+ img = data_transform(img)
33
+ # expand batch dimension
34
+ img = torch.unsqueeze(img, dim=0)
35
+
36
+ # read class_indict
37
+ json_path = './class_indices.json'
38
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
39
+
40
+ json_file = open(json_path, "r")
41
+ class_indict = json.load(json_file)
42
+
43
+ # create model
44
+ model = create_model(num_classes=5).to(device)
45
+ # load model weights
46
+ model_weight_path = "./weights/model-20.pth"
47
+ model.load_state_dict(torch.load(model_weight_path, map_location=device))
48
+ model.eval()
49
+ with torch.no_grad():
50
+ # predict class
51
+ output = torch.squeeze(model(img.to(device))).cpu()
52
+ predict = torch.softmax(output, dim=0)
53
+ predict_cla = torch.argmax(predict).numpy()
54
+
55
+ print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
56
+ predict[predict_cla].numpy())
57
+ plt.title(print_res)
58
+ for i in range(len(predict)):
59
+ print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
60
+ predict[i].numpy()))
61
+ plt.show()
62
+
63
+
64
+ if __name__ == '__main__':
65
+ main()
train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import argparse
4
+
5
+ import torch
6
+ import torch.optim as optim
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ from torchvision import transforms
9
+ import torch.optim.lr_scheduler as lr_scheduler
10
+
11
+ from model import efficientnetv2_m as create_model
12
+ from my_dataset import MyDataSet
13
+ from utils import read_split_data, train_one_epoch, evaluate
14
+
15
+
16
+ def main(args):
17
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
18
+
19
+ print(args)
20
+ print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
21
+ tb_writer = SummaryWriter()
22
+ if os.path.exists("./weights") is False:
23
+ os.makedirs("./weights")
24
+
25
+ train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
26
+
27
+ img_size = {"s": [300, 384], # train_size, val_size
28
+ "m": [384, 480],
29
+ "l": [384, 480]}
30
+ num_model = "s"
31
+
32
+ data_transform = {
33
+ "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
34
+ transforms.RandomHorizontalFlip(),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
37
+ "val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
38
+ transforms.CenterCrop(img_size[num_model][1]),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
41
+
42
+ # 实例化训练数据集
43
+ train_dataset = MyDataSet(images_path=train_images_path,
44
+ images_class=train_images_label,
45
+ transform=data_transform["train"])
46
+
47
+ # 实例化验证数据集
48
+ val_dataset = MyDataSet(images_path=val_images_path,
49
+ images_class=val_images_label,
50
+ transform=data_transform["val"])
51
+
52
+ batch_size = args.batch_size
53
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
54
+ print('Using {} dataloader workers every process'.format(nw))
55
+ train_loader = torch.utils.data.DataLoader(train_dataset,
56
+ batch_size=batch_size,
57
+ shuffle=True,
58
+ pin_memory=True,
59
+ num_workers=nw,
60
+ collate_fn=train_dataset.collate_fn)
61
+
62
+ val_loader = torch.utils.data.DataLoader(val_dataset,
63
+ batch_size=batch_size,
64
+ shuffle=False,
65
+ pin_memory=True,
66
+ num_workers=nw,
67
+ collate_fn=val_dataset.collate_fn)
68
+
69
+ # 如果存在预训练权重则载入
70
+ model = create_model(num_classes=args.num_classes).to(device)
71
+ if args.weights != "":
72
+ if os.path.exists(args.weights):
73
+ weights_dict = torch.load(args.weights, map_location=device)
74
+ load_weights_dict = {k: v for k, v in weights_dict.items()
75
+ if model.state_dict()[k].numel() == v.numel()}
76
+ print(model.load_state_dict(load_weights_dict, strict=False))
77
+ else:
78
+ raise FileNotFoundError("not found weights file: {}".format(args.weights))
79
+
80
+ # 是否冻结权重
81
+ if args.freeze_layers:
82
+ for name, para in model.named_parameters():
83
+ # 除head外,其他权重全部冻结
84
+ if "head" not in name:
85
+ para.requires_grad_(False)
86
+ else:
87
+ print("training {}".format(name))
88
+
89
+ pg = [p for p in model.parameters() if p.requires_grad]
90
+ optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
91
+ # Scheduler https://arxiv.org/pdf/1812.01187.pdf
92
+ lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
93
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
94
+
95
+ for epoch in range(args.epochs):
96
+ # train
97
+ train_loss, train_acc = train_one_epoch(model=model,
98
+ optimizer=optimizer,
99
+ data_loader=train_loader,
100
+ device=device,
101
+ epoch=epoch)
102
+
103
+ scheduler.step()
104
+
105
+ # validate
106
+ val_loss, val_acc = evaluate(model=model,
107
+ data_loader=val_loader,
108
+ device=device,
109
+ epoch=epoch)
110
+
111
+ tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
112
+ tb_writer.add_scalar(tags[0], train_loss, epoch)
113
+ tb_writer.add_scalar(tags[1], train_acc, epoch)
114
+ tb_writer.add_scalar(tags[2], val_loss, epoch)
115
+ tb_writer.add_scalar(tags[3], val_acc, epoch)
116
+ tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
117
+
118
+ torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
119
+
120
+
121
+ if __name__ == '__main__':
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument('--num_classes', type=int, default=5)
124
+ parser.add_argument('--epochs', type=int, default=30)
125
+ parser.add_argument('--batch-size', type=int, default=8)
126
+ parser.add_argument('--lr', type=float, default=0.01)
127
+ parser.add_argument('--lrf', type=float, default=0.01)
128
+
129
+ # 数据集所在根目录
130
+ # http://download.tensorflow.org/example_images/flower_photos.tgz
131
+ parser.add_argument('--data-path', type=str,
132
+ default="../../data_set/flower_data/flower_photos")
133
+
134
+ # download model weights
135
+ # 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1
136
+ parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-m.pth',
137
+ help='initial weights path')
138
+ parser.add_argument('--freeze-layers', type=bool, default=True)
139
+ parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
140
+
141
+ opt = parser.parse_args()
142
+
143
+ main(opt)
trans_effv2_weights.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ def main(model_name: str = "efficientnetv2-s",
7
+ tf_weights_path: str = "./efficientnetv2-s/model",
8
+ stage0_num: int = 2,
9
+ fused_conv_num: int = 10):
10
+
11
+ except_var = ["global_step"]
12
+
13
+ new_weights = {}
14
+ var_list = [i for i in tf.train.list_variables(tf_weights_path) if "Exponential" not in i[0]]
15
+ reader = tf.train.load_checkpoint(tf_weights_path)
16
+ for v in var_list:
17
+ if v[0] in except_var:
18
+ continue
19
+ new_name = v[0].replace(model_name + "/", "").replace("/", ".")
20
+
21
+ if "stem" in v[0]:
22
+ new_name = new_name.replace("conv2d.kernel",
23
+ "conv.weight")
24
+
25
+ new_name = new_name.replace("tpu_batch_normalization.beta",
26
+ "bn.bias")
27
+ new_name = new_name.replace("tpu_batch_normalization.gamma",
28
+ "bn.weight")
29
+ new_name = new_name.replace("tpu_batch_normalization.moving_mean",
30
+ "bn.running_mean")
31
+ new_name = new_name.replace("tpu_batch_normalization.moving_variance",
32
+ "bn.running_var")
33
+ elif "head" in v[0]:
34
+ new_name = new_name.replace("conv2d.kernel",
35
+ "project_conv.conv.weight")
36
+ new_name = new_name.replace("dense.kernel",
37
+ "classifier.weight")
38
+ new_name = new_name.replace("dense.bias",
39
+ "classifier.bias")
40
+
41
+ new_name = new_name.replace("tpu_batch_normalization.beta",
42
+ "project_conv.bn.bias")
43
+ new_name = new_name.replace("tpu_batch_normalization.gamma",
44
+ "project_conv.bn.weight")
45
+ new_name = new_name.replace("tpu_batch_normalization.moving_mean",
46
+ "project_conv.bn.running_mean")
47
+ new_name = new_name.replace("tpu_batch_normalization.moving_variance",
48
+ "project_conv.bn.running_var")
49
+ elif "blocks" in v[0]:
50
+ # e.g. blocks_0.conv2d.kernel -> 0
51
+ blocks_id = new_name.split(".", maxsplit=1)[0].replace("blocks_", "")
52
+ new_name = new_name.replace("blocks_{}".format(blocks_id),
53
+ "blocks.{}".format(blocks_id))
54
+
55
+ if int(blocks_id) <= stage0_num - 1: # expansion=1 fused_mbconv
56
+ new_name = new_name.replace("conv2d.kernel",
57
+ "project_conv.conv.weight")
58
+ new_name = new_name.replace("tpu_batch_normalization.beta",
59
+ "project_conv.bn.bias")
60
+ new_name = new_name.replace("tpu_batch_normalization.gamma",
61
+ "project_conv.bn.weight")
62
+ new_name = new_name.replace("tpu_batch_normalization.moving_mean",
63
+ "project_conv.bn.running_mean")
64
+ new_name = new_name.replace("tpu_batch_normalization.moving_variance",
65
+ "project_conv.bn.running_var")
66
+ else:
67
+ new_name = new_name.replace("blocks.{}.conv2d.kernel".format(blocks_id),
68
+ "blocks.{}.expand_conv.conv.weight".format(blocks_id))
69
+ new_name = new_name.replace("tpu_batch_normalization.beta",
70
+ "expand_conv.bn.bias")
71
+ new_name = new_name.replace("tpu_batch_normalization.gamma",
72
+ "expand_conv.bn.weight")
73
+ new_name = new_name.replace("tpu_batch_normalization.moving_mean",
74
+ "expand_conv.bn.running_mean")
75
+ new_name = new_name.replace("tpu_batch_normalization.moving_variance",
76
+ "expand_conv.bn.running_var")
77
+
78
+ if int(blocks_id) <= fused_conv_num - 1: # fused_mbconv
79
+ new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
80
+ "blocks.{}.project_conv.conv.weight".format(blocks_id))
81
+ new_name = new_name.replace("tpu_batch_normalization_1.beta",
82
+ "project_conv.bn.bias")
83
+ new_name = new_name.replace("tpu_batch_normalization_1.gamma",
84
+ "project_conv.bn.weight")
85
+ new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
86
+ "project_conv.bn.running_mean")
87
+ new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
88
+ "project_conv.bn.running_var")
89
+ else: # mbconv
90
+ new_name = new_name.replace("blocks.{}.conv2d_1.kernel".format(blocks_id),
91
+ "blocks.{}.project_conv.conv.weight".format(blocks_id))
92
+
93
+ new_name = new_name.replace("depthwise_conv2d.depthwise_kernel",
94
+ "dwconv.conv.weight")
95
+
96
+ new_name = new_name.replace("tpu_batch_normalization_1.beta",
97
+ "dwconv.bn.bias")
98
+ new_name = new_name.replace("tpu_batch_normalization_1.gamma",
99
+ "dwconv.bn.weight")
100
+ new_name = new_name.replace("tpu_batch_normalization_1.moving_mean",
101
+ "dwconv.bn.running_mean")
102
+ new_name = new_name.replace("tpu_batch_normalization_1.moving_variance",
103
+ "dwconv.bn.running_var")
104
+
105
+ new_name = new_name.replace("tpu_batch_normalization_2.beta",
106
+ "project_conv.bn.bias")
107
+ new_name = new_name.replace("tpu_batch_normalization_2.gamma",
108
+ "project_conv.bn.weight")
109
+ new_name = new_name.replace("tpu_batch_normalization_2.moving_mean",
110
+ "project_conv.bn.running_mean")
111
+ new_name = new_name.replace("tpu_batch_normalization_2.moving_variance",
112
+ "project_conv.bn.running_var")
113
+
114
+ new_name = new_name.replace("se.conv2d.bias",
115
+ "se.conv_reduce.bias")
116
+ new_name = new_name.replace("se.conv2d.kernel",
117
+ "se.conv_reduce.weight")
118
+ new_name = new_name.replace("se.conv2d_1.bias",
119
+ "se.conv_expand.bias")
120
+ new_name = new_name.replace("se.conv2d_1.kernel",
121
+ "se.conv_expand.weight")
122
+ else:
123
+ print("not recognized name: " + v[0])
124
+
125
+ var = reader.get_tensor(v[0])
126
+ new_var = var
127
+ if "conv" in new_name and "weight" in new_name and "bn" not in new_name and "dw" not in new_name:
128
+ assert len(var.shape) == 4
129
+ # conv kernel [h, w, c, n] -> [n, c, h, w]
130
+ new_var = np.transpose(var, (3, 2, 0, 1))
131
+ elif "bn" in new_name:
132
+ pass
133
+ elif "dwconv" in new_name and "weight" in new_name:
134
+ # dw_kernel [h, w, n, c] -> [n, c, h, w]
135
+ assert len(var.shape) == 4
136
+ new_var = np.transpose(var, (2, 3, 0, 1))
137
+ elif "classifier" in new_name and "weight" in new_name:
138
+ assert len(var.shape) == 2
139
+ new_var = np.transpose(var, (1, 0))
140
+
141
+ new_weights[new_name] = torch.as_tensor(new_var)
142
+
143
+ torch.save(new_weights, "pre_" + model_name + ".pth")
144
+
145
+
146
+ if __name__ == '__main__':
147
+ main(model_name="efficientnetv2-s",
148
+ tf_weights_path="./efficientnetv2-s/model",
149
+ stage0_num=2,
150
+ fused_conv_num=10)
151
+
152
+ # main(model_name="efficientnetv2-m",
153
+ # tf_weights_path="./efficientnetv2-m/model",
154
+ # stage0_num=3,
155
+ # fused_conv_num=13)
156
+
157
+ # main(model_name="efficientnetv2-l",
158
+ # tf_weights_path="./efficientnetv2-l/model",
159
+ # stage0_num=4,
160
+ # fused_conv_num=18)
utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import pickle
5
+ import random
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ def read_split_data(root: str, val_rate: float = 0.2):
14
+ random.seed(0) # 保证随机结果可复现
15
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
16
+
17
+ # 遍历文件夹,一个文件夹对应一个类别
18
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
19
+ # 排序,保证顺序一致
20
+ flower_class.sort()
21
+ # 生成类别名称以及对应的数字索引
22
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
23
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
24
+ with open('class_indices.json', 'w') as json_file:
25
+ json_file.write(json_str)
26
+
27
+ train_images_path = [] # 存储训练集的所有图片路径
28
+ train_images_label = [] # 存储训练集图片对应索引信息
29
+ val_images_path = [] # 存储验证集的所有图片路径
30
+ val_images_label = [] # 存储验证集图片对应索引信息
31
+ every_class_num = [] # 存储每个类别的样本总数
32
+ supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
33
+ # 遍历每个文件夹下的文件
34
+ for cla in flower_class:
35
+ cla_path = os.path.join(root, cla)
36
+ # 遍历获取supported支持的所有文件路径
37
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
38
+ if os.path.splitext(i)[-1] in supported]
39
+ # 获取该类别对应的索引
40
+ image_class = class_indices[cla]
41
+ # 记录该类别的样本数量
42
+ every_class_num.append(len(images))
43
+ # 按比例随机采样验证样本
44
+ val_path = random.sample(images, k=int(len(images) * val_rate))
45
+
46
+ for img_path in images:
47
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
48
+ val_images_path.append(img_path)
49
+ val_images_label.append(image_class)
50
+ else: # 否则存入训练集
51
+ train_images_path.append(img_path)
52
+ train_images_label.append(image_class)
53
+
54
+ print("{} images were found in the dataset.".format(sum(every_class_num)))
55
+ print("{} images for training.".format(len(train_images_path)))
56
+ print("{} images for validation.".format(len(val_images_path)))
57
+
58
+ plot_image = False
59
+ if plot_image:
60
+ # 绘制每种类别个数柱状图
61
+ plt.bar(range(len(flower_class)), every_class_num, align='center')
62
+ # 将横坐标0,1,2,3,4替换为相应的类别名称
63
+ plt.xticks(range(len(flower_class)), flower_class)
64
+ # 在柱状图上添加数值标签
65
+ for i, v in enumerate(every_class_num):
66
+ plt.text(x=i, y=v + 5, s=str(v), ha='center')
67
+ # 设置x坐标
68
+ plt.xlabel('image class')
69
+ # 设置y坐标
70
+ plt.ylabel('number of images')
71
+ # 设置柱状图的标题
72
+ plt.title('flower class distribution')
73
+ plt.show()
74
+
75
+ return train_images_path, train_images_label, val_images_path, val_images_label
76
+
77
+
78
+ def plot_data_loader_image(data_loader):
79
+ batch_size = data_loader.batch_size
80
+ plot_num = min(batch_size, 4)
81
+
82
+ json_path = './class_indices.json'
83
+ assert os.path.exists(json_path), json_path + " does not exist."
84
+ json_file = open(json_path, 'r')
85
+ class_indices = json.load(json_file)
86
+
87
+ for data in data_loader:
88
+ images, labels = data
89
+ for i in range(plot_num):
90
+ # [C, H, W] -> [H, W, C]
91
+ img = images[i].numpy().transpose(1, 2, 0)
92
+ # 反Normalize操作
93
+ img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
94
+ label = labels[i].item()
95
+ plt.subplot(1, plot_num, i+1)
96
+ plt.xlabel(class_indices[str(label)])
97
+ plt.xticks([]) # 去掉x轴的刻度
98
+ plt.yticks([]) # 去掉y轴的刻度
99
+ plt.imshow(img.astype('uint8'))
100
+ plt.show()
101
+
102
+
103
+ def write_pickle(list_info: list, file_name: str):
104
+ with open(file_name, 'wb') as f:
105
+ pickle.dump(list_info, f)
106
+
107
+
108
+ def read_pickle(file_name: str) -> list:
109
+ with open(file_name, 'rb') as f:
110
+ info_list = pickle.load(f)
111
+ return info_list
112
+
113
+
114
+ def train_one_epoch(model, optimizer, data_loader, device, epoch):
115
+ model.train()
116
+ loss_function = torch.nn.CrossEntropyLoss()
117
+ accu_loss = torch.zeros(1).to(device) # 累计损失
118
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
119
+ optimizer.zero_grad()
120
+
121
+ sample_num = 0
122
+ data_loader = tqdm(data_loader)
123
+ for step, data in enumerate(data_loader):
124
+ images, labels = data
125
+ sample_num += images.shape[0]
126
+
127
+ pred = model(images.to(device))
128
+ pred_classes = torch.max(pred, dim=1)[1]
129
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
130
+
131
+ loss = loss_function(pred, labels.to(device))
132
+ loss.backward()
133
+ accu_loss += loss.detach()
134
+
135
+ data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
136
+ accu_loss.item() / (step + 1),
137
+ accu_num.item() / sample_num)
138
+
139
+ if not torch.isfinite(loss):
140
+ print('WARNING: non-finite loss, ending training ', loss)
141
+ sys.exit(1)
142
+
143
+ optimizer.step()
144
+ optimizer.zero_grad()
145
+
146
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
147
+
148
+
149
+ @torch.no_grad()
150
+ def evaluate(model, data_loader, device, epoch):
151
+ loss_function = torch.nn.CrossEntropyLoss()
152
+
153
+ model.eval()
154
+
155
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
156
+ accu_loss = torch.zeros(1).to(device) # 累计损失
157
+
158
+ sample_num = 0
159
+ data_loader = tqdm(data_loader)
160
+ for step, data in enumerate(data_loader):
161
+ images, labels = data
162
+ sample_num += images.shape[0]
163
+
164
+ pred = model(images.to(device))
165
+ pred_classes = torch.max(pred, dim=1)[1]
166
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
167
+
168
+ loss = loss_function(pred, labels.to(device))
169
+ accu_loss += loss
170
+
171
+ data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
172
+ accu_loss.item() / (step + 1),
173
+ accu_num.item() / sample_num)
174
+
175
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
weights/model-20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e3027cc78d0448540d99ed074391d48031c1ab3c6d23e3868bb83a7c9c90c9e
3
+ size 213027833