sqfoo's picture
Upload 99 files
6021dd1 verified
import mxnet as mx
import numpy as np
from nowcasting.config import cfg
class ParamsReg(object):
def __init__(self):
self._params = {}
self._old_params = []
def get(self, name, **kwargs):
if name not in self._params:
self._params[name] = mx.sym.Variable(name, dtype=np.float32, **kwargs)
return self._params[name]
def get_inner(self):
return self._params
def reset(self):
self._old_params.append(self._params)
self._params = {}
_params = ParamsReg()
def reset_regs():
global _params
_params.reset()
def activation(data, act_type, name=None):
if act_type == "leaky":
if name is None:
act = mx.sym.LeakyReLU(data=data, slope=0.2)
else:
act = mx.sym.LeakyReLU(data=data, slope=0.2, name='%s_%s' %(name, act_type))
return act
elif act_type == "identity":
act = data
else:
if name is None:
act = mx.sym.Activation(data=data, act_type=act_type)
else:
act = mx.sym.Activation(data=data, act_type=act_type, name='%s_%s' % (name, act_type))
return act
def conv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1), no_bias=False,
name=None, **kwargs):
assert name is not None
global _params
weight = _params.get('%s_weight' % name, **kwargs)
if no_bias:
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
weight=weight, dilate=dilate, no_bias=True,
pad=pad, name=name, workspace=256)
else:
bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs)
conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
weight=weight, bias=bias, dilate=dilate, no_bias=no_bias,
pad=pad, name=name, workspace=256)
return conv
def conv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1),
no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True,
name=None, use_global_stats=False, **kwargs):
conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs)
assert name is not None
global _params
gamma = _params.get('%s_bn_gamma' % name, **kwargs)
beta = _params.get('%s_bn_beta' % name, **kwargs)
moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs)
moving_var = _params.get('%s_bn_moving_var' % name, **kwargs)
if fix_gamma:
bn = mx.sym.BatchNorm(data=conv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
momentum=momentum,
eps=eps,
name='%s_bn' %name,
use_global_stats=use_global_stats)
else:
bn = mx.sym.BatchNorm(data=conv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
momentum=momentum,
eps=eps,
name='%s_bn' % name,
use_global_stats=use_global_stats)
act = activation(bn, act_type=act_type, name=name)
return act
def conv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), dilate=(1, 1),
no_bias=False, act_type="relu", name=None, **kwargs):
conv = conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs)
act = activation(conv, act_type=act_type, name=name)
return act
def deconv2d(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0), no_bias=True,
target_shape=None, name="deconv2d", **kwargs):
global _params
assert name is not None
weight = _params.get('%s_weight' % name, **kwargs)
if no_bias:
if target_shape is None:
deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj,
stride=stride,
no_bias=True,
weight=weight, pad=pad, name=name)
else:
deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj,
stride=stride,
target_shape=target_shape, no_bias=True,
weight=weight, pad=pad, name=name)
else:
bias = _params.get('%s_bias' % name, wd_mult=0.0, **kwargs)
if target_shape is None:
deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj,
stride=stride,
no_bias=no_bias,
weight=weight, bias=bias, pad=pad, name=name)
else:
deconv = mx.sym.Deconvolution(data=data, num_filter=num_filter, kernel=kernel, adj=adj,
stride=stride,
target_shape=target_shape, no_bias=no_bias,
weight=weight, bias=bias, pad=pad, name=name)
return deconv
def deconv2d_bn_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0),
no_bias=True, target_shape=None, act_type="relu",
momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True,
name="deconv2d", use_global_stats=False, **kwargs):
global _params
deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs)
gamma = _params.get('%s_bn_gamma' % name, **kwargs)
beta = _params.get('%s_bn_beta' % name, **kwargs)
moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs)
moving_var = _params.get('%s_bn_moving_var' % name, **kwargs)
if fix_gamma:
bn = mx.sym.BatchNorm(data=deconv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' %name)
else:
bn = mx.sym.BatchNorm(data=deconv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' % name)
act = activation(bn, act_type=act_type, name=name)
return act
def deconv2d_act(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), adj=(0, 0),
no_bias=True, target_shape=None, act_type="relu", name="deconv2d", **kwargs):
deconv = deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs)
act = activation(deconv, act_type=act_type, name=name)
return act
def conv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1), no_bias=False,
name=None, **kwargs):
return conv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs)
def conv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0),
dilate=(1, 1, 1), no_bias=False, act_type="relu", momentum=0.9, eps=1e-5 + 1e-12,
fix_gamma=True, name=None, use_global_stats=False, **kwargs):
conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs)
assert name is not None
global _params
gamma = _params.get('%s_bn_gamma' % name, **kwargs)
beta = _params.get('%s_bn_beta' % name, **kwargs)
moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs)
moving_var = _params.get('%s_bn_moving_var' % name, **kwargs)
conv = mx.symbol.reshape(conv, shape=(0, 0, -1, width))
if fix_gamma:
bn = mx.sym.BatchNorm(data=conv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' %name)
else:
bn = mx.sym.BatchNorm(data=conv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' % name)
bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width))
act = activation(bn, act_type=act_type, name=name)
return act
def conv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), dilate=(1, 1, 1),
no_bias=False, act_type="relu", name=None, **kwargs):
conv = conv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, dilate=dilate, no_bias=no_bias, name=name, **kwargs)
act = activation(conv, act_type=act_type, name=name)
return act
def deconv3d(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0), no_bias=True,
target_shape=None, name=None, **kwargs):
return deconv2d(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, adj=adj,
no_bias=no_bias, target_shape=target_shape, name=name, **kwargs)
def deconv3d_bn_act(data, num_filter, height, width, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0),
adj=(0, 0, 0), no_bias=True, target_shape=None, act_type="relu",
momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None, use_global_stats=False, **kwargs):
global _params
deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs)
gamma = _params.get('%s_bn_gamma' % name, **kwargs)
beta = _params.get('%s_bn_beta' % name, **kwargs)
moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs)
moving_var = _params.get('%s_bn_moving_var' % name, **kwargs)
deconv = mx.symbol.reshape(deconv, shape=(0, 0, -1, width))
if fix_gamma:
bn = mx.sym.BatchNorm(data=deconv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' %name)
else:
bn = mx.sym.BatchNorm(data=deconv,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
momentum=momentum,
eps=eps,
use_global_stats=use_global_stats,
name='%s_bn' % name)
bn = mx.symbol.reshape(bn, shape=(0, 0, -1, height, width))
act = activation(bn, act_type=act_type, name=name)
return act
def deconv3d_act(data, num_filter, kernel=(1, 1, 1), stride=(1, 1, 1), pad=(0, 0, 0), adj=(0, 0, 0),
no_bias=True, target_shape=None, act_type="relu", name=None, **kwargs):
deconv = deconv3d(data=data, num_filter=num_filter, kernel=kernel, stride=stride,
pad=pad, adj=adj, target_shape=target_shape, no_bias=no_bias, name=name, **kwargs)
act = activation(deconv, act_type=act_type, name=name)
return act
def fc_layer(data, num_hidden, no_bias=False, name="fc", **kwargs):
assert name is not None
global _params
weight = _params.get('%s_weight' % name, **kwargs)
if not no_bias:
bias = _params.get('%s_bias' % name, **kwargs)
fc = mx.sym.FullyConnected(data=data, weight=weight, bias=bias,
num_hidden=num_hidden, no_bias=False, name=name, **kwargs)
else:
fc = mx.sym.FullyConnected(data=data, weight=weight,
num_hidden=num_hidden, no_bias=True, name=name, **kwargs)
return fc
def fc_layer_act(data, num_hidden, no_bias=False, act_type="relu", name="fc", **kwargs):
fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs)
act = activation(data=fc, act_type=act_type, name=name)
return act
def fc_layer_bn_act(data, num_hidden, no_bias=False, act_type="relu",
momentum=0.9, eps=1e-5 + 1e-12, fix_gamma=True, name=None,
use_global_stats=False, **kwargs):
fc = fc_layer(data=data, num_hidden=num_hidden, no_bias=no_bias, name=name, **kwargs)
assert name is not None
global _params
gamma = _params.get('%s_bn_gamma' % name, **kwargs)
beta = _params.get('%s_bn_beta' % name, **kwargs)
moving_mean = _params.get('%s_bn_moving_mean' % name, **kwargs)
moving_var = _params.get('%s_bn_moving_var' % name, **kwargs)
if fix_gamma:
bn = mx.sym.BatchNorm(data=fc,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=True,
momentum=momentum,
eps=eps,
name='%s_bn' %name,
use_global_stats=use_global_stats)
else:
bn = mx.sym.BatchNorm(data=fc,
beta=beta,
gamma=gamma,
moving_mean=moving_mean,
moving_var=moving_var,
fix_gamma=False,
momentum=momentum,
eps=eps,
name='%s_bn' % name,
use_global_stats=use_global_stats)
act = activation(bn, act_type=act_type, name=name)
return act
def downsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None):
assert isinstance(data, list)
data = mx.sym.concat(*data, dim=0)
ret = conv2d_act(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad,
act_type=cfg.MODEL.CNN_ACT_TYPE, name=name + "_conv")
return ret
def upsample_module(data, num_filter, kernel, stride, pad, b_h_w, name, aggre_type=None):
assert isinstance(data, list)
data = mx.sym.concat(*data, dim=0)
ret = deconv2d_act(data=data,
num_filter=num_filter, kernel=kernel, stride=stride, pad=pad,
act_type=cfg.MODEL.CNN_ACT_TYPE,
name=name + "_deconv")
return ret