// Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. $assert CHANNEL_TILE >= 1 $assert KERNEL_TILE >= 2 $assert ACCUMULATORS >= 1 $assert ACTIVATION in ["LINEAR", "MINMAX"] $assert ACTIVATION != "LINEAR" or not WASM $ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" #include #include #include $MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32" $MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32" $SUFFIX = {"LINEAR": "", "MINMAX": "_minmax"}[ACTIVATION] $PARAMS = {"LINEAR": "xnn_f32_default_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION] void xnn_f32_dwconv${SUFFIX}_ukernel_${KERNEL_TILE}p${CHANNEL_TILE}c__${"wasm" if WASM else "scalar"}${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}( size_t channels, size_t output_width, const float** input, const float* weights, float* output, intptr_t input_stride, size_t output_increment, size_t input_offset, const float* zero, const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) { assert(channels != 0); assert(output_width != 0); $if ACTIVATION == "MINMAX": const float vmin = params->scalar.min; const float vmax = params->scalar.max; do { $for K in range(KERNEL_TILE): const float* i${K} = input[${K}]; assert(i${K} != NULL); if XNN_UNPREDICTABLE(i${K} != zero) { i${K} = (const float*) ((uintptr_t) i${K} + input_offset); } input = (const float**) ((uintptr_t) input + input_stride); size_t c = channels; const float* w = weights; $if CHANNEL_TILE > 1: for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { $for C in range(CHANNEL_TILE): float vacc${C}p0 = w[${C}]; $for K in range(KERNEL_TILE): $for C in range(CHANNEL_TILE): const float vi${K}x${C} = i${K}[${C}]; i${K} += ${CHANNEL_TILE}; $for C in range(CHANNEL_TILE): const float vk${K}x${C} = w[${(K + 1) * CHANNEL_TILE + C}]; $if 1 <= K < ACCUMULATORS: float vacc${C}p${K} = vi${K}x${C} * vk${K}x${C}; $else: vacc${C}p${K % ACCUMULATORS} = math_muladd_f32(vi${K}x${C}, vk${K}x${C}, vacc${C}p${K % ACCUMULATORS}); w += ${(KERNEL_TILE + 1) * CHANNEL_TILE}; $if ACCUMULATORS > 1: // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0 $ACC_SLICE = 1 $while ACC_SLICE < ACCUMULATORS: $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): $if A + ACC_SLICE < ACCUMULATORS: $for C in range(CHANNEL_TILE): vacc${C}p${A} = vacc${C}p${A} + vacc${C}p${A + ACC_SLICE}; $ACC_SLICE *= 2 $if ACTIVATION == "MINMAX": $for C in range(CHANNEL_TILE): float vacc${C} = ${MAX_F32}(vacc${C}p0, vmin); $for C in range(CHANNEL_TILE): vacc${C} = ${MIN_F32}(vacc${C}, vmax); $for C in range(CHANNEL_TILE): output[${C}] = vacc${C}; $else: $for C in range(CHANNEL_TILE): output[${C}] = vacc${C}p0; output += ${CHANNEL_TILE}; } for (; c >= 1; c -= 1) { float vacc0p0 = *w++; $for K in range(KERNEL_TILE): const float vi${K} = *i${K}++; const float vk${K} = w[${(K + 1) * CHANNEL_TILE - 1}]; $if 1 <= K < ACCUMULATORS: float vacc0p${K} = vi${K} * vk${K}; $else: vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS}); $if ACCUMULATORS > 1: // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0 $ACC_SLICE = 1 $while ACC_SLICE < ACCUMULATORS: $for A in range(0, ACCUMULATORS, ACC_SLICE * 2): $if A + ACC_SLICE < ACCUMULATORS: vacc0p${A} = vacc0p${A} + vacc0p${A + ACC_SLICE}; $ACC_SLICE *= 2 $if ACTIVATION == "MINMAX": float vacc0 = ${MAX_F32}(vacc0p0, vmin); vacc0 = ${MIN_F32}(vacc0, vmax); *output++ = vacc0; $else: *output++ = vacc0p0; } $else: do { float vacc0p0 = w[0]; $for K in range(KERNEL_TILE): const float vi${K} = *i${K}++; const float vk${K} = w[${K+1}]; $if 1 <= K < ACCUMULATORS: float vacc0p${K} = vi${K} * vk${K}; $else: vacc0p${K % ACCUMULATORS} = math_muladd_f32(vi${K}, vk${K}, vacc0p${K % ACCUMULATORS}); w += ${KERNEL_TILE + 1}; $ACC_STEP = 1 $while ACC_STEP < ACCUMULATORS: $for A in range(0, ACCUMULATORS, ACC_STEP * 2): $if A + ACC_STEP < ACCUMULATORS: vacc0p${A} += vacc0p${A + ACC_STEP}; $ACC_STEP *= 2 $if ACTIVATION == "MINMAX": float vacc0 = ${MAX_F32}(vacc0p0, vmin); vacc0 = ${MIN_F32}(vacc0, vmax); *output++ = vacc0; $else: *output++ = vacc0p0; } while (--c != 0); output = (float*) ((uintptr_t) output + output_increment); } while (--output_width != 0); }