glenn-jocher commited on
Commit
5bee686
1 Parent(s): 715cb08

model definition update

Browse files
Files changed (1) hide show
  1. models/yolo.py +4 -2
models/yolo.py CHANGED
@@ -45,13 +45,15 @@ class Detect(nn.Module):
45
 
46
 
47
  class Model(nn.Module):
48
- def __init__(self, model_yaml='yolov5s.yaml'): # cfg, number of classes, depth-width gains
49
  super(Model, self).__init__()
50
  with open(model_yaml) as f:
51
  self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
 
 
52
 
53
  # Define model
54
- self.model, self.save, ch = parse_model(self.md, ch=[3]) # model, savelist, ch_out
55
  # print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
56
 
57
  # Build strides, anchors
 
45
 
46
 
47
  class Model(nn.Module):
48
+ def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
49
  super(Model, self).__init__()
50
  with open(model_yaml) as f:
51
  self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
52
+ if nc:
53
+ self.md['nc'] = nc # override yaml value
54
 
55
  # Define model
56
+ self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
57
  # print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
58
 
59
  # Build strides, anchors