glenn-jocher commited on
Commit
140d84c
1 Parent(s): ea34f84

comment updates

Browse files
Files changed (2) hide show
  1. train.py +2 -2
  2. utils/utils.py +2 -5
train.py CHANGED
@@ -152,13 +152,13 @@ def train(hyp):
152
  model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
153
 
154
  # Distributed training
155
- if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
156
  dist.init_process_group(backend='nccl', # distributed backend
157
  init_method='tcp://127.0.0.1:9999', # init method
158
  world_size=1, # number of nodes
159
  rank=0) # node rank
 
160
  model = torch.nn.parallel.DistributedDataParallel(model)
161
- # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html
162
 
163
  # Trainloader
164
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
 
152
  model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
153
 
154
  # Distributed training
155
+ if device.type != 'cpu' and torch.cuda.device_count() > 1 and dist.is_available():
156
  dist.init_process_group(backend='nccl', # distributed backend
157
  init_method='tcp://127.0.0.1:9999', # init method
158
  world_size=1, # number of nodes
159
  rank=0) # node rank
160
+ # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) # requires world_size > 1
161
  model = torch.nn.parallel.DistributedDataParallel(model)
 
162
 
163
  # Trainloader
164
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
utils/utils.py CHANGED
@@ -503,6 +503,7 @@ def build_targets(p, targets, model):
503
  off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
504
  at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
505
 
 
506
  style = 'rect4'
507
  for i in range(det.nl):
508
  anchors = det.anchors[i]
@@ -517,7 +518,6 @@ def build_targets(p, targets, model):
517
  a, t = at[j], t.repeat(na, 1, 1)[j] # filter
518
 
519
  # overlaps
520
- g = 0.5 # offset
521
  gxy = t[:, 2:4] # grid xy
522
  z = torch.zeros_like(gxy)
523
  if style == 'rect2':
@@ -878,10 +878,7 @@ def fitness(x):
878
 
879
 
880
  def output_to_target(output, width, height):
881
- """
882
- Convert a YOLO model output to target format
883
- [batch_id, class_id, x, y, w, h, conf]
884
- """
885
  if isinstance(output, torch.Tensor):
886
  output = output.cpu().numpy()
887
 
 
503
  off = torch.tensor([[1, 0], [0, 1], [-1, 0], [0, -1]], device=targets.device).float() # overlap offsets
504
  at = torch.arange(na).view(na, 1).repeat(1, nt) # anchor tensor, same as .repeat_interleave(nt)
505
 
506
+ g = 0.5 # offset
507
  style = 'rect4'
508
  for i in range(det.nl):
509
  anchors = det.anchors[i]
 
518
  a, t = at[j], t.repeat(na, 1, 1)[j] # filter
519
 
520
  # overlaps
 
521
  gxy = t[:, 2:4] # grid xy
522
  z = torch.zeros_like(gxy)
523
  if style == 'rect2':
 
878
 
879
 
880
  def output_to_target(output, width, height):
881
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
 
 
 
882
  if isinstance(output, torch.Tensor):
883
  output = output.cpu().numpy()
884