TheEeeeLin commited on
Commit
b83973e
·
1 Parent(s): eac3bed

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ readme_files/6.gif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,572 @@
1
- ---
2
- title: Resnet50-cats Vs Dogs
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.13.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resnet50-cats_vs_dogs
2
+
3
+ [![zhihu](https://img.shields.io/badge/知乎-zhihu-blue)](https://zhuanlan.zhihu.com/p/676430630)
4
+
5
+ 猫狗分类是计算机视觉最基础的任务之一——如果说完成MNIST手写体识别是实现CV的“Hello World”,那猫狗分类就是旅程的下一站~。
6
+
7
+ 这篇文章我将带大家使用PyTorch、SwanLab、Gradio三个开源工具,完成从**数据集准备、代码编写、可视化训练到构建Demo网页**的全过程。
8
+
9
+ > 代码:[Github](https://github.com/xiaolin199912/Resnet50-cats_vs_dogs)
10
+ >
11
+ > 在线Demo: [SwanHub](https://swanhub.co/ZeYiLin/Resnet50-cats_vs_dogs/demo)
12
+ >
13
+ > 数据集:[百度云](https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ) 提取码: 1ybm
14
+ >
15
+ > 三个开源库:[pytorch](https://github.com/pytorch/pytorch)、[SwanLab](https://github.com/SwanHubX/SwanLab)、[Gradio](https://github.com/gradio-app/gradio)
16
+
17
+
18
+
19
+ # 1. 准备部分
20
+
21
+ ## 1.1 安装Python库
22
+
23
+ 需要安装下面这4个库:
24
+
25
+ ```bash
26
+ torch>=1.12.0
27
+ torchvision>=0.13.0
28
+ swanlab>=0.1.2
29
+ gradio
30
+ ```
31
+
32
+ 安装命令:
33
+
34
+ ```bash
35
+ pip install torch>=1.12.0 torchvision>=0.13.0 swanlab>=0.1.2 gradio
36
+ ```
37
+
38
+ ## 1.2 创建文件目录
39
+
40
+ 现在打开1个文件夹,新建下面这5个文件:
41
+
42
+ ![在这里插入图片描述](readme_files/1.png)
43
+
44
+ 它们各自的作用分别是:
45
+
46
+ * `checkpoint`:这个文件夹用于存储训练过程中生成的模型权重。
47
+ * `datasets`:这个文件夹用于放置数据集。
48
+ * `app.py`:运行Gradio Demo的Python脚本。
49
+ * `load_datasets.py`:负责载入数据集,包含了数据的预处理、加载等步骤,确保数据以适当的格式提供给模型使用。
50
+ * `train.py`:模型训练的核心脚本。它包含了模型的载入、训练循环、损失函数的选择、优化器的配置等关键组成部分,用于指导如何使用数据来训练模型。
51
+
52
+ ## 1.3 下载猫狗分类数据集
53
+
54
+ 数据集来源是Modelscope上的[猫狗分类数据集](https://modelscope.cn/datasets/tany0699/cats_and_dogs/summary),包含275张图像的数据集和70张图像的测试集,一共不到10MB。
55
+ 我对数据做了一些整理,所以更推荐使用下面的百度网盘链接下载:
56
+
57
+ > 百度网盘:链接: <https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ> 提取码: 1ybm
58
+
59
+ ![在这里插入图片描述](readme_files/2.png)
60
+
61
+ 将数据集放入`datasets`文件夹:
62
+
63
+ ![在这里插入图片描述](readme_files/3.png)
64
+
65
+ ok,现在我们开始训练部分!
66
+
67
+ > ps:如果你想要用更大规模的数据来训练猫狗分类模型,请前往文末的相关链接。
68
+
69
+ # 2. 训练部分
70
+
71
+ ps:如果想直接看完整代码和效果,可直接跳转到第2.9。
72
+
73
+ ## 2.1 load_datasets.py
74
+
75
+ 我们首先需要创建1个类`DatasetLoader`,它的作用是完成数据集的读取和预处理,我们将它写在`load_datasets.py`中。
76
+ 在写这个类之前,先分析一下数据集。
77
+ 在datasets目录下,`train.csv`和`val.csv`分别记录了训练集和测试集的图像相对路径(第一列是图像的相对路径,第二列是标签,0代表猫,1代表狗):
78
+
79
+ ![在这里插入图片描述](readme_files/4.png)
80
+
81
+ ![左图作为train.csv,右图为train文件夹中的cat文件夹中的图像](readme_files/5.png)
82
+
83
+ 左图作为train.csv,右图为train文件夹中的cat文件夹中的图像。
84
+
85
+ 那么我们的目标就很明确:
86
+
87
+ 1. 解析这两个csv文件,获取图像相对路径和标签
88
+ 2. 根据相对路径读取图像
89
+ 3. 对图像做预处理
90
+ 4. 返回预处理后的图像和对应标签
91
+
92
+ 明确了目标后,现在我们开始写`DatasetLoader`类:
93
+
94
+ ```python
95
+ import csv
96
+ import os
97
+ from torchvision import transforms
98
+ from PIL import Image
99
+ from torch.utils.data import Dataset
100
+
101
+ class DatasetLoader(Dataset):
102
+ def __init__(self, csv_path):
103
+ self.csv_file = csv_path
104
+ with open(self.csv_file, 'r') as file:
105
+ self.data = list(csv.reader(file))
106
+
107
+ self.current_dir = os.path.dirname(os.path.abspath(__file__))
108
+
109
+ def preprocess_image(self, image_path):
110
+ full_path = os.path.join(self.current_dir, 'datasets', image_path)
111
+ image = Image.open(full_path)
112
+ image_transform = transforms.Compose([
113
+ transforms.Resize((256, 256)),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
116
+ ])
117
+
118
+ return image_transform(image)
119
+
120
+ def __getitem__(self, index):
121
+ image_path, label = self.data[index]
122
+ image = self.preprocess_image(image_path)
123
+ return image, int(label)
124
+
125
+ def __len__(self):
126
+ return len(self.data)
127
+ ```
128
+
129
+ `DatasetLoader`类由四个部分组成:
130
+
131
+ 1. `__init__`:包含1个输入参数csv\_path,在外部传入`csv_path`后,将读取后的数据存入`self.data`中。`self.current_dir`则是获取了当前代码所在目录的绝对路径,为后续读取图像做准备。
132
+
133
+ 2. `preprocess_image`:此函数用于图像预处理。首先,它构造图像文件的绝对路径,然后使用PIL库打开图像。接着,定义了一系列图像变换:调整图像大小至256x256、转换图像为张量、对图像进行标准化处理,最终,返回预处理后的图像。
134
+
135
+ 3. `__getitem__`:当数据集类被循环调用时,`__getitem__`方法会返回指定索引index的数据,即图像和标签。首先,它根据索引从`self.data`中取出图像路径和标签。然后,调用`preprocess_image`方法来处理图像数据。最后,将处理后的图像数据和标签转换为整型后返回。
136
+
137
+ 4. `__len__`:用于返回数据集的总图像数量。
138
+
139
+ ## 2.2 载入数据集
140
+ > 从本节开始,代码将写在`train.py`中。
141
+
142
+ ```python
143
+ from torch.utils.data import DataLoader
144
+ from load_datasets import DatasetLoader
145
+
146
+ batch_size = 8
147
+
148
+ TrainDataset = DatasetLoader("datasets/train.csv")
149
+ ValDataset = DatasetLoader("datasets/val.csv")
150
+ TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
151
+ ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
152
+ ```
153
+
154
+ 我们传入那两个csv文件的路径实例化`DatasetLoader`类,然后用PyTorch的`DataLoader`做一层封装。`DataLoader`可以再传入两个参数:
155
+
156
+ * `batch_size`:定义了每个数据批次包含多少张图像。在深度学习中,我们通常不会一次性地处理所有数据,而是将数据划分为小批次。这有助于模型更快地学习,并且还可以节省内存。在这里我们定义batch\_size = 8,即每个批次将包含8个图像。
157
+ * `shuffle`:定义了是否在每个循环轮次(epoch)开始时随机打乱数据。这通常用于训练数据集以保证每个epoch的数据顺序不同,从而帮助模型更好地泛化。如果设置为True,那么在每个epoch开始时,数据将被打乱。在这里我们让训练时打乱,测试时不打乱。
158
+
159
+ ## 2.3 载入ResNet50模型
160
+
161
+ 模型我们选用经典的**ResNet50**,模型的具体原理本文就不细说了,重点放在工程实现上。
162
+ 我们使用**torchvision**来创建1个resnet50模型,并载入在Imagenet1k数据集上预训练好的权重:
163
+
164
+ ```python
165
+ from torchvision.models import ResNet50_Weights
166
+
167
+ # 加载预训练的ResNet50模型
168
+ model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
169
+ ```
170
+
171
+ 因为猫狗分类是个2分类任务,而torchvision提供的resnet50默认是1000分类,所以我们需要把模型最后的全连接层的输出维度替换为2:
172
+
173
+ ```python
174
+ from torchvision.models import ResNet50_Weights
175
+
176
+ num_classes=2
177
+
178
+ # 加载预训练的ResNet50模型
179
+ model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
180
+
181
+ # 将全连接层的输出维度替换为num_classes
182
+ in_features = model.fc.in_features
183
+ model.fc = torch.nn.Linear(in_features, num_classes)
184
+ ```
185
+
186
+ ## 2.4 设置cuda/mps/cpu
187
+
188
+ 如果你的电脑是**英伟达显卡**,那么cuda可以极大加速你的训练;
189
+ 如果你的电脑是**Macbook Apple Sillicon(M系列芯片)**,那么mps同样可以极大加速你的训练;
190
+ 如果都不是,那就选用cpu:
191
+
192
+ ```python
193
+ #检测是否支持mps
194
+ try:
195
+ use_mps = torch.backends.mps.is_available()
196
+ except AttributeError:
197
+ use_mps = False
198
+
199
+ #检测是否支持cuda
200
+ if torch.cuda.is_available():
201
+ device = "cuda"
202
+ elif use_mps:
203
+ device = "mps"
204
+ else:
205
+ device = "cpu"
206
+ ```
207
+
208
+ 将模型加载到对应的device中:
209
+
210
+ ```python
211
+ model.to(torch.device(device))
212
+ ```
213
+
214
+ ## 2.5 设置超参数、优化器、损失函数
215
+
216
+ **超参数**
217
+ 设置训练轮次为20轮,学习率为1e-4,训练批次为8,分类数为2分类。
218
+
219
+ ```python
220
+ num_epochs = 20
221
+ lr = 1e-4
222
+ batch_size = 8
223
+ num_classes = 2
224
+ ```
225
+
226
+ ### 损失函数与优化器
227
+
228
+ 设置损失函数为交叉熵损失,优化器为Adam。
229
+
230
+ ```python
231
+ criterion = torch.nn.CrossEntropyLoss()
232
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
233
+ ```
234
+
235
+ ## 2.6 初始化SwanLab
236
+
237
+ 在训练中我们使用`swanlab`库作为实验管理与指标可视化工具。
238
+ [swanlab](https://github.com/SwanHubX/SwanLab)是一个类似Tensorboard的开源训练图表可视化库,有着更轻量的体积与更友好的API,除了能记录指标,还能自动记录训练的logging、硬件环境、Python环境、训练时间等信息。
239
+
240
+ ![在这里插入图片描述](readme_files/6.gif)
241
+
242
+ ### 设置初始化配置参数
243
+
244
+ swanlab库使用`swanlab.init`设置实验名、实验介绍和记录超参数。
245
+
246
+ ```python
247
+ import swanlab
248
+
249
+ swanlab.init(
250
+ # 设置实验名
251
+ experiment_name="ResNet50",
252
+ # 设置实验介绍
253
+ description="Train ResNet50 for cat and dog classification.",
254
+ # 记录超参数
255
+ config={
256
+ "model": "resnet50",
257
+ "optim": "Adam",
258
+ "lr": lr,
259
+ "batch_size": batch_size,
260
+ "num_epochs": num_epochs,
261
+ "num_class": num_classes,
262
+ "device": device,
263
+ }
264
+ )
265
+ ```
266
+
267
+ ### 跟踪关键指标
268
+
269
+ swanlab库使用`swanlab.log`来记录关键指标,具体使用案例见2.7和2.8。
270
+
271
+ ## 2.7 训练函数
272
+
273
+ 我们定义1个训练函数`train`:
274
+
275
+ ```python
276
+ def train(model, device, train_dataloader, optimizer, criterion, epoch):
277
+ model.train()
278
+ for iter, (inputs, labels) in enumerate(train_loader):
279
+ inputs, labels = inputs.to(device), labels.to(device)
280
+ optimizer.zero_grad()
281
+ outputs = model(inputs)
282
+ loss = criterion(outputs, labels)
283
+ loss.backward()
284
+ optimizer.step()
285
+ print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
286
+ loss.item()))
287
+ swanlab.log({"train_loss": loss.item()})
288
+ ```
289
+
290
+ 训练的逻辑很简单:我们循环调用`train_dataloader`,每次取出1个batch\_size的图像和标签,传入到resnet50模型中得到预测结果,将结果和标签传入损失函数中计算交叉熵损失,最后根据损失计算反向传播,Adam优化器执行模型参数更新,循环往复。
291
+
292
+ 在训练中我们最关心的指标是损失值`loss`,所以我们用`swanlab.log`跟踪它的变化。
293
+
294
+ ## 2.8 测试函数
295
+
296
+ 我们定义1个测试函数`test`:
297
+
298
+ ```python
299
+ def test(model, device, test_dataloader, epoch):
300
+ model.eval()
301
+ correct = 0
302
+ total = 0
303
+ with torch.no_grad():
304
+ for inputs, labels in test_dataloader:
305
+ inputs, labels = inputs.to(device), labels.to(device)
306
+ outputs = model(inputs)
307
+ _, predicted = torch.max(outputs.data, 1)
308
+ total += labels.size(0)
309
+ correct += (predicted == labels).sum().item()
310
+ accuracy = correct / total * 100
311
+ print('Accuracy: {:.2f}%'.format(accuracy))
312
+ swanlab.log({"test_acc": accuracy})
313
+ ```
314
+
315
+ 测试的逻辑同样很简单:我们循环调用`test_dataloader`,将测试集的图像传入到resnet50模型中得到预测结果,与标签进行对比,计算整体的准确率。
316
+
317
+ 在测试中我们最关心的指标是准确率`accuracy`,所以我们用`swanlab.log`跟踪它的变化。
318
+
319
+ ## 2.9 完整训练代码
320
+
321
+ 我们一共训练`num_epochs`轮,每4轮进行测试,并在最后保存权重文件:
322
+
323
+ ```python
324
+ for epoch in range(1, num_epochs + 1):
325
+ train(model, device, TrainDataLoader, optimizer, criterion, epoch)
326
+ if epoch % 4 == 0:
327
+ accuracy = test(model, device, ValDataLoader, epoch)
328
+
329
+ if not os.path.exists("checkpoint"):
330
+ os.makedirs("checkpoint")
331
+ torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
332
+ print("Training complete")
333
+ ```
334
+
335
+ 组合后的完整`train.py`代码:
336
+
337
+ ```python
338
+ import torch
339
+ import torchvision
340
+ from torchvision.models import ResNet50_Weights
341
+ import swanlab
342
+ from torch.utils.data import DataLoader
343
+ from load_datasets_simple import DatasetLoader
344
+ import os
345
+
346
+
347
+ # 定义训练函数
348
+ def train(model, device, train_dataloader, optimizer, criterion, epoch):
349
+ model.train()
350
+ for iter, (inputs, labels) in enumerate(train_dataloader):
351
+ inputs, labels = inputs.to(device), labels.to(device)
352
+ optimizer.zero_grad()
353
+ outputs = model(inputs)
354
+ loss = criterion(outputs, labels)
355
+ loss.backward()
356
+ optimizer.step()
357
+ print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
358
+ loss.item()))
359
+ swanlab.log({"train_loss": loss.item()})
360
+
361
+
362
+ # 定义测试函数
363
+ def test(model, device, test_dataloader, epoch):
364
+ model.eval()
365
+ correct = 0
366
+ total = 0
367
+ with torch.no_grad():
368
+ for inputs, labels in test_dataloader:
369
+ inputs, labels = inputs.to(device), labels.to(device)
370
+ outputs = model(inputs)
371
+ _, predicted = torch.max(outputs.data, 1)
372
+ total += labels.size(0)
373
+ correct += (predicted == labels).sum().item()
374
+ accuracy = correct / total * 100
375
+ print('Accuracy: {:.2f}%'.format(accuracy))
376
+ swanlab.log({"test_acc": accuracy})
377
+
378
+
379
+ if __name__ == "__main__":
380
+ num_epochs = 20
381
+ lr = 1e-4
382
+ batch_size = 8
383
+ num_classes = 2
384
+
385
+ # 设置device
386
+ try:
387
+ use_mps = torch.backends.mps.is_available()
388
+ except AttributeError:
389
+ use_mps = False
390
+
391
+ if torch.cuda.is_available():
392
+ device = "cuda"
393
+ elif use_mps:
394
+ device = "mps"
395
+ else:
396
+ device = "cpu"
397
+
398
+ # 初始化swanlab
399
+ swanlab.init(
400
+ experiment_name="ResNet50",
401
+ description="Train ResNet50 for cat and dog classification.",
402
+ config={
403
+ "model": "resnet50",
404
+ "optim": "Adam",
405
+ "lr": lr,
406
+ "batch_size": batch_size,
407
+ "num_epochs": num_epochs,
408
+ "num_class": num_classes,
409
+ "device": device,
410
+ }
411
+ )
412
+
413
+ TrainDataset = DatasetLoader("datasets/train.csv")
414
+ ValDataset = DatasetLoader("datasets/val.csv")
415
+ TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
416
+ ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
417
+
418
+ # 载入ResNet50模型
419
+ model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
420
+
421
+ # 将全连接层替换为2分类
422
+ in_features = model.fc.in_features
423
+ model.fc = torch.nn.Linear(in_features, num_classes)
424
+
425
+ model.to(torch.device(device))
426
+ criterion = torch.nn.CrossEntropyLoss()
427
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
428
+
429
+ # 开始训练
430
+ for epoch in range(1, num_epochs + 1):
431
+ train(model, device, TrainDataLoader, optimizer, criterion, epoch) # Train for one epoch
432
+
433
+ if epoch % 4 == 0: # Test every 4 epochs
434
+ accuracy = test(model, device, ValDataLoader, epoch)
435
+
436
+ # 保存权重文件
437
+ if not os.path.exists("checkpoint"):
438
+ os.makedirs("checkpoint")
439
+ torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
440
+ print("Training complete")
441
+ ```
442
+
443
+
444
+
445
+ ## 2.10 开始训练!
446
+
447
+ 运行`train.py`: ![在这里插入图片描述](readme_files/7.png)
448
+
449
+ 此时我们打开终端,输入`swanlab watch`开启SwanLab实验看板:
450
+
451
+ ![在这里插入图片描述](readme_files/8.png)
452
+
453
+ 点开http:127.0.0.1:5092,将在浏览器中看到实验看板。
454
+
455
+ 默认页面是Project DashBoard,包含了项目信息和一个对比实验表格:
456
+ ![在这里插入图片描述](readme_files/9.png)
457
+ 我们点开1个进行中的实验,会看到`train_loss`和`test_acc`整体的变化曲线:
458
+ ![在这里插入图片描述](readme_files/10.png)
459
+
460
+ 切换到**OverView标签页**,这里记录了实验的各种信息,包括**swanlab.init**中的参数、最终的实验指标、实验状态、训练时长、Git仓库链接、主机名、操作系统、Python版本、硬件配置等等。
461
+
462
+ **可以看到训练完成的模型在测试集上的准确率是100%。**
463
+
464
+ ![在这里插入图片描述](readme_files/11.png)
465
+ 至此我们完成了模型的训练和测试,得到了1个表现非常棒的猫狗分类模型,权重保存在了checkpoint目录下。
466
+
467
+ 接下来,我们就基于训练好的权重,创建1个Demo网页吧~
468
+
469
+ # 3. Gradio演示程序
470
+
471
+ Gradio是一个开源的Python库,旨在帮助数据科学家、研究人员和从事机器学习领域的开发人员快速创建和共享用于机器学习模型的用户界面。
472
+
473
+ 在这里我们使用Gradio来构建一个猫狗分类的Demo界面,编写`app.py`程序:
474
+
475
+ ```python
476
+ import gradio as gr
477
+ import torch
478
+ import torchvision.transforms as transforms
479
+ import torch.nn.functional as F
480
+ import torchvision
481
+
482
+ # 加载与训练中使用的相同结构的模型
483
+ def load_model(checkpoint_path, num_classes):
484
+ # 加载预训练的ResNet50模型
485
+ try:
486
+ use_mps = torch.backends.mps.is_available()
487
+ except AttributeError:
488
+ use_mps = False
489
+
490
+ if torch.cuda.is_available():
491
+ device = "cuda"
492
+ elif use_mps:
493
+ device = "mps"
494
+ else:
495
+ device = "cpu"
496
+
497
+ model = torchvision.models.resnet50(weights=None)
498
+ in_features = model.fc.in_features
499
+ model.fc = torch.nn.Linear(in_features, num_classes)
500
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
501
+ model.eval() # Set model to evaluation mode
502
+ return model
503
+
504
+ # 加载图像并执行必要的转换的函数
505
+ def process_image(image, image_size):
506
+ # Define the same transforms as used during training
507
+ preprocessing = transforms.Compose([
508
+ transforms.Resize((image_size, image_size)),
509
+ transforms.ToTensor(),
510
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
511
+ ])
512
+ image = preprocessing(image).unsqueeze(0)
513
+ return image
514
+
515
+
516
+ # 预测图像类别并返回概率的函数
517
+ def predict(image):
518
+ classes = {'0': 'cat', '1': 'dog'} # Update or extend this dictionary based on your actual classes
519
+ image = process_image(image, 256) # Using the image size from training
520
+ with torch.no_grad():
521
+ outputs = model(image)
522
+ probabilities = F.softmax(outputs, dim=1).squeeze() # Apply softmax to get probabilities
523
+ # Mapping class labels to probabilities
524
+ class_probabilities = {classes[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
525
+ return class_probabilities
526
+
527
+
528
+ # 定义到您的模型权重的路径
529
+ checkpoint_path = 'checkpoint/lastest_checkpoint.pth'
530
+ num_classes = 2
531
+ model = load_model(checkpoint_path, num_classes)
532
+
533
+ # 定义Gradio Interface
534
+ iface = gr.Interface(
535
+ fn=predict,
536
+ inputs=gr.Image(type="pil"),
537
+ outputs=gr.Label(num_top_classes=num_classes),
538
+ title="Cat vs Dog Classifier",
539
+ )
540
+
541
+ if __name__ == "__main__":
542
+ iface.launch()
543
+ ```
544
+
545
+ 运行程序后,会出现以下输出:
546
+
547
+ ![在这里插入图片描述](readme_files/12.png)
548
+
549
+ 点开链接,出现猫狗分类的Demo网页:
550
+
551
+ ![在这里插入图片描述](readme_files/13.png)
552
+
553
+ 用猫和狗的图片试试:
554
+
555
+ ![在这里插入图片描述](readme_files/14.png)
556
+
557
+ ![在这里插入图片描述](readme_files/15.png)
558
+
559
+ 效果很完美!
560
+
561
+ 至此,我们完成了用PyTorch、SwanLab、Gradio三个开源工具训练1个猫狗分类模型的全部过程,更多想了解的可以参考相关链接或评论此文章。
562
+
563
+ 如果有帮助,请Star吧~
564
+
565
+ # 4. 相关链接
566
+
567
+ * SwanLab:<https://github.com/SwanHubX/SwanLab>
568
+ * 猫狗分类代码:<https://github.com/xiaolin199912/Resnet50-cats_vs_dogs>
569
+ * 在线Demo:https://swanhub.co/ZeYiLin/Resnet50-cats_vs_dogs/demo
570
+ * 猫狗分类数据集(300张图像):<https://modelscope.cn/datasets/tany0699/cats_and_dogs/summary>
571
+ * 百度云下载:链接: <https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ> 提取码: 1ybm
572
+ * 猫狗分类数据集(10k张图像):<https://modelscope.cn/datasets/XCsunny/cat_vs_dog_class/summary>
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ import torch.nn.functional as F
5
+ import torchvision
6
+
7
+
8
+
9
+ # 加载与训练中使用的相同结构的模型
10
+ def load_model(checkpoint_path, num_classes):
11
+ # 加载预训练的ResNet50模型
12
+ try:
13
+ use_mps = torch.backends.mps.is_available()
14
+ except AttributeError:
15
+ use_mps = False
16
+
17
+ if torch.cuda.is_available():
18
+ device = "cuda"
19
+ elif use_mps:
20
+ device = "mps"
21
+ else:
22
+ device = "cpu"
23
+
24
+ model = torchvision.models.resnet50(weights=None)
25
+ in_features = model.fc.in_features
26
+ model.fc = torch.nn.Linear(in_features, num_classes)
27
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device))
28
+ model.eval() # Set model to evaluation mode
29
+ return model
30
+
31
+
32
+ # 加载图像并执行必要的转换的函数
33
+ def process_image(image, image_size):
34
+ # Define the same transforms as used during training
35
+ preprocessing = transforms.Compose([
36
+ transforms.Resize((image_size, image_size)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
39
+ ])
40
+ image = preprocessing(image).unsqueeze(0)
41
+ return image
42
+
43
+
44
+ # 预测图像类别并返回概率的函数
45
+ def predict(image):
46
+ classes = {'0': 'cat', '1': 'dog'} # Update or extend this dictionary based on your actual classes
47
+ image = process_image(image, 256) # Using the image size from training
48
+ with torch.no_grad():
49
+ outputs = model(image)
50
+ probabilities = F.softmax(outputs, dim=1).squeeze() # Apply softmax to get probabilities
51
+ # Mapping class labels to probabilities
52
+ class_probabilities = {classes[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
53
+ return class_probabilities
54
+
55
+
56
+ # 定义到您的模型权重的路径
57
+ checkpoint_path = 'checkpoint/latest_checkpoint.pth'
58
+ num_classes = 2
59
+ model = load_model(checkpoint_path, num_classes)
60
+
61
+ # 定义Gradio Interface
62
+ iface = gr.Interface(
63
+ fn=predict,
64
+ inputs=gr.Image(type="pil"),
65
+ outputs=gr.Label(num_top_classes=num_classes),
66
+ title="Cat vs Dog Classifier",
67
+ examples=["test_images/test_cat.jpg", "test_images/test_dog.jpg"]
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ iface.launch()
checkpoint/.gitkeep ADDED
File without changes
checkpoint/latest_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b5a1ebfec3b9a5d20bf36fd75dec3f82005bf0dc0dbffb4e6efb4ecc428e464
3
+ size 94364890
datasets/.gitkeep ADDED
File without changes
load_datasets.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class DatasetLoader(Dataset):
9
+ def __init__(self, csv_path):
10
+ self.csv_file = csv_path
11
+ with open(self.csv_file, 'r') as file:
12
+ self.data = list(csv.reader(file))
13
+
14
+ self.current_dir = os.path.dirname(os.path.abspath(__file__))
15
+
16
+ def preprocess_image(self, image_path):
17
+ """
18
+ Preprocess the image: Read the image, apply transformations, and return the transformed image.
19
+ """
20
+ full_path = os.path.join(self.current_dir, 'datasets', image_path)
21
+ image = Image.open(full_path)
22
+ image_transform = transforms.Compose([
23
+ transforms.Resize((256, 256)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ return image_transform(image)
29
+
30
+ def __getitem__(self, index):
31
+ """
32
+ Return the preprocessed image and its label at the specified index from the dataset.
33
+ """
34
+ image_path, label = self.data[index]
35
+ image = self.preprocess_image(image_path)
36
+ return image, int(label)
37
+
38
+ def __len__(self):
39
+ """
40
+ Return the number of items in the dataset.
41
+ """
42
+ return len(self.data)
readme_files/1.png ADDED
readme_files/10.png ADDED
readme_files/11.png ADDED
readme_files/12.png ADDED
readme_files/13.png ADDED
readme_files/14.png ADDED
readme_files/15.png ADDED
readme_files/2.png ADDED
readme_files/3.png ADDED
readme_files/4.png ADDED
readme_files/5.png ADDED
readme_files/6.gif ADDED

Git LFS Details

  • SHA256: 9c0179b825aeb1794ce592108247d8210d31222f13c016b03d5639588d2c6080
  • Pointer size: 132 Bytes
  • Size of remote file: 6.87 MB
readme_files/7.png ADDED
readme_files/8.png ADDED
readme_files/9.png ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=1.12.0
2
+ torchvision>=0.13.0
3
+ swanlab>=0.1.2
4
+ gradio
test_images/test_cat.jpg ADDED
test_images/test_dog.jpg ADDED
train.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.models import ResNet50_Weights
4
+ import swanlab
5
+ from torch.utils.data import DataLoader
6
+ from load_datasets import DatasetLoader
7
+ import os
8
+
9
+
10
+ # Define train function
11
+ def train(model, device, train_dataloader, optimizer, criterion, epoch):
12
+ model.train()
13
+ for iter, (inputs, labels) in enumerate(train_dataloader):
14
+ inputs, labels = inputs.to(device), labels.to(device)
15
+ optimizer.zero_grad()
16
+ outputs = model(inputs)
17
+ loss = criterion(outputs, labels)
18
+ loss.backward()
19
+ optimizer.step()
20
+ print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
21
+ loss.item()))
22
+ swanlab.log({"train_loss": loss.item()})
23
+
24
+
25
+ # Define test function
26
+ def test(model, device, test_dataloader, epoch):
27
+ model.eval()
28
+ correct = 0
29
+ total = 0
30
+ with torch.no_grad():
31
+ for inputs, labels in test_dataloader:
32
+ inputs, labels = inputs.to(device), labels.to(device)
33
+ outputs = model(inputs)
34
+ _, predicted = torch.max(outputs.data, 1)
35
+ total += labels.size(0)
36
+ correct += (predicted == labels).sum().item()
37
+ accuracy = correct / total * 100
38
+ print('Accuracy: {:.2f}%'.format(accuracy))
39
+ swanlab.log({"test_acc": accuracy})
40
+
41
+
42
+ if __name__ == "__main__":
43
+ num_epochs = 20
44
+ lr = 1e-4
45
+ batch_size = 16
46
+ num_classes = 2
47
+
48
+ try:
49
+ use_mps = torch.backends.mps.is_available()
50
+ except AttributeError:
51
+ use_mps = False
52
+
53
+ if torch.cuda.is_available():
54
+ device = "cuda"
55
+ elif use_mps:
56
+ device = "mps"
57
+ else:
58
+ device = "cpu"
59
+
60
+ # Initialize swanlab
61
+ swanlab.init(
62
+ experiment_name="ResNet50",
63
+ description="Train ResNet50 for cat and dog classification.",
64
+ config={
65
+ "model": "resnet50",
66
+ "optim": "Adam",
67
+ "lr": lr,
68
+ "batch_size": batch_size,
69
+ "num_epochs": num_epochs,
70
+ "num_class": num_classes,
71
+ "device": device,
72
+ }
73
+ )
74
+
75
+ TrainDataset = DatasetLoader("datasets/train.csv")
76
+ ValDataset = DatasetLoader("datasets/val.csv")
77
+ TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
78
+ ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
79
+
80
+ # Load the pre-trained ResNet50 model
81
+ model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
82
+
83
+ # Replace the last fully connected layer.
84
+ in_features = model.fc.in_features
85
+ model.fc = torch.nn.Linear(in_features, num_classes)
86
+
87
+ # Train
88
+ model.to(torch.device(device))
89
+ criterion = torch.nn.CrossEntropyLoss()
90
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
91
+
92
+
93
+ for epoch in range(1, num_epochs + 1):
94
+ train(model, device, TrainDataLoader, optimizer, criterion, epoch) # Train for one epoch
95
+
96
+ if epoch % 4 == 0: # Test every 4 epochs
97
+ accuracy = test(model, device, ValDataLoader, epoch)
98
+
99
+ if not os.path.exists("checkpoint"):
100
+ os.makedirs("checkpoint")
101
+ torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
102
+ print("Training complete")