glenn-jocher commited on
Commit
394d1c8
·
unverified ·
1 Parent(s): ab0db8d

Input channel yaml['ch'] addition (#1741)

Browse files
Files changed (2) hide show
  1. models/yolo.py +3 -2
  2. utils/torch_utils.py +1 -1
models/yolo.py CHANGED
@@ -1,10 +1,10 @@
1
  import argparse
2
  import logging
 
3
  import sys
4
  from copy import deepcopy
5
  from pathlib import Path
6
 
7
- import math
8
  import torch
9
  import torch.nn as nn
10
 
@@ -78,10 +78,11 @@ class Model(nn.Module):
78
  self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
79
 
80
  # Define model
 
81
  if nc and nc != self.yaml['nc']:
82
  logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
83
  self.yaml['nc'] = nc # override yaml value
84
- self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out
85
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
86
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
87
 
 
1
  import argparse
2
  import logging
3
+ import math
4
  import sys
5
  from copy import deepcopy
6
  from pathlib import Path
7
 
 
8
  import torch
9
  import torch.nn as nn
10
 
 
78
  self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
79
 
80
  # Define model
81
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
82
  if nc and nc != self.yaml['nc']:
83
  logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
84
  self.yaml['nc'] = nc # override yaml value
85
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
86
  self.names = [str(i) for i in range(self.yaml['nc'])] # default names
87
  # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
88
 
utils/torch_utils.py CHANGED
@@ -196,7 +196,7 @@ def model_info(model, verbose=False, img_size=640):
196
  try: # FLOPS
197
  from thop import profile
198
  stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
199
- img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
200
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
201
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
202
  fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
 
196
  try: # FLOPS
197
  from thop import profile
198
  stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
199
+ img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
200
  flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
201
  img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
202
  fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS