glenn-jocher
commited on
Commit
•
5bee686
1
Parent(s):
715cb08
model definition update
Browse files- models/yolo.py +4 -2
models/yolo.py
CHANGED
@@ -45,13 +45,15 @@ class Detect(nn.Module):
|
|
45 |
|
46 |
|
47 |
class Model(nn.Module):
|
48 |
-
def __init__(self, model_yaml='yolov5s.yaml'): #
|
49 |
super(Model, self).__init__()
|
50 |
with open(model_yaml) as f:
|
51 |
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
|
|
|
|
52 |
|
53 |
# Define model
|
54 |
-
self.model, self.save, ch = parse_model(self.md, ch=[
|
55 |
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
|
56 |
|
57 |
# Build strides, anchors
|
|
|
45 |
|
46 |
|
47 |
class Model(nn.Module):
|
48 |
+
def __init__(self, model_yaml='yolov5s.yaml', ch=3, nc=None): # model, input channels, number of classes
|
49 |
super(Model, self).__init__()
|
50 |
with open(model_yaml) as f:
|
51 |
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
52 |
+
if nc:
|
53 |
+
self.md['nc'] = nc # override yaml value
|
54 |
|
55 |
# Define model
|
56 |
+
self.model, self.save, ch = parse_model(self.md, ch=[ch]) # model, savelist, ch_out
|
57 |
# print([x.shape for x in self.forward(torch.zeros(1, 3, 64, 64))])
|
58 |
|
59 |
# Build strides, anchors
|