glenn-jocher commited on
Commit
2b64b45
1 Parent(s): fb7fa5b

Fix TFDWConv() `c1 == c2` check (#7842)

Browse files
Files changed (1) hide show
  1. models/tf.py +2 -2
models/tf.py CHANGED
@@ -88,10 +88,10 @@ class TFConv(keras.layers.Layer):
88
 
89
  class TFDWConv(keras.layers.Layer):
90
  # Depthwise convolution
91
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
92
  # ch_in, ch_out, weights, kernel, stride, padding, groups
93
  super().__init__()
94
- assert g == c1 == c2, f'TFDWConv() groups={g} must equal input={c1} and output={c2} channels'
95
  conv = keras.layers.DepthwiseConv2D(
96
  kernel_size=k,
97
  strides=s,
 
88
 
89
  class TFDWConv(keras.layers.Layer):
90
  # Depthwise convolution
91
+ def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
92
  # ch_in, ch_out, weights, kernel, stride, padding, groups
93
  super().__init__()
94
+ assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels'
95
  conv = keras.layers.DepthwiseConv2D(
96
  kernel_size=k,
97
  strides=s,