using xnnpack_operator = at::native::xnnpack::Operator; | |
namespace at { | |
namespace native { | |
namespace xnnp_utils { | |
/* | |
* Return shape in the same order as the memory format | |
* e.g. channels_last will return NHWC instead of NCHW | |
*/ | |
std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in); | |
/* | |
* Input is always int8_t, output can be [int8_t, uint8_t]. | |
* input + offset = output | |
* int8_t + 128 = uint8_t | |
* int8_t + 0 = int8_t | |
*/ | |
template <typename PT> | |
void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); | |
template <int kSpatialDim> | |
Tensor convert_conv_weights_to_channel_last_tensor( | |
const at::Tensor& src, | |
int groups, | |
bool transpose); | |
/* | |
* Series of create wrapper functions to call xnn_create_[de]conv* functions. | |
*/ | |
C10_ALWAYS_INLINE | |
enum xnn_status xnnp_create_convolution2d_nhwc( | |
uint32_t pad_top, | |
uint32_t pad_right, | |
uint32_t pad_bottom, | |
uint32_t pad_left, | |
uint32_t kernel_h, | |
uint32_t kernel_w, | |
uint32_t stride_h, | |
uint32_t stride_w, | |
uint32_t dilation_h, | |
uint32_t dilation_w, | |
uint32_t groups, | |
size_t group_input_channels, | |
size_t group_output_channels, | |
size_t ip_chan_stride, | |
size_t op_chan_stride, | |
int8_t izp, | |
float ip_scale, | |
int8_t kzp, | |
const float* k_scales, | |
const int8_t* kernel, | |
const int32_t* bias, | |
int8_t ozp, | |
float op_scale, | |
int8_t op_min, | |
int8_t op_max, | |
uint32_t flags, | |
xnn_operator_t* op, | |
bool per_channel, | |
bool transpose) { | |
/* Symmetric quantization forces kzp = 0 */ | |
TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." | |
"But got: ", kzp); | |
if (transpose) { | |
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); | |
return xnn_create_deconvolution2d_nhwc_qs8( | |
pad_top, /* uint32_t output_padding_top */ | |
pad_right, /* uint32_t output_padding_right */ | |
pad_bottom, /* uint32_t output_padding_bottom */ | |
pad_left, /* uint32_t output_padding_left */ | |
kernel_h, /* uint32_t kernel_height */ | |
kernel_w, /* uint32_t kernel_width */ | |
stride_h, /* uint32_t stride_height */ | |
stride_w, /* uint32_t stride_width */ | |
dilation_h, /* uint32_t dilation_height */ | |
dilation_w, /* uint32_t dilation_width */ | |
groups, /* uint32_t groups */ | |
group_input_channels, /* size_t group_input_channels */ | |
group_output_channels, /* size_t group_output_channels */ | |
ip_chan_stride, /* size_t input_pixel_stride */ | |
op_chan_stride, /* size_t output_pixel_stride */ | |
izp, /* int8_t input_zero_point */ | |
ip_scale, /* float input_scale */ | |
k_scales[0], /* float kernel_scale */ | |
kernel, /* const int8_t* kernel */ | |
bias, /* const int32_t* bias */ | |
ozp, /* int8_t output_zero_point */ | |
op_scale, /* float output_scale */ | |
op_min, /* int8_t output_min */ | |
op_max, /* int8_t output_max */ | |
flags, /* uint32_t flags */ | |
op); /* xnn_operator_t* deconvolution_op_out */ | |
} | |
if (!per_channel) { | |
return xnn_create_convolution2d_nhwc_qs8( | |
pad_top, /* uint32_t input_padding_top */ | |
pad_right, /* uint32_t input_padding_right */ | |
pad_bottom, /* uint32_t input_padding_bottom */ | |
pad_left, /* uint32_t input_padding_left */ | |
kernel_h, /* uint32_t kernel_height */ | |
kernel_w, /* uint32_t kernel_width */ | |
stride_h, /* uint32_t subsampling_height */ | |
stride_w, /* uint32_t subsampling_width */ | |
dilation_h, /* uint32_t dilation_height */ | |
dilation_w, /* uint32_t dilation_width */ | |
groups, /* uint32_t groups */ | |
group_input_channels, /* size_t group_input_channels */ | |
group_output_channels, /* size_t group_output_channels*/ | |
ip_chan_stride, /* size_t input_channel_stride */ | |
op_chan_stride, /* size_t output_channel_stride */ | |
izp, /* int8_t input_zero_point */ | |
ip_scale, /* float input_scale */ | |
k_scales[0], /* float kernel_scale */ | |
kernel, /* const int8_t* kernel */ | |
bias, /* const int32_t* bias */ | |
ozp, /* int8_t output_zero_point */ | |
op_scale, /* float output_scale */ | |
op_min, /* int8_t output_min */ | |
op_max, /* int8_t output_max */ | |
flags, /* uint32_t flags */ | |
op); /* xnn_operator_t* convolution_op_out */ | |
} else { /* per_channel */ | |
return xnn_create_convolution2d_nhwc_qc8( | |
pad_top, /* uint32_t input_padding_top */ | |
pad_right, /* uint32_t input_padding_right */ | |
pad_bottom, /* uint32_t input_padding_bottom */ | |
pad_left, /* uint32_t input_padding_left */ | |
kernel_h, /* uint32_t kernel_height */ | |
kernel_w, /* uint32_t kernel_width */ | |
stride_h, /* uint32_t subsampling_height */ | |
stride_w, /* uint32_t subsampling_width */ | |
dilation_h, /* uint32_t dilation_height */ | |
dilation_w, /* uint32_t dilation_width */ | |
groups, /* uint32_t groups */ | |
group_input_channels, /* size_t group_input_channels */ | |
group_output_channels, /* size_t group_output_channels*/ | |
ip_chan_stride, /* size_t input_channel_stride */ | |
op_chan_stride, /* size_t output_channel_stride */ | |
izp, /* int8_t input_zero_point */ | |
ip_scale, /* float input_scale */ | |
k_scales, /* const float* kernel_scale */ | |
kernel, /* const int8_t* kernel */ | |
bias, /* const int32_t* bias */ | |
ozp, /* int8_t output_zero_point */ | |
op_scale, /* float output_scale */ | |
op_min, /* int8_t output_min */ | |
op_max, /* int8_t output_max */ | |
flags, /* uint32_t flags */ | |
op); /* xnn_operator_t* convolution_op_out */ | |
} | |
} | |
/* | |
* Series of setup wrapper functions to call xnn_setup_[de]conv* functions. | |
*/ | |
C10_ALWAYS_INLINE | |
enum xnn_status xnnp_setup_convolution2d_nhwc( | |
xnn_operator_t op, | |
size_t batch, | |
size_t in_h, | |
size_t in_w, | |
const int8_t* inp, | |
int8_t* outp, | |
pthreadpool_t pt_pool, | |
bool per_channel = false, | |
bool transpose = false, | |
uint32_t adj_h = 0, | |
uint32_t adj_w = 0) { | |
if(transpose) { | |
TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); | |
return xnn_setup_deconvolution2d_nhwc_qs8( | |
op, /* xnn_operator_t deconvolution_op */ | |
batch, /* size_t batch_size */ | |
in_h, /* size_t input_height */ | |
in_w, /* size_t input_width */ | |
adj_h, /* uint32_t adjustment_height */ | |
adj_w, /* uint32_t adjustment_width */ | |
inp, /* const int8_t* input */ | |
outp, /* int8_t* output */ | |
pt_pool); /* pthreadpool_t threadpool */ | |
} | |
if (!per_channel) { | |
return xnn_setup_convolution2d_nhwc_qs8( | |
op, /* xnn_operator_t convolution_op */ | |
batch, /* size_t batch_size */ | |
in_h, /* size_t input_height */ | |
in_w, /* size_t input_width */ | |
inp, /* const int8_t* input */ | |
outp, /* int8_t* output */ | |
pt_pool); /* pthreadpool_t threadpool */ | |
} else { /* per_channel */ | |
return xnn_setup_convolution2d_nhwc_qc8( | |
op, /* xnn_operator_t convolution_op */ | |
batch, /* size_t batch_size */ | |
in_h, /* size_t input_height */ | |
in_w, /* size_t input_width */ | |
inp, /* const int8_t* input */ | |
outp, /* int8_t* output */ | |
pt_pool); /* pthreadpool_t threadpool */ | |
} | |
} | |
/* | |
* Series of wrapper functions to call xnn_create* and xnn_setup* | |
* functions for linear | |
*/ | |
C10_ALWAYS_INLINE | |
enum xnn_status xnnp_create_fully_connected_nc( | |
size_t input_channels, | |
size_t output_channels, | |
size_t input_stride, | |
size_t output_stride, | |
int8_t input_zero_point, | |
float input_scale, | |
int8_t kernel_zero_point, | |
float kernel_scale, | |
const int8_t* kernel, | |
const int32_t* bias, | |
int8_t output_zero_point, | |
float output_scale, | |
int8_t output_min, | |
int8_t output_max, | |
uint32_t flags, | |
xnn_operator_t* fully_connected_op_out) { | |
/* Symmetric quantization forces kzp = 0 */ | |
TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." | |
"But got: ", kernel_zero_point); | |
return xnn_create_fully_connected_nc_qs8( | |
input_channels, /* size_t input_channels */ | |
output_channels, /* size_t output_channels */ | |
input_stride, /* size_t input_stride */ | |
output_stride, /* size_t output_stride */ | |
input_zero_point, /* int8_t input_zero_point */ | |
input_scale, /* float input_scale */ | |
kernel_scale, /* float kernel_scale */ | |
kernel, /* const int8_t* kernel */ | |
bias, /* const int32_t* bias */ | |
output_zero_point, /* int8_t output_zero_point */ | |
output_scale, /* float output_scale */ | |
output_min, /* int8_t output_min */ | |
output_max, /* int8_t output_max */ | |
flags, /* uint32_t flags */ | |
fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ | |
} | |
C10_ALWAYS_INLINE | |
enum xnn_status xnnp_setup_fully_connected_nc( | |
xnn_operator_t fully_connected_op, | |
size_t batch_size, | |
const int8_t* input, | |
int8_t* output, | |
pthreadpool_t threadpool) { | |
return xnn_setup_fully_connected_nc_qs8( | |
fully_connected_op, /* xnn_operator_t fully_connected_op */ | |
batch_size, /* size_t batch_size */ | |
input, /* const int8_t* input */ | |
output, /* int8_t* output */ | |
threadpool); /* pthreadpool_t threadpool */ | |
} | |
} // namespace xnnp_utils | |
} // namespace native | |
} // namespace at | |