WangA commited on
Commit
f7c1ebf
1 Parent(s): 14b2abc

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +9 -12
train.py CHANGED
@@ -96,7 +96,7 @@ class BaseDset(object):
96
  self.__test_keys = []
97
 
98
  def load(self, base_path):
99
- """加载数据集,将类别和路径存储"""
100
  self.__base_path = base_path
101
  train_dir = os.path.join(self.__base_path, 'train')
102
  test_dir = os.path.join(self.__base_path, 'test')
@@ -107,17 +107,14 @@ class BaseDset(object):
107
  self.__test_keys = []
108
 
109
  for class_id in os.listdir(train_dir):
110
- # 对于train_dir里的每个文件夹名字 classi
111
  class_dir = os.path.join(train_dir, class_id)
112
- # 为其在训练集合中创建一个文件夹
113
- # 在类别集合中,即train_keys中添加类别classi
114
  self.__train_set[class_id] = []
115
  self.__train_keys.append(class_id)
116
- # 对于每个类别内的数据,将其路径添加到集合中
117
  for img_name in os.listdir(class_dir):
118
  img_path = os.path.join(class_dir, img_name)
119
  self.__train_set[class_id].append(img_path)
120
- # 同理对于测试集合也一样
121
  for class_id in os.listdir(test_dir):
122
  class_dir = os.path.join(test_dir, class_id)
123
  self.__test_set[class_id] = []
@@ -128,7 +125,7 @@ class BaseDset(object):
128
 
129
  return len(self.__train_keys), len(self.__test_keys)
130
 
131
- # 获取三元组 !!!
132
  def getTriplet(self, split='train'):
133
  # 默认选取训练集
134
  if split == 'train':
@@ -167,11 +164,11 @@ def train(data, model, criterion, optimizer, epoch):
167
  total_loss = 0
168
  model.train()
169
  for batch_idx, img_triplet in enumerate(data):
170
- # 提取数据
171
  anchor_img, pos_img, neg_img = img_triplet
172
  anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device)
173
  anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
174
- # 分别获得三个编码
175
  E1, E2, E3 = model(anchor_img, pos_img, neg_img)
176
  # 计算二者之间的欧式距离
177
  dist_E1_E2 = F.pairwise_distance(E1, E2, 2)
@@ -180,14 +177,14 @@ def train(data, model, criterion, optimizer, epoch):
180
  target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1)
181
  target = target.to(device)
182
  target = Variable(target)
183
- # 大小如何?
184
  loss = criterion(dist_E1_E2, dist_E1_E3, target)
185
  total_loss += loss
186
 
187
  optimizer.zero_grad()
188
  loss.backward()
189
  optimizer.step()
190
- # 打印一波损失
191
  log_step = args.train_log_step
192
  if (batch_idx % log_step == 0) and (batch_idx != 0):
193
  print('Train Epoch: {} [{}/{}] \t Loss: {:.4f}'.format(epoch, batch_idx, len(data), total_loss / log_step))
@@ -225,7 +222,7 @@ def test(data, model, criterion):
225
  accuracies[i] += batch_acc
226
  print('Test Loss: {}'.format(total_loss / len(data)))
227
  for i in range(len(accuracies)):
228
- # 0%等价于准确率其余是更严格的指标
229
  print(
230
  'Test Accuracy with diff = {}% of margin: {:.4f}'.format(acc_threshes[i] * 100,
231
  accuracies[i] / len(data)))
 
96
  self.__test_keys = []
97
 
98
  def load(self, base_path):
99
+ """加载训练和测试数据集,将类别和路径存储"""
100
  self.__base_path = base_path
101
  train_dir = os.path.join(self.__base_path, 'train')
102
  test_dir = os.path.join(self.__base_path, 'test')
 
107
  self.__test_keys = []
108
 
109
  for class_id in os.listdir(train_dir):
 
110
  class_dir = os.path.join(train_dir, class_id)
 
 
111
  self.__train_set[class_id] = []
112
  self.__train_keys.append(class_id)
113
+
114
  for img_name in os.listdir(class_dir):
115
  img_path = os.path.join(class_dir, img_name)
116
  self.__train_set[class_id].append(img_path)
117
+
118
  for class_id in os.listdir(test_dir):
119
  class_dir = os.path.join(test_dir, class_id)
120
  self.__test_set[class_id] = []
 
125
 
126
  return len(self.__train_keys), len(self.__test_keys)
127
 
128
+ # 获取三元组
129
  def getTriplet(self, split='train'):
130
  # 默认选取训练集
131
  if split == 'train':
 
164
  total_loss = 0
165
  model.train()
166
  for batch_idx, img_triplet in enumerate(data):
167
+ # 提取三元组数据
168
  anchor_img, pos_img, neg_img = img_triplet
169
  anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device)
170
  anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img)
171
+ # 分别获得三个编码,表示原始样本、正样本、负样本
172
  E1, E2, E3 = model(anchor_img, pos_img, neg_img)
173
  # 计算二者之间的欧式距离
174
  dist_E1_E2 = F.pairwise_distance(E1, E2, 2)
 
177
  target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1)
178
  target = target.to(device)
179
  target = Variable(target)
180
+
181
  loss = criterion(dist_E1_E2, dist_E1_E3, target)
182
  total_loss += loss
183
 
184
  optimizer.zero_grad()
185
  loss.backward()
186
  optimizer.step()
187
+ # 打印损失
188
  log_step = args.train_log_step
189
  if (batch_idx % log_step == 0) and (batch_idx != 0):
190
  print('Train Epoch: {} [{}/{}] \t Loss: {:.4f}'.format(epoch, batch_idx, len(data), total_loss / log_step))
 
222
  accuracies[i] += batch_acc
223
  print('Test Loss: {}'.format(total_loss / len(data)))
224
  for i in range(len(accuracies)):
225
+ # 0%等价于准确率,其余是更严格的指标
226
  print(
227
  'Test Accuracy with diff = {}% of margin: {:.4f}'.format(acc_threshes[i] * 100,
228
  accuracies[i] / len(data)))