glenn-jocher
commited on
Merge pull request #245 from yxNONG/patch-2
Browse filesUnify the check point of single and multi GPU
- train.py +2 -2
- utils/torch_utils.py +13 -7
train.py
CHANGED
@@ -79,7 +79,6 @@ def train(hyp):
|
|
79 |
# Create model
|
80 |
model = Model(opt.cfg).to(device)
|
81 |
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
82 |
-
model.names = data_dict['names']
|
83 |
|
84 |
# Image sizes
|
85 |
gs = int(max(model.stride)) # grid size (max stride)
|
@@ -178,6 +177,7 @@ def train(hyp):
|
|
178 |
model.hyp = hyp # attach hyperparameters to model
|
179 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
180 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
|
|
181 |
|
182 |
# Class frequency
|
183 |
labels = np.concatenate(dataset.labels, 0)
|
@@ -294,7 +294,7 @@ def train(hyp):
|
|
294 |
batch_size=batch_size,
|
295 |
imgsz=imgsz_test,
|
296 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
297 |
-
model=ema.ema,
|
298 |
single_cls=opt.single_cls,
|
299 |
dataloader=testloader)
|
300 |
|
|
|
79 |
# Create model
|
80 |
model = Model(opt.cfg).to(device)
|
81 |
assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
|
|
|
82 |
|
83 |
# Image sizes
|
84 |
gs = int(max(model.stride)) # grid size (max stride)
|
|
|
177 |
model.hyp = hyp # attach hyperparameters to model
|
178 |
model.gr = 1.0 # giou loss ratio (obj_loss = 1.0 or giou)
|
179 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
180 |
+
model.names = data_dict['names']
|
181 |
|
182 |
# Class frequency
|
183 |
labels = np.concatenate(dataset.labels, 0)
|
|
|
294 |
batch_size=batch_size,
|
295 |
imgsz=imgsz_test,
|
296 |
save_json=final_epoch and opt.data.endswith(os.sep + 'coco.yaml'),
|
297 |
+
model=ema.ema.module if hasattr(model, 'module') else ema.ema,
|
298 |
single_cls=opt.single_cls,
|
299 |
dataloader=testloader)
|
300 |
|
utils/torch_utils.py
CHANGED
@@ -54,6 +54,11 @@ def time_synchronized():
|
|
54 |
return time.time()
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
def initialize_weights(model):
|
58 |
for m in model.modules():
|
59 |
t = type(m)
|
@@ -111,8 +116,8 @@ def model_info(model, verbose=False):
|
|
111 |
|
112 |
try: # FLOPS
|
113 |
from thop import profile
|
114 |
-
|
115 |
-
fs = ', %.1f GFLOPS' % (
|
116 |
except:
|
117 |
fs = ''
|
118 |
|
@@ -185,7 +190,7 @@ class ModelEMA:
|
|
185 |
self.updates += 1
|
186 |
d = self.decay(self.updates)
|
187 |
with torch.no_grad():
|
188 |
-
if
|
189 |
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
190 |
else:
|
191 |
msd, esd = model.state_dict(), self.ema.state_dict()
|
@@ -196,7 +201,8 @@ class ModelEMA:
|
|
196 |
v += (1. - d) * msd[k].detach()
|
197 |
|
198 |
def update_attr(self, model):
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
54 |
return time.time()
|
55 |
|
56 |
|
57 |
+
def is_parallel(model):
|
58 |
+
# is model is parallel with DP or DDP
|
59 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
60 |
+
|
61 |
+
|
62 |
def initialize_weights(model):
|
63 |
for m in model.modules():
|
64 |
t = type(m)
|
|
|
116 |
|
117 |
try: # FLOPS
|
118 |
from thop import profile
|
119 |
+
flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, 64, 64),), verbose=False)[0] / 1E9 * 2
|
120 |
+
fs = ', %.1f GFLOPS' % (flops * 100) # 640x640 FLOPS
|
121 |
except:
|
122 |
fs = ''
|
123 |
|
|
|
190 |
self.updates += 1
|
191 |
d = self.decay(self.updates)
|
192 |
with torch.no_grad():
|
193 |
+
if is_parallel(model):
|
194 |
msd, esd = model.module.state_dict(), self.ema.module.state_dict()
|
195 |
else:
|
196 |
msd, esd = model.state_dict(), self.ema.state_dict()
|
|
|
201 |
v += (1. - d) * msd[k].detach()
|
202 |
|
203 |
def update_attr(self, model):
|
204 |
+
# Update class attributes
|
205 |
+
ema = self.ema.module if is_parallel(model) else self.ema
|
206 |
+
for k, v in model.__dict__.items():
|
207 |
+
if not k.startswith('_') and k != 'module':
|
208 |
+
setattr(ema, k, v)
|