| import mxnet as mx | |
| import numpy as np | |
| from nowcasting.config import cfg | |
| class ParamsReg(object): | |
| def __init__(self): | |
| self._params = {} | |
| self._old_params = [] | |
| def get(self, name, **kwargs): | |
| if name not in self._params: | |
| self._params[name] = mx.sym.Variable(name, dtype=np.float32, **kwargs) | |
| return self._params[name] | |
| def get_inner(self): | |
| return self._params | |
| def reset(self): | |
| self._old_params.append(self._params) | |
| self._params = {} | |
| _params = ParamsReg() | |
| def reset_regs(): | |
| global _params | |
| _params.reset() | |
| def activation(data, act_type, name=None): | |
| if act_type == "leaky": | |
| if name is None: | |
| act = mx.sym.LeakyReLU(data=data, slope=0.2) | |
| else: | |
| act = mx.sym.LeakyReLU(data=data, slope=0.2, name='%s_%s' %(name, act_type)) | |
| return act | |
| elif act_type == "identity": | |
| act = data | |
| else: | |
| if name is None: | |
| act = mx.sym.Activation(data=data, act_type=act_type) | |
| else: | |
| act = mx.sym.Activation(data=data, act_type=act_type, name='%s_%s' % (name, act_type)) | |
| return act | |
| def conv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), no_bias=False, | |
| name=None, **kwargs): | |
| assert name is not None | |
| global _params | |
| weight = _params.get('%s_weight' % name, **kwargs) | |
| if no_bias: | |
| conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| weight=weight, dilate=dilate, no_bias=True, | |
| pad=pad, name=name, workspace=256) | |
| else: | |
| bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs) | |
| conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| weight=weight, bias=bias, dilate=dilate, no_bias=no_bias, | |
| pad=pad, name=name, workspace=256) | |
| return conv | |
| def conv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), | |
| no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, | |
| name=None, use_global_stats=False, **kwargs): | |
| conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) | |
| assert name is not None | |
| global _params | |
| gamma = _params.get('%s_bn_gamma' % name, **kwargs) | |
| beta = _params.get('%s_bn_beta' % name, **kwargs) | |
| moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) | |
| moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) | |
| if fix_gamma: | |
| bn = mx.sym.BatchNorm(data=conv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=True, | |
| momentum=momentum, | |
| eps=eps, | |
| name='%s_bn' %name, | |
| use_global_stats=use_global_stats) | |
| else: | |
| bn = mx.sym.BatchNorm(data=conv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=False, | |
| momentum=momentum, | |
| eps=eps, | |
| name='%s_bn' % name, | |
| use_global_stats=use_global_stats) | |
| act = activation(bn, act_type=act_type, name=name) | |
| return act | |
| def conv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), | |
| no_bias=False, act_type="relu", name=None, **kwargs): | |
| conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) | |
| act = activation(conv, act_type=act_type, name=name) | |
| return act | |
| def deconv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), no_bias=True, | |
| target_shape=None, name="deconv2d", **kwargs): | |
| global _params | |
| assert name is not None | |
| weight = _params.get('%s_weight' % name, **kwargs) | |
| if no_bias: | |
| if target_shape is None: | |
| deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, | |
| stride=stride, | |
| no_bias=True, | |
| weight=weight, pad=pad, name=name) | |
| else: | |
| deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, | |
| stride=stride, | |
| target_shape=target_shape, no_bias=True, | |
| weight=weight, pad=pad, name=name) | |
| else: | |
| bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs) | |
| if target_shape is None: | |
| deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, | |
| stride=stride, | |
| no_bias=no_bias, | |
| weight=weight, bias=bias, pad=pad, name=name) | |
| else: | |
| deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj, | |
| stride=stride, | |
| target_shape=target_shape, no_bias=no_bias, | |
| weight=weight, bias=bias, pad=pad, name=name) | |
| return deconv | |
| def deconv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), | |
| no_bias=True, target_shape=None, act_type="relu", | |
| momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, | |
| name="deconv2d", use_global_stats=False, **kwargs): | |
| global _params | |
| deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) | |
| gamma = _params.get('%s_bn_gamma' % name, **kwargs) | |
| beta = _params.get('%s_bn_beta' % name, **kwargs) | |
| moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) | |
| moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) | |
| if fix_gamma: | |
| bn = mx.sym.BatchNorm(data=deconv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=True, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' %name) | |
| else: | |
| bn = mx.sym.BatchNorm(data=deconv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=False, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' % name) | |
| act = activation(bn, act_type=act_type, name=name) | |
| return act | |
| def deconv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), | |
| no_bias=True, target_shape=None, act_type="relu", name="deconv2d", **kwargs): | |
| deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) | |
| act = activation(deconv, act_type=act_type, name=name) | |
| return act | |
| def conv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), no_bias=False, | |
| name=None, **kwargs): | |
| return conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) | |
| def conv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), | |
| dilate=(1, 1, 1), no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, | |
| fix_gamma=True, name=None, use_global_stats=False, **kwargs): | |
| conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) | |
| assert name is not None | |
| global _params | |
| gamma = _params.get('%s_bn_gamma' % name, **kwargs) | |
| beta = _params.get('%s_bn_beta' % name, **kwargs) | |
| moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) | |
| moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) | |
| conv = mx.symbol.reshape(conv, shape=(0, 0, -1, width)) | |
| if fix_gamma: | |
| bn = mx.sym.BatchNorm(data=conv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=True, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' %name) | |
| else: | |
| bn = mx.sym.BatchNorm(data=conv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=False, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' % name) | |
| bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width)) | |
| act = activation(bn, act_type=act_type, name=name) | |
| return act | |
| def conv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), | |
| no_bias=False, act_type="relu", name=None, **kwargs): | |
| conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs) | |
| act = activation(conv, act_type=act_type, name=name) | |
| return act | |
| def deconv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), no_bias=True, | |
| target_shape=None, name=None, **kwargs): | |
| return deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj, | |
| no_bias=no_bias, target_shape=target_shape, name=name, **kwargs) | |
| def deconv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), | |
| adj=(0, 0, 0), no_bias=True, target_shape=None, act_type="relu", | |
| momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs): | |
| global _params | |
| deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) | |
| gamma = _params.get('%s_bn_gamma' % name, **kwargs) | |
| beta = _params.get('%s_bn_beta' % name, **kwargs) | |
| moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) | |
| moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) | |
| deconv = mx.symbol.reshape(deconv, shape=(0, 0, -1, width)) | |
| if fix_gamma: | |
| bn = mx.sym.BatchNorm(data=deconv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=True, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' %name) | |
| else: | |
| bn = mx.sym.BatchNorm(data=deconv, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=False, | |
| momentum=momentum, | |
| eps=eps, | |
| use_global_stats=use_global_stats, | |
| name='%s_bn' % name) | |
| bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width)) | |
| act = activation(bn, act_type=act_type, name=name) | |
| return act | |
| def deconv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), | |
| no_bias=True, target_shape=None, act_type="relu", name=None, **kwargs): | |
| deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, | |
| pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs) | |
| act = activation(deconv, act_type=act_type, name=name) | |
| return act | |
| def fc_layer(data, num_hidden, no_bias=False, name="fc", **kwargs): | |
| assert name is not None | |
| global _params | |
| weight = _params.get('%s_weight' % name, **kwargs) | |
| if not no_bias: | |
| bias = _params.get('%s_bias' % name, **kwargs) | |
| fc = mx.sym.FullyConnected(data=data, weight=weight, bias=bias, | |
| num_hidden=num_hidden, no_bias=False, name=name, **kwargs) | |
| else: | |
| fc = mx.sym.FullyConnected(data=data, weight=weight, | |
| num_hidden=num_hidden, no_bias=True, name=name, **kwargs) | |
| return fc | |
| def fc_layer_act(data, num_hidden, no_bias=False, act_type="relu", name="fc", **kwargs): | |
| fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs) | |
| act = activation(data=fc, act_type=act_type, name=name) | |
| return act | |
| def fc_layer_bn_act(data, num_hidden, no_bias=False, act_type="relu", | |
| momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, | |
| use_global_stats=False, **kwargs): | |
| fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs) | |
| assert name is not None | |
| global _params | |
| gamma = _params.get('%s_bn_gamma' % name, **kwargs) | |
| beta = _params.get('%s_bn_beta' % name, **kwargs) | |
| moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs) | |
| moving_var = _params.get('%s_bn_moving_var' % name, **kwargs) | |
| if fix_gamma: | |
| bn = mx.sym.BatchNorm(data=fc, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=True, | |
| momentum=momentum, | |
| eps=eps, | |
| name='%s_bn' %name, | |
| use_global_stats=use_global_stats) | |
| else: | |
| bn = mx.sym.BatchNorm(data=fc, | |
| beta=beta, | |
| gamma=gamma, | |
| moving_mean=moving_mean, | |
| moving_var=moving_var, | |
| fix_gamma=False, | |
| momentum=momentum, | |
| eps=eps, | |
| name='%s_bn' % name, | |
| use_global_stats=use_global_stats) | |
| act = activation(bn, act_type=act_type, name=name) | |
| return act | |
| def downsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None): | |
| assert isinstance(data, list) | |
| data = mx.sym.concat(*data, dim=0) | |
| ret = conv2d_act(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, | |
| act_type=cfg.MODEL.CNN_ACT_TYPE, name=name + "_conv") | |
| return ret | |
| def upsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None): | |
| assert isinstance(data, list) | |
| data = mx.sym.concat(*data, dim=0) | |
| ret = deconv2d_act(data=data, | |
| num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, | |
| act_type=cfg.MODEL.CNN_ACT_TYPE, | |
| name=name + "_deconv") | |
| return ret | |