Facepalm0 commited on
Commit
cdf5f1c
·
verified ·
1 Parent(s): 0efb5b8

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +277 -0
train.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import transforms
6
+ import os
7
+ import numpy as np
8
+ import wandb
9
+ from PIL import Image
10
+ from models.resnet import resnet18, resnet34, resnet50
11
+ from models.openmax import OpenMax
12
+ # from models.metamax import MetaMax
13
+ from utils.data_stats import calculate_dataset_stats, load_dataset_stats
14
+ from utils.eval_utils import evaluate_known_classes, evaluate_openmax, evaluate_metamax
15
+ from pprint import pprint
16
+ import math
17
+
18
+
19
+ class GameDataset(Dataset):
20
+ def __init__(self, data_dir, num_labels=20, transform=None):
21
+ self.data_dir = data_dir
22
+ self.transform = transform
23
+ self.images = []
24
+ self.labels = []
25
+ self.image_paths = []
26
+
27
+ if not os.path.exists(data_dir):
28
+ raise ValueError(f"Data directory {data_dir} does not exist")
29
+
30
+ # 遍历数据目录加载图片和标签
31
+ for class_dir in range(num_labels): # 训练集为0-19类,验证集为0-20类
32
+ class_path = os.path.join(data_dir, f"{class_dir:02d}")
33
+ if os.path.exists(class_path):
34
+ for img_name in os.listdir(class_path):
35
+ if img_name.endswith('.png'):
36
+ img_path = os.path.join(class_path, img_name)
37
+ try:
38
+ # 读取PNG图片,只保留RGB通道
39
+ img = np.array(Image.open(img_path))[:, :, :3] # 只取前3个通道
40
+ if img.shape != (50, 50, 3):
41
+ print(f"Skipping {img_path} due to invalid shape: {img.shape}")
42
+ continue
43
+
44
+ self.images.append(img)
45
+ self.labels.append(class_dir)
46
+ self.image_paths.append(img_path)
47
+ except Exception as e:
48
+ print(f"Error loading {img_path}: {e}")
49
+ continue
50
+
51
+ self.images = np.array(self.images)
52
+ self.labels = np.array(self.labels)
53
+ print(f"Loaded {len(self.images)} images from {data_dir}")
54
+
55
+ def __len__(self):
56
+ return len(self.images)
57
+
58
+ def __getitem__(self, idx):
59
+ image = self.images[idx]
60
+ label = self.labels[idx]
61
+ path = self.image_paths[idx]
62
+
63
+ if self.transform:
64
+ image = self.transform(image)
65
+
66
+ return image, label, path
67
+
68
+
69
+
70
+ def train(num_epochs = 20, batch_size = 256, learning_rate = 0.001, dropout_rate = 0.3, patience = 10, model_type='resnet34'):
71
+ from post_train import collect_features
72
+ os.makedirs('models', exist_ok=True)
73
+ os.makedirs('wandb_logs', exist_ok=True)
74
+ images_path = os.path.join('jk_zfls', 'round0_train')
75
+ # 尝试加载已保存的数据集统计信息,如果不存在则重新计算
76
+ try:
77
+ mean, std = load_dataset_stats()
78
+ print("Loaded pre-calculated dataset statistics")
79
+ except FileNotFoundError:
80
+ print("FileNotFound, Calculating dataset statistics...")
81
+ mean, std = calculate_dataset_stats(images_path)
82
+
83
+ wandb.init(
84
+ project="jk_zfls",
85
+ name=f"{model_type}-training",
86
+ config={
87
+ "learning_rate": learning_rate,
88
+ "batch_size": batch_size,
89
+ "epochs": num_epochs,
90
+ "model": f"{model_type}",
91
+ "num_classes": 20
92
+ },
93
+ dir="./wandb_logs"
94
+ )
95
+
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+
98
+ # 计算填充值 (将均值从[0,1]转换为[0,255])
99
+ fill_value = tuple(int(x * 255) for x in mean)
100
+
101
+ # 增加数据增强
102
+ transform = transforms.Compose([
103
+ transforms.ToTensor(),
104
+ transforms.RandomAffine(
105
+ degrees=15,
106
+ translate=(0.1, 0.1),
107
+ scale=(0.9, 1.1),
108
+ fill=fill_value # 使用数据集的均值作为填充值
109
+ ),
110
+ transforms.Normalize(mean=mean, std=std)
111
+ ])
112
+
113
+ # 验证集不需要数据增强
114
+ val_transform = transforms.Compose([
115
+ transforms.ToTensor(),
116
+ transforms.Normalize(mean=mean, std=std)
117
+ ])
118
+
119
+ # 加载数据集
120
+ train_dataset = GameDataset('jk_zfls/round0_train', num_labels=20, transform=transform)
121
+ val_dataset = GameDataset('jk_zfls/round0_eval', num_labels=21, transform=val_transform)
122
+
123
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
124
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
125
+
126
+ # 根据选择加载不同的模型
127
+ if model_type == 'resnet18':
128
+ model = resnet18(num_classes=20, dropout_rate=dropout_rate)
129
+ elif model_type == 'resnet34':
130
+ model = resnet34(num_classes=20, dropout_rate=dropout_rate)
131
+ elif model_type == 'resnet50':
132
+ model = resnet50(num_classes=20, dropout_rate=dropout_rate)
133
+ else:
134
+ raise ValueError(f"Unsupported model type: {model_type}")
135
+
136
+ # 加载模型(和已有参数)
137
+ # checkpoint = torch.load('models/best_model_99.75.pth')
138
+ # model.load_state_dict(checkpoint['model_state_dict'])
139
+ model = model.to(device)
140
+
141
+ # 定义损失函数和优化器,使用更小的学习率
142
+ criterion = nn.CrossEntropyLoss()
143
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate * 0.1, weight_decay=1e-3)
144
+
145
+ # optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
146
+ # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
147
+ # 使用带 warmup 的 cosine 调度器
148
+ num_training_steps = len(train_loader) * num_epochs
149
+ num_warmup_steps = len(train_loader) * 5 # 5个epoch的warmup
150
+
151
+ # 定义warmup调度器和ReduceLROnPlateau调度器
152
+ warmup_scheduler = optim.lr_scheduler.LinearLR(
153
+ optimizer,
154
+ start_factor=0.1, # 从0.1倍的学习率开始
155
+ end_factor=1.0, # 最终达到设定的学习率
156
+ total_iters=num_warmup_steps
157
+ )
158
+
159
+ reduce_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
160
+ optimizer,
161
+ mode='max',
162
+ factor=0.5,
163
+ patience=5,
164
+ verbose=True,
165
+ min_lr=1e-6
166
+ )
167
+
168
+ patience_counter = 0 # 计数器,记录连续没有提升的轮数
169
+ best_params = {
170
+ 'epoch': None,
171
+ 'model_state_dict': None,
172
+ 'optimizer_state_dict': None,
173
+ 'loss': None,
174
+ 'best_val_acc': 0
175
+ }
176
+ for epoch in range(num_epochs):
177
+ # 训练阶段
178
+ model.train()
179
+ total_loss = 0
180
+
181
+ for batch_idx, (images, labels, paths) in enumerate(train_loader):
182
+ images, labels = images.to(device), labels.to(device)
183
+
184
+ optimizer.zero_grad()
185
+ logits = model(images)
186
+ loss = criterion(logits, labels)
187
+ loss.backward()
188
+ optimizer.step()
189
+
190
+ total_loss += loss.item()
191
+
192
+ if batch_idx % 10 == 0:
193
+ print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
194
+
195
+ # 在warmup阶段更新学习率
196
+ if epoch * len(train_loader) + batch_idx < num_warmup_steps:
197
+ warmup_scheduler.step()
198
+
199
+ train_loss = total_loss / len(train_loader)
200
+
201
+ # 验证阶段(只验证已知类别)
202
+ val_loss, val_acc, val_errors = evaluate_known_classes(model, val_loader, criterion, device)
203
+
204
+ # 记录到wandb
205
+ wandb.log({
206
+ 'epoch': epoch,
207
+ 'train_loss': train_loss,
208
+ 'val_loss': val_loss,
209
+ 'val_accuracy': val_acc
210
+ })
211
+
212
+ print(f'Epoch {epoch}:')
213
+ print(f'Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Accuracy = {val_acc:.2f}%')
214
+
215
+ # 验证阶段后更新ReduceLROnPlateau
216
+ reduce_scheduler.step(val_acc)
217
+
218
+ # 打印当前学习率
219
+ current_lr = optimizer.param_groups[0]['lr']
220
+ print(f'Current learning rate: {current_lr:.2e}')
221
+
222
+ # 记录最佳模型(基于验证集准确率)
223
+ if val_acc > best_params['best_val_acc']:
224
+ patience_counter = 0 # 重置计数器
225
+ best_params.update({
226
+ 'epoch': epoch,
227
+ 'model_state_dict': model.state_dict(),
228
+ 'optimizer_state_dict': optimizer.state_dict(),
229
+ 'loss': val_loss,
230
+ 'best_val_acc': val_acc
231
+ })
232
+ else:
233
+ patience_counter += 1 # 增加计数器
234
+ print(f'Validation accuracy did not improve. Patience: {patience_counter}/{patience}')
235
+
236
+ # 早停检查
237
+ if patience_counter >= patience:
238
+ print(f"\nEarly stopping triggered! No improvement for {patience} consecutive epochs.")
239
+ break
240
+
241
+ if val_acc == 100:
242
+ print(f'Achieved 100% accuracy at epoch {epoch}')
243
+ break
244
+
245
+
246
+ # 训练完成后,保存最佳模型的参数
247
+ print("Saving best model parameters...")
248
+ torch.save(best_params, f'models/{model_type}_{best_params["best_val_acc"]:.2f}.pth')
249
+
250
+ # 使用最佳模型收集features
251
+ print("Collecting features from best model for OpenMax/MetaMax training...")
252
+ model.load_state_dict(best_params['model_state_dict'])
253
+ model.eval()
254
+ features, labels = collect_features(model, train_loader, device, return_logits=False)
255
+
256
+ # 训练OpenMax/MetaMax
257
+ openmax = OpenMax(num_classes=20)
258
+ openmax.fit(features, labels)
259
+
260
+ # metamax = MetaMax(num_classes=20)
261
+ # metamax.fit(features, labels)
262
+
263
+ # 保存模型
264
+ torch.save(openmax, 'models/openmax.pth')
265
+ # torch.save(metamax, 'models/metamax.pth')
266
+ print("OpenMax and MetaMax models saved")
267
+ # 在训练完OpenMax后添加评估
268
+ print("Evaluating OpenMax and MetaMax...")
269
+ val_features, val_logits, val_labels = collect_features(model, val_loader, device, return_logits=True)
270
+
271
+ overall_acc, known_acc, unknown_acc = evaluate_openmax(openmax, val_features, val_logits, val_labels, multiplier=0.5)
272
+ print(f"Multiplier: 0.5, Overall Accuracy: {overall_acc:.2f}%")
273
+ # evaluate_metamax(metamax, val_features, val_labels, device)
274
+ wandb.finish()
275
+
276
+ if __name__ == '__main__':
277
+ train(num_epochs=100, batch_size=64, learning_rate=0.001, dropout_rate=0.3, patience=20, model_type='resnet50')