TheEeeeLin
commited on
Commit
·
b83973e
1
Parent(s):
eac3bed
Upload 25 files
Browse files- .gitattributes +1 -0
- README.md +572 -13
- app.py +71 -0
- checkpoint/.gitkeep +0 -0
- checkpoint/latest_checkpoint.pth +3 -0
- datasets/.gitkeep +0 -0
- load_datasets.py +42 -0
- readme_files/1.png +0 -0
- readme_files/10.png +0 -0
- readme_files/11.png +0 -0
- readme_files/12.png +0 -0
- readme_files/13.png +0 -0
- readme_files/14.png +0 -0
- readme_files/15.png +0 -0
- readme_files/2.png +0 -0
- readme_files/3.png +0 -0
- readme_files/4.png +0 -0
- readme_files/5.png +0 -0
- readme_files/6.gif +3 -0
- readme_files/7.png +0 -0
- readme_files/8.png +0 -0
- readme_files/9.png +0 -0
- requirements.txt +4 -0
- test_images/test_cat.jpg +0 -0
- test_images/test_dog.jpg +0 -0
- train.py +102 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
readme_files/6.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,572 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Resnet50-cats_vs_dogs
|
2 |
+
|
3 |
+
[![zhihu](https://img.shields.io/badge/知乎-zhihu-blue)](https://zhuanlan.zhihu.com/p/676430630)
|
4 |
+
|
5 |
+
猫狗分类是计算机视觉最基础的任务之一——如果说完成MNIST手写体识别是实现CV的“Hello World”,那猫狗分类就是旅程的下一站~。
|
6 |
+
|
7 |
+
这篇文章我将带大家使用PyTorch、SwanLab、Gradio三个开源工具,完成从**数据集准备、代码编写、可视化训练到构建Demo网页**的全过程。
|
8 |
+
|
9 |
+
> 代码:[Github](https://github.com/xiaolin199912/Resnet50-cats_vs_dogs)
|
10 |
+
>
|
11 |
+
> 在线Demo: [SwanHub](https://swanhub.co/ZeYiLin/Resnet50-cats_vs_dogs/demo)
|
12 |
+
>
|
13 |
+
> 数据集:[百度云](https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ) 提取码: 1ybm
|
14 |
+
>
|
15 |
+
> 三个开源库:[pytorch](https://github.com/pytorch/pytorch)、[SwanLab](https://github.com/SwanHubX/SwanLab)、[Gradio](https://github.com/gradio-app/gradio)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
# 1. 准备部分
|
20 |
+
|
21 |
+
## 1.1 安装Python库
|
22 |
+
|
23 |
+
需要安装下面这4个库:
|
24 |
+
|
25 |
+
```bash
|
26 |
+
torch>=1.12.0
|
27 |
+
torchvision>=0.13.0
|
28 |
+
swanlab>=0.1.2
|
29 |
+
gradio
|
30 |
+
```
|
31 |
+
|
32 |
+
安装命令:
|
33 |
+
|
34 |
+
```bash
|
35 |
+
pip install torch>=1.12.0 torchvision>=0.13.0 swanlab>=0.1.2 gradio
|
36 |
+
```
|
37 |
+
|
38 |
+
## 1.2 创建文件目录
|
39 |
+
|
40 |
+
现在打开1个文件夹,新建下面这5个文件:
|
41 |
+
|
42 |
+
![在这里插入图片描述](readme_files/1.png)
|
43 |
+
|
44 |
+
它们各自的作用分别是:
|
45 |
+
|
46 |
+
* `checkpoint`:这个文件夹用于存储训练过程中生成的模型权重。
|
47 |
+
* `datasets`:这个文件夹用于放置数据集。
|
48 |
+
* `app.py`:运行Gradio Demo的Python脚本。
|
49 |
+
* `load_datasets.py`:负责载入数据集,包含了数据的预处理、加载等步骤,确保数据以适当的格式提供给模型使用。
|
50 |
+
* `train.py`:模型训练的核心脚本。它包含了模型的载入、训练循环、损失函数的选择、优化器的配置等关键组成部分,用于指导如何使用数据来训练模型。
|
51 |
+
|
52 |
+
## 1.3 下载猫狗分类数据集
|
53 |
+
|
54 |
+
数据集来源是Modelscope上的[猫狗分类数据集](https://modelscope.cn/datasets/tany0699/cats_and_dogs/summary),包含275张图像的数据集和70张图像的测试集,一共不到10MB。
|
55 |
+
我对数据做了一些整理,所以更推荐使用下面的百度网盘链接下载:
|
56 |
+
|
57 |
+
> 百度网盘:链接: <https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ> 提取码: 1ybm
|
58 |
+
|
59 |
+
![在这里插入图片描述](readme_files/2.png)
|
60 |
+
|
61 |
+
将数据集放入`datasets`文件夹:
|
62 |
+
|
63 |
+
![在这里插入图片描述](readme_files/3.png)
|
64 |
+
|
65 |
+
ok,现在我们开始训练部分!
|
66 |
+
|
67 |
+
> ps:如果你想要用更大规模的数据来训练猫狗分类模型,请前往文末的相关链接。
|
68 |
+
|
69 |
+
# 2. 训练部分
|
70 |
+
|
71 |
+
ps:如果想直接看完整代码和效果,可直接跳转到第2.9。
|
72 |
+
|
73 |
+
## 2.1 load_datasets.py
|
74 |
+
|
75 |
+
我们首先需要创建1个类`DatasetLoader`,它的作用是完成数据集的读取和预处理,我们将它写在`load_datasets.py`中。
|
76 |
+
在写这个类之前,先分析一下数据集。
|
77 |
+
在datasets目录下,`train.csv`和`val.csv`分别记录了训练集和测试集的图像相对路径(第一列是图像的相对路径,第二列是标签,0代表猫,1代表狗):
|
78 |
+
|
79 |
+
![在这里插入图片描述](readme_files/4.png)
|
80 |
+
|
81 |
+
![左图作为train.csv,右图为train文件夹中的cat文件夹中的图像](readme_files/5.png)
|
82 |
+
|
83 |
+
左图作为train.csv,右图为train文件夹中的cat文件夹中的图像。
|
84 |
+
|
85 |
+
那么我们的目标就很明确:
|
86 |
+
|
87 |
+
1. 解析这两个csv文件,获取图像相对路径和标签
|
88 |
+
2. 根据相对路径读取图像
|
89 |
+
3. 对图像做预处理
|
90 |
+
4. 返回预处理后的图像和对应标签
|
91 |
+
|
92 |
+
明确了目标后,现在我们开始写`DatasetLoader`类:
|
93 |
+
|
94 |
+
```python
|
95 |
+
import csv
|
96 |
+
import os
|
97 |
+
from torchvision import transforms
|
98 |
+
from PIL import Image
|
99 |
+
from torch.utils.data import Dataset
|
100 |
+
|
101 |
+
class DatasetLoader(Dataset):
|
102 |
+
def __init__(self, csv_path):
|
103 |
+
self.csv_file = csv_path
|
104 |
+
with open(self.csv_file, 'r') as file:
|
105 |
+
self.data = list(csv.reader(file))
|
106 |
+
|
107 |
+
self.current_dir = os.path.dirname(os.path.abspath(__file__))
|
108 |
+
|
109 |
+
def preprocess_image(self, image_path):
|
110 |
+
full_path = os.path.join(self.current_dir, 'datasets', image_path)
|
111 |
+
image = Image.open(full_path)
|
112 |
+
image_transform = transforms.Compose([
|
113 |
+
transforms.Resize((256, 256)),
|
114 |
+
transforms.ToTensor(),
|
115 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
116 |
+
])
|
117 |
+
|
118 |
+
return image_transform(image)
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
image_path, label = self.data[index]
|
122 |
+
image = self.preprocess_image(image_path)
|
123 |
+
return image, int(label)
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return len(self.data)
|
127 |
+
```
|
128 |
+
|
129 |
+
`DatasetLoader`类由四个部分组成:
|
130 |
+
|
131 |
+
1. `__init__`:包含1个输入参数csv\_path,在外部传入`csv_path`后,将读取后的数据存入`self.data`中。`self.current_dir`则是获取了当前代码所在目录的绝对路径,为后续读取图像做准备。
|
132 |
+
|
133 |
+
2. `preprocess_image`:此函数用于图像预处理。首先,它构造图像文件的绝对路径,然后使用PIL库打开图像。接着,定义了一系列图像变换:调整图像大小至256x256、转换图像为张量、对图像进行标准化处理,最终,返回预处理后的图像。
|
134 |
+
|
135 |
+
3. `__getitem__`:当数据集类被循环调用时,`__getitem__`方法会返回指定索引index的数据,即图像和标签。首先,它根据索引从`self.data`中取出图像路径和标签。然后,调用`preprocess_image`方法来处理图像数据。最后,将处理后的图像数据和标签转换为整型后返回。
|
136 |
+
|
137 |
+
4. `__len__`:用于返回数据集的总图像数量。
|
138 |
+
|
139 |
+
## 2.2 载入数据集
|
140 |
+
> 从本节开始,代码将写在`train.py`中。
|
141 |
+
|
142 |
+
```python
|
143 |
+
from torch.utils.data import DataLoader
|
144 |
+
from load_datasets import DatasetLoader
|
145 |
+
|
146 |
+
batch_size = 8
|
147 |
+
|
148 |
+
TrainDataset = DatasetLoader("datasets/train.csv")
|
149 |
+
ValDataset = DatasetLoader("datasets/val.csv")
|
150 |
+
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
|
151 |
+
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
|
152 |
+
```
|
153 |
+
|
154 |
+
我们传入那两个csv文件的路径实例化`DatasetLoader`类,然后用PyTorch的`DataLoader`做一层封装。`DataLoader`可以再传入两个参数:
|
155 |
+
|
156 |
+
* `batch_size`:定义了每个数据批次包含多少张图像。在深度学习中,我们通常不会一次性地处理所有数据,而是将数据划分为小批次。这有助于模型更快地学习,并且还可以节省内存。在这里我们定义batch\_size = 8,即每个批次将包含8个图像。
|
157 |
+
* `shuffle`:定义了是否在每个循环轮次(epoch)开始时随机打乱数据。这通常用于训练数据集以保证每个epoch的数据顺序不同,从而帮助模型更好地泛化。如果设置为True,那么在每个epoch开始时,数据将被打乱。在这里我们让训练时打乱,测试时不打乱。
|
158 |
+
|
159 |
+
## 2.3 载入ResNet50模型
|
160 |
+
|
161 |
+
模型我们选用经典的**ResNet50**,模型的具体原理本文就不细说了,重点放在工程实现上。
|
162 |
+
我们使用**torchvision**来创建1个resnet50模型,并载入在Imagenet1k数据集上预训练好的权重:
|
163 |
+
|
164 |
+
```python
|
165 |
+
from torchvision.models import ResNet50_Weights
|
166 |
+
|
167 |
+
# 加载预训练的ResNet50模型
|
168 |
+
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
169 |
+
```
|
170 |
+
|
171 |
+
因为猫狗分类是个2分类任务,而torchvision提供的resnet50默认是1000分类,所以我们需要把模型最后的全连接层的输出维度替换为2:
|
172 |
+
|
173 |
+
```python
|
174 |
+
from torchvision.models import ResNet50_Weights
|
175 |
+
|
176 |
+
num_classes=2
|
177 |
+
|
178 |
+
# 加载预训练的ResNet50模型
|
179 |
+
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
180 |
+
|
181 |
+
# 将全连接层的输出维度替换为num_classes
|
182 |
+
in_features = model.fc.in_features
|
183 |
+
model.fc = torch.nn.Linear(in_features, num_classes)
|
184 |
+
```
|
185 |
+
|
186 |
+
## 2.4 设置cuda/mps/cpu
|
187 |
+
|
188 |
+
如果你的电脑是**英伟达显卡**,那么cuda可以极大加速你的训练;
|
189 |
+
如果你的电脑是**Macbook Apple Sillicon(M系列芯片)**,那么mps同样可以极大加速你的训练;
|
190 |
+
如果都不是,那就选用cpu:
|
191 |
+
|
192 |
+
```python
|
193 |
+
#检测是否支持mps
|
194 |
+
try:
|
195 |
+
use_mps = torch.backends.mps.is_available()
|
196 |
+
except AttributeError:
|
197 |
+
use_mps = False
|
198 |
+
|
199 |
+
#检测是否支持cuda
|
200 |
+
if torch.cuda.is_available():
|
201 |
+
device = "cuda"
|
202 |
+
elif use_mps:
|
203 |
+
device = "mps"
|
204 |
+
else:
|
205 |
+
device = "cpu"
|
206 |
+
```
|
207 |
+
|
208 |
+
将模型加载到对应的device中:
|
209 |
+
|
210 |
+
```python
|
211 |
+
model.to(torch.device(device))
|
212 |
+
```
|
213 |
+
|
214 |
+
## 2.5 设置超参数、优化器、损失函数
|
215 |
+
|
216 |
+
**超参数**
|
217 |
+
设置训练轮次为20轮,学习率为1e-4,训练批次为8,分类数为2分类。
|
218 |
+
|
219 |
+
```python
|
220 |
+
num_epochs = 20
|
221 |
+
lr = 1e-4
|
222 |
+
batch_size = 8
|
223 |
+
num_classes = 2
|
224 |
+
```
|
225 |
+
|
226 |
+
### 损失函数与优化器
|
227 |
+
|
228 |
+
设置损失函数为交叉熵损失,优化器为Adam。
|
229 |
+
|
230 |
+
```python
|
231 |
+
criterion = torch.nn.CrossEntropyLoss()
|
232 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
233 |
+
```
|
234 |
+
|
235 |
+
## 2.6 初始化SwanLab
|
236 |
+
|
237 |
+
在训练中我们使用`swanlab`库作为实验管理与指标可视化工具。
|
238 |
+
[swanlab](https://github.com/SwanHubX/SwanLab)是一个类似Tensorboard的开源训练图表可视化库,有着更轻量的体积与更友好的API,除了能记录指标,还能自动记录训练的logging、硬件环境、Python环境、训练时间等信息。
|
239 |
+
|
240 |
+
![在这里插入图片描述](readme_files/6.gif)
|
241 |
+
|
242 |
+
### 设置初始化配置参数
|
243 |
+
|
244 |
+
swanlab库使用`swanlab.init`设置实验名、实验介绍和记录超参数。
|
245 |
+
|
246 |
+
```python
|
247 |
+
import swanlab
|
248 |
+
|
249 |
+
swanlab.init(
|
250 |
+
# 设置实验名
|
251 |
+
experiment_name="ResNet50",
|
252 |
+
# 设置实验介绍
|
253 |
+
description="Train ResNet50 for cat and dog classification.",
|
254 |
+
# 记录超参数
|
255 |
+
config={
|
256 |
+
"model": "resnet50",
|
257 |
+
"optim": "Adam",
|
258 |
+
"lr": lr,
|
259 |
+
"batch_size": batch_size,
|
260 |
+
"num_epochs": num_epochs,
|
261 |
+
"num_class": num_classes,
|
262 |
+
"device": device,
|
263 |
+
}
|
264 |
+
)
|
265 |
+
```
|
266 |
+
|
267 |
+
### 跟踪关键指标
|
268 |
+
|
269 |
+
swanlab库使用`swanlab.log`来记录关键指标,具体使用案例见2.7和2.8。
|
270 |
+
|
271 |
+
## 2.7 训练函数
|
272 |
+
|
273 |
+
我们定义1个训练函数`train`:
|
274 |
+
|
275 |
+
```python
|
276 |
+
def train(model, device, train_dataloader, optimizer, criterion, epoch):
|
277 |
+
model.train()
|
278 |
+
for iter, (inputs, labels) in enumerate(train_loader):
|
279 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
280 |
+
optimizer.zero_grad()
|
281 |
+
outputs = model(inputs)
|
282 |
+
loss = criterion(outputs, labels)
|
283 |
+
loss.backward()
|
284 |
+
optimizer.step()
|
285 |
+
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
|
286 |
+
loss.item()))
|
287 |
+
swanlab.log({"train_loss": loss.item()})
|
288 |
+
```
|
289 |
+
|
290 |
+
训练的逻辑很简单:我们循环调用`train_dataloader`,每次取出1个batch\_size的图像和标签,传入到resnet50模型中得到预测结果,将结果和标签传入损失函数中计算交叉熵损失,最后根据损失计算反向传播,Adam优化器执行模型参数更新,循环往复。
|
291 |
+
|
292 |
+
在训练中我们最关心的指标是损失值`loss`,所以我们用`swanlab.log`跟踪它的变化。
|
293 |
+
|
294 |
+
## 2.8 测试函数
|
295 |
+
|
296 |
+
我们定义1个测试函数`test`:
|
297 |
+
|
298 |
+
```python
|
299 |
+
def test(model, device, test_dataloader, epoch):
|
300 |
+
model.eval()
|
301 |
+
correct = 0
|
302 |
+
total = 0
|
303 |
+
with torch.no_grad():
|
304 |
+
for inputs, labels in test_dataloader:
|
305 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
306 |
+
outputs = model(inputs)
|
307 |
+
_, predicted = torch.max(outputs.data, 1)
|
308 |
+
total += labels.size(0)
|
309 |
+
correct += (predicted == labels).sum().item()
|
310 |
+
accuracy = correct / total * 100
|
311 |
+
print('Accuracy: {:.2f}%'.format(accuracy))
|
312 |
+
swanlab.log({"test_acc": accuracy})
|
313 |
+
```
|
314 |
+
|
315 |
+
测试的逻辑同样很简单:我们循环调用`test_dataloader`,将测试集的图像传入到resnet50模型中得到预测结果,与标签进行对比,计算整体的准确率。
|
316 |
+
|
317 |
+
在测试中我们最关心的指标是准确率`accuracy`,所以我们用`swanlab.log`跟踪它的变化。
|
318 |
+
|
319 |
+
## 2.9 完整训练代码
|
320 |
+
|
321 |
+
我们一共训练`num_epochs`轮,每4轮进行测试,并在最后保存权重文件:
|
322 |
+
|
323 |
+
```python
|
324 |
+
for epoch in range(1, num_epochs + 1):
|
325 |
+
train(model, device, TrainDataLoader, optimizer, criterion, epoch)
|
326 |
+
if epoch % 4 == 0:
|
327 |
+
accuracy = test(model, device, ValDataLoader, epoch)
|
328 |
+
|
329 |
+
if not os.path.exists("checkpoint"):
|
330 |
+
os.makedirs("checkpoint")
|
331 |
+
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
|
332 |
+
print("Training complete")
|
333 |
+
```
|
334 |
+
|
335 |
+
组合后的完整`train.py`代码:
|
336 |
+
|
337 |
+
```python
|
338 |
+
import torch
|
339 |
+
import torchvision
|
340 |
+
from torchvision.models import ResNet50_Weights
|
341 |
+
import swanlab
|
342 |
+
from torch.utils.data import DataLoader
|
343 |
+
from load_datasets_simple import DatasetLoader
|
344 |
+
import os
|
345 |
+
|
346 |
+
|
347 |
+
# 定义训练函数
|
348 |
+
def train(model, device, train_dataloader, optimizer, criterion, epoch):
|
349 |
+
model.train()
|
350 |
+
for iter, (inputs, labels) in enumerate(train_dataloader):
|
351 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
352 |
+
optimizer.zero_grad()
|
353 |
+
outputs = model(inputs)
|
354 |
+
loss = criterion(outputs, labels)
|
355 |
+
loss.backward()
|
356 |
+
optimizer.step()
|
357 |
+
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
|
358 |
+
loss.item()))
|
359 |
+
swanlab.log({"train_loss": loss.item()})
|
360 |
+
|
361 |
+
|
362 |
+
# 定义测试函数
|
363 |
+
def test(model, device, test_dataloader, epoch):
|
364 |
+
model.eval()
|
365 |
+
correct = 0
|
366 |
+
total = 0
|
367 |
+
with torch.no_grad():
|
368 |
+
for inputs, labels in test_dataloader:
|
369 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
370 |
+
outputs = model(inputs)
|
371 |
+
_, predicted = torch.max(outputs.data, 1)
|
372 |
+
total += labels.size(0)
|
373 |
+
correct += (predicted == labels).sum().item()
|
374 |
+
accuracy = correct / total * 100
|
375 |
+
print('Accuracy: {:.2f}%'.format(accuracy))
|
376 |
+
swanlab.log({"test_acc": accuracy})
|
377 |
+
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
num_epochs = 20
|
381 |
+
lr = 1e-4
|
382 |
+
batch_size = 8
|
383 |
+
num_classes = 2
|
384 |
+
|
385 |
+
# 设置device
|
386 |
+
try:
|
387 |
+
use_mps = torch.backends.mps.is_available()
|
388 |
+
except AttributeError:
|
389 |
+
use_mps = False
|
390 |
+
|
391 |
+
if torch.cuda.is_available():
|
392 |
+
device = "cuda"
|
393 |
+
elif use_mps:
|
394 |
+
device = "mps"
|
395 |
+
else:
|
396 |
+
device = "cpu"
|
397 |
+
|
398 |
+
# 初始化swanlab
|
399 |
+
swanlab.init(
|
400 |
+
experiment_name="ResNet50",
|
401 |
+
description="Train ResNet50 for cat and dog classification.",
|
402 |
+
config={
|
403 |
+
"model": "resnet50",
|
404 |
+
"optim": "Adam",
|
405 |
+
"lr": lr,
|
406 |
+
"batch_size": batch_size,
|
407 |
+
"num_epochs": num_epochs,
|
408 |
+
"num_class": num_classes,
|
409 |
+
"device": device,
|
410 |
+
}
|
411 |
+
)
|
412 |
+
|
413 |
+
TrainDataset = DatasetLoader("datasets/train.csv")
|
414 |
+
ValDataset = DatasetLoader("datasets/val.csv")
|
415 |
+
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
|
416 |
+
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
|
417 |
+
|
418 |
+
# 载入ResNet50模型
|
419 |
+
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
420 |
+
|
421 |
+
# 将全连接层替换为2分类
|
422 |
+
in_features = model.fc.in_features
|
423 |
+
model.fc = torch.nn.Linear(in_features, num_classes)
|
424 |
+
|
425 |
+
model.to(torch.device(device))
|
426 |
+
criterion = torch.nn.CrossEntropyLoss()
|
427 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
428 |
+
|
429 |
+
# 开始训练
|
430 |
+
for epoch in range(1, num_epochs + 1):
|
431 |
+
train(model, device, TrainDataLoader, optimizer, criterion, epoch) # Train for one epoch
|
432 |
+
|
433 |
+
if epoch % 4 == 0: # Test every 4 epochs
|
434 |
+
accuracy = test(model, device, ValDataLoader, epoch)
|
435 |
+
|
436 |
+
# 保存权重文件
|
437 |
+
if not os.path.exists("checkpoint"):
|
438 |
+
os.makedirs("checkpoint")
|
439 |
+
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
|
440 |
+
print("Training complete")
|
441 |
+
```
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
## 2.10 开始训练!
|
446 |
+
|
447 |
+
运行`train.py`: ![在这里插入图片描述](readme_files/7.png)
|
448 |
+
|
449 |
+
此时我们打开终端,输入`swanlab watch`开启SwanLab实验看板:
|
450 |
+
|
451 |
+
![在这里插入图片描述](readme_files/8.png)
|
452 |
+
|
453 |
+
点开http:127.0.0.1:5092,将在浏览器中看到实验看板。
|
454 |
+
|
455 |
+
默认页面是Project DashBoard,包含了项目信息和一个对比实验表格:
|
456 |
+
![在这里插入图片描述](readme_files/9.png)
|
457 |
+
我们点开1个进行中的实验,会看到`train_loss`和`test_acc`整体的变化曲线:
|
458 |
+
![在这里插入图片描述](readme_files/10.png)
|
459 |
+
|
460 |
+
切换到**OverView标签页**,这里记录了实验的各种信息,包括**swanlab.init**中的参数、最终的实验指标、实验状态、训练时长、Git仓库链接、主机名、操作系统、Python版本、硬件配置等等。
|
461 |
+
|
462 |
+
**可以看到训练完成的模型在测试集上的准确率是100%。**
|
463 |
+
|
464 |
+
![在这里插入图片描述](readme_files/11.png)
|
465 |
+
至此我们完成了模型的训练和测试,得到了1个表现非常棒的猫狗分类模型,权重保存在了checkpoint目录下。
|
466 |
+
|
467 |
+
接下来,我们就基于训练好的权重,创建1个Demo网页吧~
|
468 |
+
|
469 |
+
# 3. Gradio演示程序
|
470 |
+
|
471 |
+
Gradio是一个开源的Python库,旨在帮助数据科学家、研究人员和从事机器学习领域的开发人员快速创建和共享用于机器学习模型的用户界面。
|
472 |
+
|
473 |
+
在这里我们使用Gradio来构建一个猫狗分类的Demo界面,编写`app.py`程序:
|
474 |
+
|
475 |
+
```python
|
476 |
+
import gradio as gr
|
477 |
+
import torch
|
478 |
+
import torchvision.transforms as transforms
|
479 |
+
import torch.nn.functional as F
|
480 |
+
import torchvision
|
481 |
+
|
482 |
+
# 加载与训练中使用的相同结构的模型
|
483 |
+
def load_model(checkpoint_path, num_classes):
|
484 |
+
# 加载预训练的ResNet50模型
|
485 |
+
try:
|
486 |
+
use_mps = torch.backends.mps.is_available()
|
487 |
+
except AttributeError:
|
488 |
+
use_mps = False
|
489 |
+
|
490 |
+
if torch.cuda.is_available():
|
491 |
+
device = "cuda"
|
492 |
+
elif use_mps:
|
493 |
+
device = "mps"
|
494 |
+
else:
|
495 |
+
device = "cpu"
|
496 |
+
|
497 |
+
model = torchvision.models.resnet50(weights=None)
|
498 |
+
in_features = model.fc.in_features
|
499 |
+
model.fc = torch.nn.Linear(in_features, num_classes)
|
500 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
501 |
+
model.eval() # Set model to evaluation mode
|
502 |
+
return model
|
503 |
+
|
504 |
+
# 加载图像并执行必要的转换的函数
|
505 |
+
def process_image(image, image_size):
|
506 |
+
# Define the same transforms as used during training
|
507 |
+
preprocessing = transforms.Compose([
|
508 |
+
transforms.Resize((image_size, image_size)),
|
509 |
+
transforms.ToTensor(),
|
510 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
511 |
+
])
|
512 |
+
image = preprocessing(image).unsqueeze(0)
|
513 |
+
return image
|
514 |
+
|
515 |
+
|
516 |
+
# 预测图像类别并返回概率的函数
|
517 |
+
def predict(image):
|
518 |
+
classes = {'0': 'cat', '1': 'dog'} # Update or extend this dictionary based on your actual classes
|
519 |
+
image = process_image(image, 256) # Using the image size from training
|
520 |
+
with torch.no_grad():
|
521 |
+
outputs = model(image)
|
522 |
+
probabilities = F.softmax(outputs, dim=1).squeeze() # Apply softmax to get probabilities
|
523 |
+
# Mapping class labels to probabilities
|
524 |
+
class_probabilities = {classes[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
|
525 |
+
return class_probabilities
|
526 |
+
|
527 |
+
|
528 |
+
# 定义到您的模型权重的路径
|
529 |
+
checkpoint_path = 'checkpoint/lastest_checkpoint.pth'
|
530 |
+
num_classes = 2
|
531 |
+
model = load_model(checkpoint_path, num_classes)
|
532 |
+
|
533 |
+
# 定义Gradio Interface
|
534 |
+
iface = gr.Interface(
|
535 |
+
fn=predict,
|
536 |
+
inputs=gr.Image(type="pil"),
|
537 |
+
outputs=gr.Label(num_top_classes=num_classes),
|
538 |
+
title="Cat vs Dog Classifier",
|
539 |
+
)
|
540 |
+
|
541 |
+
if __name__ == "__main__":
|
542 |
+
iface.launch()
|
543 |
+
```
|
544 |
+
|
545 |
+
运行程序后,会出现以下输出:
|
546 |
+
|
547 |
+
![在这里插入图片描述](readme_files/12.png)
|
548 |
+
|
549 |
+
点开链接,出现猫狗分类的Demo网页:
|
550 |
+
|
551 |
+
![在这里插入图片描述](readme_files/13.png)
|
552 |
+
|
553 |
+
用猫和狗的图片试试:
|
554 |
+
|
555 |
+
![在这里插入图片描述](readme_files/14.png)
|
556 |
+
|
557 |
+
![在这里插入图片描述](readme_files/15.png)
|
558 |
+
|
559 |
+
效果很完美!
|
560 |
+
|
561 |
+
至此,我们完成了用PyTorch、SwanLab、Gradio三个开源工具训练1个猫狗分类模型的全部过程,更多想了解的可以参考相关链接或评论此文章。
|
562 |
+
|
563 |
+
如果有帮助,请Star吧~
|
564 |
+
|
565 |
+
# 4. 相关链接
|
566 |
+
|
567 |
+
* SwanLab:<https://github.com/SwanHubX/SwanLab>
|
568 |
+
* 猫狗分类代码:<https://github.com/xiaolin199912/Resnet50-cats_vs_dogs>
|
569 |
+
* 在线Demo:https://swanhub.co/ZeYiLin/Resnet50-cats_vs_dogs/demo
|
570 |
+
* 猫狗分类数据集(300张图像):<https://modelscope.cn/datasets/tany0699/cats_and_dogs/summary>
|
571 |
+
* 百度云下载:链接: <https://pan.baidu.com/s/1qYa13SxFM0AirzDyFMy0mQ> 提取码: 1ybm
|
572 |
+
* 猫狗分类数据集(10k张图像):<https://modelscope.cn/datasets/XCsunny/cat_vs_dog_class/summary>
|
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
# 加载与训练中使用的相同结构的模型
|
10 |
+
def load_model(checkpoint_path, num_classes):
|
11 |
+
# 加载预训练的ResNet50模型
|
12 |
+
try:
|
13 |
+
use_mps = torch.backends.mps.is_available()
|
14 |
+
except AttributeError:
|
15 |
+
use_mps = False
|
16 |
+
|
17 |
+
if torch.cuda.is_available():
|
18 |
+
device = "cuda"
|
19 |
+
elif use_mps:
|
20 |
+
device = "mps"
|
21 |
+
else:
|
22 |
+
device = "cpu"
|
23 |
+
|
24 |
+
model = torchvision.models.resnet50(weights=None)
|
25 |
+
in_features = model.fc.in_features
|
26 |
+
model.fc = torch.nn.Linear(in_features, num_classes)
|
27 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
28 |
+
model.eval() # Set model to evaluation mode
|
29 |
+
return model
|
30 |
+
|
31 |
+
|
32 |
+
# 加载图像并执行必要的转换的函数
|
33 |
+
def process_image(image, image_size):
|
34 |
+
# Define the same transforms as used during training
|
35 |
+
preprocessing = transforms.Compose([
|
36 |
+
transforms.Resize((image_size, image_size)),
|
37 |
+
transforms.ToTensor(),
|
38 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
39 |
+
])
|
40 |
+
image = preprocessing(image).unsqueeze(0)
|
41 |
+
return image
|
42 |
+
|
43 |
+
|
44 |
+
# 预测图像类别并返回概率的函数
|
45 |
+
def predict(image):
|
46 |
+
classes = {'0': 'cat', '1': 'dog'} # Update or extend this dictionary based on your actual classes
|
47 |
+
image = process_image(image, 256) # Using the image size from training
|
48 |
+
with torch.no_grad():
|
49 |
+
outputs = model(image)
|
50 |
+
probabilities = F.softmax(outputs, dim=1).squeeze() # Apply softmax to get probabilities
|
51 |
+
# Mapping class labels to probabilities
|
52 |
+
class_probabilities = {classes[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
|
53 |
+
return class_probabilities
|
54 |
+
|
55 |
+
|
56 |
+
# 定义到您的模型权重的路径
|
57 |
+
checkpoint_path = 'checkpoint/latest_checkpoint.pth'
|
58 |
+
num_classes = 2
|
59 |
+
model = load_model(checkpoint_path, num_classes)
|
60 |
+
|
61 |
+
# 定义Gradio Interface
|
62 |
+
iface = gr.Interface(
|
63 |
+
fn=predict,
|
64 |
+
inputs=gr.Image(type="pil"),
|
65 |
+
outputs=gr.Label(num_top_classes=num_classes),
|
66 |
+
title="Cat vs Dog Classifier",
|
67 |
+
examples=["test_images/test_cat.jpg", "test_images/test_dog.jpg"]
|
68 |
+
)
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
iface.launch()
|
checkpoint/.gitkeep
ADDED
File without changes
|
checkpoint/latest_checkpoint.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b5a1ebfec3b9a5d20bf36fd75dec3f82005bf0dc0dbffb4e6efb4ecc428e464
|
3 |
+
size 94364890
|
datasets/.gitkeep
ADDED
File without changes
|
load_datasets.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class DatasetLoader(Dataset):
|
9 |
+
def __init__(self, csv_path):
|
10 |
+
self.csv_file = csv_path
|
11 |
+
with open(self.csv_file, 'r') as file:
|
12 |
+
self.data = list(csv.reader(file))
|
13 |
+
|
14 |
+
self.current_dir = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
|
16 |
+
def preprocess_image(self, image_path):
|
17 |
+
"""
|
18 |
+
Preprocess the image: Read the image, apply transformations, and return the transformed image.
|
19 |
+
"""
|
20 |
+
full_path = os.path.join(self.current_dir, 'datasets', image_path)
|
21 |
+
image = Image.open(full_path)
|
22 |
+
image_transform = transforms.Compose([
|
23 |
+
transforms.Resize((256, 256)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
26 |
+
])
|
27 |
+
|
28 |
+
return image_transform(image)
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
"""
|
32 |
+
Return the preprocessed image and its label at the specified index from the dataset.
|
33 |
+
"""
|
34 |
+
image_path, label = self.data[index]
|
35 |
+
image = self.preprocess_image(image_path)
|
36 |
+
return image, int(label)
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
"""
|
40 |
+
Return the number of items in the dataset.
|
41 |
+
"""
|
42 |
+
return len(self.data)
|
readme_files/1.png
ADDED
readme_files/10.png
ADDED
readme_files/11.png
ADDED
readme_files/12.png
ADDED
readme_files/13.png
ADDED
readme_files/14.png
ADDED
readme_files/15.png
ADDED
readme_files/2.png
ADDED
readme_files/3.png
ADDED
readme_files/4.png
ADDED
readme_files/5.png
ADDED
readme_files/6.gif
ADDED
Git LFS Details
|
readme_files/7.png
ADDED
readme_files/8.png
ADDED
readme_files/9.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.12.0
|
2 |
+
torchvision>=0.13.0
|
3 |
+
swanlab>=0.1.2
|
4 |
+
gradio
|
test_images/test_cat.jpg
ADDED
test_images/test_dog.jpg
ADDED
train.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torchvision.models import ResNet50_Weights
|
4 |
+
import swanlab
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
from load_datasets import DatasetLoader
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
# Define train function
|
11 |
+
def train(model, device, train_dataloader, optimizer, criterion, epoch):
|
12 |
+
model.train()
|
13 |
+
for iter, (inputs, labels) in enumerate(train_dataloader):
|
14 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
15 |
+
optimizer.zero_grad()
|
16 |
+
outputs = model(inputs)
|
17 |
+
loss = criterion(outputs, labels)
|
18 |
+
loss.backward()
|
19 |
+
optimizer.step()
|
20 |
+
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(TrainDataLoader),
|
21 |
+
loss.item()))
|
22 |
+
swanlab.log({"train_loss": loss.item()})
|
23 |
+
|
24 |
+
|
25 |
+
# Define test function
|
26 |
+
def test(model, device, test_dataloader, epoch):
|
27 |
+
model.eval()
|
28 |
+
correct = 0
|
29 |
+
total = 0
|
30 |
+
with torch.no_grad():
|
31 |
+
for inputs, labels in test_dataloader:
|
32 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
33 |
+
outputs = model(inputs)
|
34 |
+
_, predicted = torch.max(outputs.data, 1)
|
35 |
+
total += labels.size(0)
|
36 |
+
correct += (predicted == labels).sum().item()
|
37 |
+
accuracy = correct / total * 100
|
38 |
+
print('Accuracy: {:.2f}%'.format(accuracy))
|
39 |
+
swanlab.log({"test_acc": accuracy})
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
num_epochs = 20
|
44 |
+
lr = 1e-4
|
45 |
+
batch_size = 16
|
46 |
+
num_classes = 2
|
47 |
+
|
48 |
+
try:
|
49 |
+
use_mps = torch.backends.mps.is_available()
|
50 |
+
except AttributeError:
|
51 |
+
use_mps = False
|
52 |
+
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
device = "cuda"
|
55 |
+
elif use_mps:
|
56 |
+
device = "mps"
|
57 |
+
else:
|
58 |
+
device = "cpu"
|
59 |
+
|
60 |
+
# Initialize swanlab
|
61 |
+
swanlab.init(
|
62 |
+
experiment_name="ResNet50",
|
63 |
+
description="Train ResNet50 for cat and dog classification.",
|
64 |
+
config={
|
65 |
+
"model": "resnet50",
|
66 |
+
"optim": "Adam",
|
67 |
+
"lr": lr,
|
68 |
+
"batch_size": batch_size,
|
69 |
+
"num_epochs": num_epochs,
|
70 |
+
"num_class": num_classes,
|
71 |
+
"device": device,
|
72 |
+
}
|
73 |
+
)
|
74 |
+
|
75 |
+
TrainDataset = DatasetLoader("datasets/train.csv")
|
76 |
+
ValDataset = DatasetLoader("datasets/val.csv")
|
77 |
+
TrainDataLoader = DataLoader(TrainDataset, batch_size=batch_size, shuffle=True)
|
78 |
+
ValDataLoader = DataLoader(ValDataset, batch_size=batch_size, shuffle=False)
|
79 |
+
|
80 |
+
# Load the pre-trained ResNet50 model
|
81 |
+
model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
|
82 |
+
|
83 |
+
# Replace the last fully connected layer.
|
84 |
+
in_features = model.fc.in_features
|
85 |
+
model.fc = torch.nn.Linear(in_features, num_classes)
|
86 |
+
|
87 |
+
# Train
|
88 |
+
model.to(torch.device(device))
|
89 |
+
criterion = torch.nn.CrossEntropyLoss()
|
90 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
91 |
+
|
92 |
+
|
93 |
+
for epoch in range(1, num_epochs + 1):
|
94 |
+
train(model, device, TrainDataLoader, optimizer, criterion, epoch) # Train for one epoch
|
95 |
+
|
96 |
+
if epoch % 4 == 0: # Test every 4 epochs
|
97 |
+
accuracy = test(model, device, ValDataLoader, epoch)
|
98 |
+
|
99 |
+
if not os.path.exists("checkpoint"):
|
100 |
+
os.makedirs("checkpoint")
|
101 |
+
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')
|
102 |
+
print("Training complete")
|