glenn-jocher
commited on
Commit
•
2b64b45
1
Parent(s):
fb7fa5b
Fix TFDWConv() `c1 == c2` check (#7842)
Browse files- models/tf.py +2 -2
models/tf.py
CHANGED
@@ -88,10 +88,10 @@ class TFConv(keras.layers.Layer):
|
|
88 |
|
89 |
class TFDWConv(keras.layers.Layer):
|
90 |
# Depthwise convolution
|
91 |
-
def __init__(self, c1, c2, k=1, s=1, p=None,
|
92 |
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
93 |
super().__init__()
|
94 |
-
assert
|
95 |
conv = keras.layers.DepthwiseConv2D(
|
96 |
kernel_size=k,
|
97 |
strides=s,
|
|
|
88 |
|
89 |
class TFDWConv(keras.layers.Layer):
|
90 |
# Depthwise convolution
|
91 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
|
92 |
# ch_in, ch_out, weights, kernel, stride, padding, groups
|
93 |
super().__init__()
|
94 |
+
assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels'
|
95 |
conv = keras.layers.DepthwiseConv2D(
|
96 |
kernel_size=k,
|
97 |
strides=s,
|