glenn-jocher
commited on
Input channel yaml['ch'] addition (#1741)
Browse files- models/yolo.py +3 -2
- utils/torch_utils.py +1 -1
models/yolo.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import argparse
|
2 |
import logging
|
|
|
3 |
import sys
|
4 |
from copy import deepcopy
|
5 |
from pathlib import Path
|
6 |
|
7 |
-
import math
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
|
@@ -78,10 +78,11 @@ class Model(nn.Module):
|
|
78 |
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
79 |
|
80 |
# Define model
|
|
|
81 |
if nc and nc != self.yaml['nc']:
|
82 |
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
|
83 |
self.yaml['nc'] = nc # override yaml value
|
84 |
-
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
85 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
86 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
87 |
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
+
import math
|
4 |
import sys
|
5 |
from copy import deepcopy
|
6 |
from pathlib import Path
|
7 |
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
|
|
|
78 |
self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
79 |
|
80 |
# Define model
|
81 |
+
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
82 |
if nc and nc != self.yaml['nc']:
|
83 |
logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc))
|
84 |
self.yaml['nc'] = nc # override yaml value
|
85 |
+
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
|
86 |
self.names = [str(i) for i in range(self.yaml['nc'])] # default names
|
87 |
# print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
|
88 |
|
utils/torch_utils.py
CHANGED
@@ -196,7 +196,7 @@ def model_info(model, verbose=False, img_size=640):
|
|
196 |
try: # FLOPS
|
197 |
from thop import profile
|
198 |
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
199 |
-
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device) # input
|
200 |
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
201 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
202 |
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|
|
|
196 |
try: # FLOPS
|
197 |
from thop import profile
|
198 |
stride = int(model.stride.max()) if hasattr(model, 'stride') else 32
|
199 |
+
img = torch.zeros((1, model.yaml.get('ch', 3), stride, stride), device=next(model.parameters()).device) # input
|
200 |
flops = profile(deepcopy(model), inputs=(img,), verbose=False)[0] / 1E9 * 2 # stride FLOPS
|
201 |
img_size = img_size if isinstance(img_size, list) else [img_size, img_size] # expand if int/float
|
202 |
fs = ', %.1f GFLOPS' % (flops * img_size[0] / stride * img_size[1] / stride) # 640x640 FLOPS
|