|
|
|
|
|
|
|
|
|
|
|
"""Converts hand written assembly (.S files) to C++ files using the JIT. |
|
|
|
Takes a single argument, an assembly file, and prints converted output to |
|
stdout. |
|
""" |
|
|
|
import argparse |
|
import codecs |
|
from collections import defaultdict |
|
import datetime |
|
from enum import Enum |
|
import os |
|
import re |
|
import sys |
|
from typing import List, Tuple, Mapping |
|
|
|
|
|
class PrfmMode(Enum): |
|
NoPrfm = 1 |
|
PrfmInFileName = 2 |
|
ForcePrfm = 3 |
|
|
|
SPACES = r'\s*' |
|
COMMA = r',' + SPACES |
|
COMMENTS = SPACES + '((//\s+.+)|)$' |
|
WB = r'!' |
|
|
|
REG_NO_GROUP = r'r\d+|h\d+|s\d+|d\d+|q\d+|sp|lr|pc|w\d+|x\d+|(?:v\d+\.(?:\d+)?(?:d|s|h|b))' |
|
REG = r'(' + REG_NO_GROUP + ')' |
|
IMM_NO_GROUP = r'\d+' |
|
IMM = r'(' + IMM_NO_GROUP + ')' |
|
REG_LANE_NO_GROUP = r'(?:' + REG_NO_GROUP + r')\[' + IMM_NO_GROUP + r'\]' |
|
REG_OR_IMM = r'(' + REG_LANE_NO_GROUP + '|' + REG_NO_GROUP + '|' + IMM_NO_GROUP + ')' |
|
REG_INDEXED = r'(?:' + REG_NO_GROUP + r')\[' + IMM_NO_GROUP + r'\]' |
|
|
|
REGLIST_CONSEC = r'\{(\w+)-(\w+)\}' + SPACES |
|
REGLIST_INDIV = r'\{([\w.]+(?:,\s+[\w.]+)*)\}' + SPACES |
|
REGLIST_INDIV_REPLICATE = r'\{(\w+(?:\[\])(,\s*\w+(?:\[\]))*)\}' + SPACES |
|
REGLIST_INDEX = r'\{(' + REG_LANE_NO_GROUP + ')\}' + SPACES |
|
|
|
REGLIST_LANE_INDEX = r'\{' + REG + '\}' + r'\[(\d+)\]' + SPACES |
|
|
|
REG_INDEX = REG + r'\[(\d+)\]' + SPACES |
|
|
|
APSR = 'APSR_nzcv' |
|
FPSCR = '(FPSCR)' |
|
|
|
MEMOP = r'\[' + SPACES + REG + '\]' + SPACES |
|
MEMOP_MAYBE_WB = r'\[' + SPACES + REG + '\]' + f'({WB})?' |
|
MEMOP_OFFSET = r'\[' + REG + COMMA + '(-?\d+)\]' + SPACES |
|
MEMOP_OFFSET_MAYBE_WB = r'\[' + REG + COMMA + '(-?\d+)\]' + f'({WB})?' + SPACES |
|
|
|
B_IMM = r'(\d+)(f|b)' |
|
|
|
INSTR = SPACES + r'([A-Z0-9.]+)' + SPACES |
|
|
|
|
|
IFDEF_RE = re.compile(r'\s*#(ifndef|endif|ifdef).*') |
|
|
|
COMMENT_RE = re.compile(SPACES + r'((//|#)\s*.+)') |
|
|
|
LABEL = re.compile(r'(\w+):') |
|
|
|
INSTR_RE = re.compile(INSTR + COMMENTS) |
|
|
|
INSTR_REGLIST_CONSEC_RE = re.compile(INSTR + REGLIST_CONSEC + COMMENTS) |
|
|
|
INSTR_REGLIST_LIST_RE = re.compile(INSTR + REGLIST_INDIV + COMMENTS) |
|
|
|
INSTR_OP_RE = re.compile(INSTR + REG + COMMENTS) |
|
|
|
INSTR_B_IMM = re.compile(INSTR + B_IMM + COMMENTS) |
|
|
|
INSTR_B_REG_IMM_IMM = re.compile(INSTR + REG + COMMA + IMM + COMMA + B_IMM + |
|
COMMENTS) |
|
|
|
P2ALIGN_RE = re.compile(SPACES + r'\.p2align\s+(\d+)') |
|
|
|
INSTR_REG_IMM_RE = re.compile(INSTR + REG + COMMA + IMM + COMMENTS) |
|
|
|
INSTR_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMENTS) |
|
|
|
INSTR_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + MEMOP + COMMA + IMM + |
|
COMMENTS) |
|
|
|
INSTR_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + |
|
MEMOP_OFFSET_MAYBE_WB + COMMENTS) |
|
|
|
INSTR_REG_REG_MEMOP_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + MEMOP + |
|
COMMENTS) |
|
|
|
INSTR_REG_REG_MEMOP_OFFSET_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + |
|
MEMOP_OFFSET_MAYBE_WB + COMMENTS) |
|
|
|
INSTR_REG_REG_MEMOP_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + |
|
MEMOP + COMMA + IMM + COMMENTS) |
|
|
|
INSTR_MEMOP_RE = re.compile(INSTR + MEMOP + COMMENTS) |
|
|
|
INSTR_MEMOP_OFFSET_RE = re.compile(INSTR + MEMOP_OFFSET + COMMENTS) |
|
|
|
INSTR_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG_OR_IMM + COMMENTS) |
|
|
|
INSTR_REG_REG_REG_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + |
|
REG_OR_IMM + COMMENTS) |
|
|
|
INSTR_REG_REG_REG_IMM_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG + |
|
COMMA + IMM + COMMENTS) |
|
|
|
INSTR_REGLIST_INDIV_MEMOP_REG = re.compile(INSTR + REGLIST_INDIV + COMMA + |
|
MEMOP + COMMA + REG + COMMENTS) |
|
|
|
INSTR_REGLIST_CONSEC_MEMOP_REG = re.compile(INSTR + REGLIST_CONSEC + COMMA + |
|
MEMOP + COMMA + REG + COMMENTS) |
|
|
|
INSTR_REG_REGLIST_CONSECT = re.compile(INSTR + REG + COMMA + REGLIST_CONSEC + |
|
COMMENTS) |
|
|
|
INSTR_REG_REGLIST_CONSECT_WB = re.compile(INSTR + REG + WB + COMMA + |
|
REGLIST_CONSEC + COMMENTS) |
|
|
|
INSTR_REG_REGLIST_INDIV_WB = re.compile(INSTR + REG + WB + COMMA + |
|
REGLIST_INDIV + COMMENTS) |
|
|
|
INSTR_REGLIST_INDIV_MEMOP = re.compile(INSTR + REGLIST_INDIV + COMMA + |
|
MEMOP_MAYBE_WB + COMMENTS) |
|
|
|
INSTR_REGLIST_INDIV_MEMOP_IMM = re.compile(INSTR + REGLIST_INDIV + COMMA + |
|
MEMOP + COMMA + IMM + COMMENTS) |
|
|
|
INSTR_REGLIST_INDEX_MEMOP_IMM = re.compile(INSTR + REGLIST_LANE_INDEX + COMMA + |
|
MEMOP + COMMA + IMM + COMMENTS) |
|
|
|
INSTR_REG_INDEX_REG = re.compile(INSTR + REG_INDEX + COMMA + REG + COMMENTS) |
|
|
|
INSTR_REGLIST_CONSEC_MEMOP = re.compile(INSTR + REGLIST_CONSEC + COMMA + |
|
MEMOP_MAYBE_WB + COMMENTS) |
|
|
|
INSTR_REGLIST_REPLICATE_MEMOP = re.compile(INSTR + REGLIST_INDIV_REPLICATE + |
|
COMMA + MEMOP + r'(!)?' + COMMENTS) |
|
|
|
INSTR_REGLIST_INDEX_MEMOP = re.compile(INSTR + REGLIST_INDEX + COMMA + |
|
MEMOP_MAYBE_WB + COMMENTS) |
|
|
|
INSTR_REG_FPSCR = re.compile(INSTR + f'({APSR}|{REG_NO_GROUP})' + COMMA + |
|
FPSCR + COMMENTS) |
|
|
|
|
|
INSTR_PLD_MEMOP = re.compile(INSTR + f'(PLDL1KEEP|PSTL1KEEP)' + COMMA + MEMOP + COMMENTS) |
|
|
|
INSTR_PLD_MEMOP_OFFSET = re.compile(INSTR + f'(PLDL1KEEP)' + COMMA + |
|
MEMOP_OFFSET + COMMENTS) |
|
|
|
COND = r'([A-Z]+)' |
|
|
|
INSTR_REG_REG_REG_COND_RE = re.compile(INSTR + REG + COMMA + REG + COMMA + REG + |
|
COMMA + COND + COMMENTS) |
|
|
|
|
|
def remove_brackets(s: str) -> str: |
|
return s.replace('[', '').replace(']', '') |
|
|
|
|
|
def fix_replicate_instruction(s: str) -> str: |
|
return re.sub(r'_(\d+)', r'r_\1', s, 1) |
|
|
|
|
|
def fix_instr_name(s: str) -> str: |
|
fixed = s.lower().replace('.', '_', 2) |
|
if fixed == 'and': |
|
return 'and_' |
|
return fixed |
|
|
|
|
|
def fix_comments(s: str) -> str: |
|
return s.replace('#', '//', 1) |
|
|
|
|
|
def maybe_wb(wb: bool) -> str: |
|
return '++' if wb else '' |
|
|
|
|
|
def fix_fn_name(name: str) -> str: |
|
if name.startswith('xnn_'): |
|
name = name[len('xnn_'):] |
|
|
|
name = name.replace("__asm_", "__") |
|
|
|
if 'minmax' in name: |
|
name = name.replace('minmax_', '') |
|
return f'xnn_generate_{name}' |
|
|
|
|
|
def remove_prfm_from_fn_name(name: str) -> str: |
|
assert ('_prfm' in name) |
|
return name.replace('_prfm', '') |
|
|
|
|
|
def fix_regs(regs: str) -> str: |
|
|
|
|
|
def repl(m): |
|
if m.group(2): |
|
return f'{m[1]}v{m[2]}{m[3]}()' |
|
else: |
|
return f'{m[1]}{m[3]}()' |
|
|
|
return re.sub(r'(\w+\.)(\d+)?(\w+)', repl, regs) |
|
|
|
|
|
def get_callee_saved() -> List[str]: |
|
return [ |
|
'd8', 'd9', 'd10', 'd11', 'd12', 'd13', 'd14', 'd15', 'x19', 'x20', 'x21', |
|
'x22' |
|
] |
|
|
|
|
|
IGNORE_LINES = [r'\s*\.\w+'] |
|
|
|
AARCH32 = 'aarch32' |
|
AARCH64 = 'aarch64' |
|
GEMM = 'GEMM' |
|
IGEMM = 'IGEMM' |
|
|
|
|
|
AARCH32_POST_OP = """void Generator::perform_post_operations( |
|
size_t max_mr, |
|
size_t num_post_operations, |
|
const xnn_post_operation* post_operations) |
|
{ |
|
if (num_post_operations == 0) { |
|
return; |
|
} |
|
for (size_t i = 0; i < num_post_operations; i++) { |
|
switch (post_operations[i].op_type) { |
|
case xnn_post_operation_type_hardswish: { |
|
const auto sixth = q0; |
|
const auto three = q1; |
|
const auto six = q2; |
|
const auto zero = q3; |
|
vld3r_32({sixth.low(), three.low(), six.low()}, mem[PARAMS_REG_PLACEHOLDER]++); |
|
vmov(zero, 0); |
|
vmov(three.high(), three.low()); |
|
vmov(six.high(), six.low()); |
|
const QRegister accs[] = {q8, q9, q10, q11, q12, q13, q14, q15}; |
|
const QRegister tmps[] = {q4, q5, q6, q7}; |
|
f32_hardswish(sixth, three, six, zero, &accs[0], XNN_COUNT_OF(accs), &tmps[0], XNN_COUNT_OF(tmps)); |
|
break; |
|
} |
|
default: |
|
XNN_LOG_UNREACHABLE("unsupported post operation: %u", post_operations[i].op_type); |
|
} |
|
} |
|
}""" |
|
|
|
AARCH32_POST_OP_RELOAD = """void Generator::perform_post_operations( |
|
size_t max_mr, |
|
size_t num_post_operations, |
|
const xnn_post_operation* post_operations) |
|
{ |
|
if (num_post_operations == 0) { |
|
return; |
|
} |
|
ldr(PARAMS_REG_PLACEHOLDER, mem[sp, PARAMS_OFFSET_PLACEHOLDER]); // params |
|
for (size_t i = 0; i < num_post_operations; i++) { |
|
switch (post_operations[i].op_type) { |
|
case xnn_post_operation_type_hardswish: { |
|
const auto sixth = q0; |
|
const auto three = q1; |
|
const auto six = q2; |
|
const auto zero = q3; |
|
vld3r_32({sixth.low(), three.low(), six.low()}, mem[PARAMS_REG_PLACEHOLDER]++); |
|
vmov(zero, 0); |
|
vmov(three.high(), three.low()); |
|
vmov(six.high(), six.low()); |
|
const QRegister accs[] = {ACCS_PLACEHOLDER}; |
|
const QRegister tmps[] = {TMPS_PLACEHOLDER}; |
|
f32_hardswish(sixth, three, six, zero, &accs[0], XNN_COUNT_OF(accs), &tmps[0], XNN_COUNT_OF(tmps)); |
|
break; |
|
} |
|
default: |
|
XNN_LOG_UNREACHABLE("unsupported post operation: %u", post_operations[i].op_type); |
|
} |
|
} |
|
}""" |
|
|
|
AARCH64_POST_OP = """void Generator::perform_post_operations( |
|
size_t max_mr, |
|
size_t num_post_operations, |
|
const xnn_post_operation* post_operations) |
|
{ |
|
if (num_post_operations == 0) { |
|
return; |
|
} |
|
for (size_t i = 0; i < num_post_operations; i++) { |
|
switch (post_operations[i].op_type) { |
|
case xnn_post_operation_type_hardswish: { |
|
// Reuse A pointers (don't use v8-v15 as they are callee saved). |
|
const auto sixth = v0.v4s(); |
|
const auto three = v1.v4s(); |
|
const auto six = v2.v4s(); |
|
const auto zero = v3.v4s(); |
|
// v4, v5, v6, v7 available for temporaries. |
|
ld3r({sixth, three, six}, mem[PARAMS_REG_PLACEHOLDER]++); |
|
movi(zero, 0); |
|
const VRegister accs[] = {ACCS_PLACEHOLDER |
|
}; |
|
const VRegister tmps[] = {TMPS_PLACEHOLDER}; |
|
f32_hardswish(sixth, three, six, zero, &accs[0], XNN_COUNT_OF(accs), &tmps[0], XNN_COUNT_OF(tmps)); |
|
break; |
|
} |
|
default: |
|
XNN_LOG_UNREACHABLE("unsupported post operation: %u", post_operations[i].op_type); |
|
} |
|
} |
|
}""" |
|
|
|
AARCH32_MR1_POST_OP_ACCS = "q8, q9" |
|
AARCH32_MR4_POST_OP_ACCS = "q8, q9, q10, q11, q12, q13, q14, q15" |
|
AARCH32_MR1_POST_OP_TMPS = "q12, q13" |
|
AARCH32_MR4_POST_OP_TMPS = "q4, q5, q6, q7" |
|
AARCH64_MR1_POST_OP_ACCS = """ |
|
v16.v4s(), v17.v4s(),""" |
|
AARCH64_MR4_POST_OP_ACCS = """ |
|
v24.v4s(), v25.v4s(), |
|
v26.v4s(), v27.v4s(), |
|
v28.v4s(), v29.v4s(), |
|
v30.v4s(), v31.v4s(),""" |
|
AARCH64_MR6_POST_OP_ACCS = """ |
|
v20.v4s(), v21.v4s(), v22.v4s(), v23.v4s(), |
|
v24.v4s(), v25.v4s(), v26.v4s(), v27.v4s(), |
|
v28.v4s(), v29.v4s(), v30.v4s(), v31.v4s(),""" |
|
AARCH64_MR1_POST_OP_TMPS = """v4.v4s(), v5.v4s()""" |
|
AARCH64_MR6_POST_OP_TMPS = """v4.v4s(), v5.v4s(), v6.v4s(), v7.v4s()""" |
|
|
|
|
|
def replace_template(template: str, replacements: Mapping[str, str]): |
|
result = template |
|
for k, v in replacements.items(): |
|
result = result.replace(k, v) |
|
return result |
|
|
|
|
|
def get_post_op_accs(vector_register_usage): |
|
usage = [reg for sublist in vector_register_usage for reg in sublist]; |
|
assert(len(usage) == 8) |
|
return f""" |
|
{usage[0]}.v4s(), {usage[1]}.v4s(), |
|
{usage[2]}.v4s(), {usage[3]}.v4s(), |
|
{usage[4]}.v4s(), {usage[5]}.v4s(), |
|
{usage[6]}.v4s(), {usage[7]}.v4s(),""" |
|
|
|
|
|
|
|
|
|
def get_post_operation_implementation(arch, mr: int, params_register: str, |
|
params_offset: str, reload_params: bool, |
|
vector_register_usage): |
|
if arch == AARCH32: |
|
if reload_params: |
|
if mr == 1: |
|
return replace_template( |
|
AARCH32_POST_OP_RELOAD, { |
|
'ACCS_PLACEHOLDER': AARCH32_MR1_POST_OP_ACCS, |
|
'TMPS_PLACEHOLDER': AARCH32_MR1_POST_OP_TMPS, |
|
'PARAMS_REG_PLACEHOLDER': params_register, |
|
'PARAMS_OFFSET_PLACEHOLDER': params_offset, |
|
}) |
|
else: |
|
return replace_template( |
|
AARCH32_POST_OP_RELOAD, { |
|
'ACCS_PLACEHOLDER': AARCH32_MR4_POST_OP_ACCS, |
|
'TMPS_PLACEHOLDER': AARCH32_MR4_POST_OP_TMPS, |
|
'PARAMS_REG_PLACEHOLDER': params_register, |
|
'PARAMS_OFFSET_PLACEHOLDER': params_offset, |
|
}) |
|
else: |
|
return replace_template(AARCH32_POST_OP, { |
|
'PARAMS_REG_PLACEHOLDER': params_register, |
|
}) |
|
elif arch == AARCH64: |
|
|
|
|
|
if mr == 1: |
|
return replace_template( |
|
AARCH64_POST_OP, { |
|
'ACCS_PLACEHOLDER': AARCH64_MR1_POST_OP_ACCS, |
|
'TMPS_PLACEHOLDER': AARCH64_MR1_POST_OP_TMPS, |
|
'PARAMS_REG_PLACEHOLDER': params_register |
|
}) |
|
elif mr == 4: |
|
return replace_template( |
|
AARCH64_POST_OP, { |
|
'ACCS_PLACEHOLDER': get_post_op_accs(vector_register_usage['C']), |
|
'TMPS_PLACEHOLDER': AARCH64_MR6_POST_OP_TMPS, |
|
'PARAMS_REG_PLACEHOLDER': params_register, |
|
'PARAMS_OFFSET_PLACEHOLDER': params_offset, |
|
}) |
|
elif mr == 6: |
|
return replace_template( |
|
AARCH64_POST_OP, { |
|
'ACCS_PLACEHOLDER': AARCH64_MR6_POST_OP_ACCS, |
|
'TMPS_PLACEHOLDER': AARCH64_MR6_POST_OP_TMPS, |
|
'PARAMS_REG_PLACEHOLDER': params_register |
|
}) |
|
else: |
|
print(f'unsupported mr {mr} for post operations', file=sys.stderr) |
|
sys.exit(1) |
|
|
|
|
|
def parse_prologue(input_file: str, lines: List[str], arch: str, minmax: bool, |
|
kernel_type: str, prfm_mode: PrfmMode, mr: int, |
|
post_op: bool) -> Tuple[List[str], Mapping[str, int]]: |
|
prologue = [] |
|
|
|
in_autogen = False |
|
in_a_pointers = False |
|
|
|
in_c_pointers = False |
|
|
|
in_register_usage = False |
|
a_pointers = [] |
|
c_pointers = [] |
|
|
|
vector_register_usage = defaultdict(list) |
|
|
|
vector_register_map = {} |
|
|
|
for line in lines: |
|
if 'Auto-generated file' in line: |
|
in_autogen = True |
|
continue |
|
elif line.startswith('.syntax') or 'LINT.' in line: |
|
continue |
|
elif 'BEGIN_FUNCTION' in line: |
|
prologue.append(f'// Converted from: {input_file[20:]}') |
|
params = 'const jit_gemm_params* jit_gemm_params' |
|
prefetch = 'bool prefetch, ' if prfm_mode == PrfmMode.PrfmInFileName else '' |
|
if kernel_type == GEMM: |
|
prologue.append( |
|
f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params})' |
|
) |
|
prologue.append('{') |
|
else: |
|
prologue.append( |
|
f'void Generator::generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params})' |
|
) |
|
prologue.append('{') |
|
continue |
|
elif 'Copyright ' in line: |
|
in_autogen = False |
|
|
|
prologue.append(line) |
|
continue |
|
elif '#include <xnnpack/assembly.h>' in line: |
|
prologue.append(f'#include <cassert>') |
|
prologue.append(f'#include <cstddef>') |
|
if minmax: |
|
prologue.append(f'#include <limits>') |
|
prologue.append('') |
|
prologue.append('#include <xnnpack.h>') |
|
prologue.append(f'#include <xnnpack/{arch}-assembler.h>') |
|
if kernel_type == GEMM: |
|
prologue.append('#include <xnnpack/gemm.h>') |
|
else: |
|
prologue.append('#include <xnnpack/igemm.h>') |
|
if post_op: |
|
prologue.append('#include <xnnpack/log.h>') |
|
prologue.append('#include <xnnpack/memory.h>') |
|
prologue.append('#include <xnnpack/microparams.h>') |
|
prologue.append('#include <xnnpack/post-operation.h>') |
|
prologue.append('') |
|
prologue.append('namespace xnnpack {') |
|
prologue.append(f'namespace {arch} {{') |
|
prologue.append('namespace {') |
|
prologue.append('class Generator : public MacroAssembler {') |
|
prologue.append(' using MacroAssembler::MacroAssembler;') |
|
prologue.append('') |
|
prologue.append(' public:') |
|
params = 'float min, float max' if minmax else 'void* params' |
|
params = 'const jit_gemm_params* jit_gemm_params' |
|
prefetch = 'bool prefetch, ' if prfm_mode == PrfmMode.PrfmInFileName else '' |
|
if kernel_type == GEMM: |
|
prologue.append( |
|
f' void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, {params});' |
|
) |
|
else: |
|
prologue.append( |
|
f' void generate({prefetch}size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, {params});' |
|
) |
|
|
|
if post_op: |
|
prologue.append( |
|
' void perform_post_operations(size_t max_mr, size_t num_post_operations, const xnn_post_operation* post_operations);' |
|
) |
|
prologue.append('};') |
|
continue |
|
elif in_a_pointers: |
|
prologue.append(fix_comments(line.rstrip())) |
|
if not line.strip(): |
|
in_a_pointers = False |
|
continue |
|
|
|
|
|
|
|
m = re.search(r'(?:#|//)\W+(?:A\d+\W+)?(\w\d+)', line) |
|
if not m: |
|
print(f'ERROR expected to find A pointers: {line}', file=sys.stderr) |
|
sys.exit(1) |
|
a_pointers.append(m.group(1)) |
|
continue |
|
elif 'A pointers' in line: |
|
prologue.append(fix_comments(line.rstrip())) |
|
in_a_pointers = True |
|
continue |
|
elif in_c_pointers: |
|
prologue.append(fix_comments(line.rstrip())) |
|
if not line.strip(): |
|
in_c_pointers = False |
|
continue |
|
|
|
|
|
|
|
m = re.search(r'(?:#|//)\W+(?:C\d+\W+)?(\w\d+)', line) |
|
if not m: |
|
print(f'ERROR expected to find C pointers: {line}', file=sys.stderr) |
|
sys.exit(1) |
|
c_pointers.append(m.group(1)) |
|
continue |
|
elif 'C pointers' in line: |
|
prologue.append(fix_comments(line.rstrip())) |
|
in_c_pointers = True |
|
continue |
|
elif in_register_usage: |
|
prologue.append(fix_comments(line.rstrip())) |
|
if line.strip() == '': |
|
in_register_usage = False |
|
continue |
|
if any(word in line.lower() |
|
for word in ['unused', 'scratch', 'temp']): |
|
continue |
|
|
|
if re.search(r'(?:#|//)\W+B', line): |
|
continue |
|
|
|
|
|
m = re.search(r'(?:#|//)\W+[Cc]lamp\W+\(?([vr]\d+)\)?', line) |
|
if m: |
|
vector_register_usage['clamp'].append(m.group(1)) |
|
continue |
|
|
|
m = re.search(r'(?:#|//)\W+(A|C)\d?\W+(((?:v|d|q|r|x)\d+(?:\[\d+\])?(?:\W*|-))+)', line) |
|
if not m: |
|
print( |
|
f'ERROR failed to parse vector register usage: {line}', |
|
file=sys.stderr) |
|
sys.exit(1) |
|
param_reg = m.group(1) |
|
all_regs = m.group(2).split() |
|
pointer_regs = [reg for reg in all_regs if reg.startswith('r') or reg.startswith('x')] |
|
|
|
vec_regs = [reg for reg in all_regs if reg not in pointer_regs] |
|
|
|
|
|
if len(pointer_regs) > 1: |
|
print(f'ERROR unexpected pointer registers: {pointer_regs}', file=sys.stderr) |
|
sys.exit(1) |
|
if len(pointer_regs) == 1: |
|
if param_reg.startswith('A'): |
|
a_pointers.append(pointer_regs[0]) |
|
elif param_reg.startswith('C'): |
|
c_pointers.append(pointer_regs[0]) |
|
else: |
|
print(f'ERROR unrecognized param register {param_reg}', file=sys.stderr) |
|
sys.exit(1) |
|
vector_register_usage[param_reg].append(vec_regs) |
|
continue |
|
elif 'register usage' in line.lower(): |
|
prologue.append(fix_comments(line.rstrip())) |
|
in_register_usage = True |
|
continue |
|
elif any(re.fullmatch(p, line) for p in IGNORE_LINES): |
|
continue |
|
elif in_autogen: |
|
continue |
|
else: |
|
prologue.append(fix_comments(line.rstrip())) |
|
continue |
|
|
|
|
|
if a_pointers and len(a_pointers) != int(mr): |
|
print(f'len(a_pointers) {len(a_pointers)} != mr {mr}', file=sys.stderr) |
|
sys.exit(1) |
|
if c_pointers and len(c_pointers) != int(mr): |
|
print('len(c_pointers) != mr', file=sys.stderr) |
|
sys.exit(1) |
|
|
|
for i, v in enumerate(a_pointers): |
|
vector_register_map[v] = i |
|
for i, v in enumerate(c_pointers): |
|
vector_register_map[v] = i |
|
|
|
for reg_alphabet, v in vector_register_usage.items(): |
|
for i, _ in enumerate(v): |
|
for j, _ in enumerate(v[i]): |
|
|
|
regs = v[i][j].split('-') |
|
for reg in regs: |
|
vector_register_map[reg] = i |
|
if '[' in reg: |
|
continue |
|
vector_register_map[reg.replace('v', 'q')] = i |
|
vector_register_map[reg.replace('v', 'd')] = i |
|
vector_register_map[reg.replace('v', 's')] = i |
|
vector_register_map[reg.replace('v', 'h')] = i |
|
vector_register_map[reg.replace('v', 'b')] = i |
|
|
|
|
|
return prologue, vector_register_map, vector_register_usage |
|
|
|
|
|
def emit_prefetch_instruction(instr: str, prfm_mode: PrfmMode, |
|
instructions: List[str]) -> None: |
|
"""Emit instructions depending on prfm_mode. |
|
|
|
If prfm_mode is PrfmInFileName, guard instruction behind a prefetch check. instr should be |
|
the generated prefetch instruction (not the assembly instruction). |
|
If prfm_mode is ForcePrfm, emit unguarded prefetch. |
|
""" |
|
if prfm_mode == PrfmMode.PrfmInFileName: |
|
instructions.append(f'if (prefetch) {{ {instr} }}') |
|
elif prfm_mode == PrfmMode.ForcePrfm: |
|
instructions.append(f'{instr}') |
|
|
|
|
|
def emit_clamp_instruction(instr: str, instructions: List[str]) -> None: |
|
"""Guard fmax/fmin instructions behind a clamp_min/clamp_max check.""" |
|
if 'fmax' in instr or 'vmax' in instr: |
|
instructions.append(f'if (clamp_min) {{ {instr} }}') |
|
elif 'fmin' in instr or 'vmin' in instr: |
|
instructions.append(f'if (clamp_max) {{ {instr} }}') |
|
else: |
|
instructions.append(instr) |
|
|
|
|
|
def emit_instruction(instr: str, |
|
instructions: List[str], |
|
vector_register_map: Mapping[str, int], |
|
vector_register_usage = None, |
|
is_a53: bool = False) -> None: |
|
|
|
m = re.search(r'// (?:A|C)(\d+)(?:\W+\w+)?$', instr) |
|
if m: |
|
if m[1] == '0': |
|
instructions.append(instr) |
|
else: |
|
instructions.append(f'if (max_mr > {m[1]}) {{ {instr} }}') |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pat = re.compile(r'(\w+)\((?:mem\[)?\{?((?:v|q|d|s|h|b|x|w|r)\d+(?:\.d\(\)\[\d\])?)') |
|
m = re.search(pat, instr) |
|
if not m: |
|
instructions.append(instr) |
|
return |
|
|
|
instr_name = m.group(1) |
|
reg = m.group(2) |
|
|
|
|
|
|
|
m = re.search(r'(v\d+)\.d\(\)(\[\d\])', reg) |
|
if m: |
|
reg = f'{m[1]}{m[2]}' |
|
|
|
|
|
|
|
|
|
m = re.search(r'ld1\(\{(v\d+)\.[ds]\(\)\},\W+\d+,\W+mem\[(x\d+)\]', instr) |
|
if m: |
|
reg = f'{m[2]}' |
|
|
|
if instr_name in ['fmin', 'fmax', 'vmin_f32', 'vmax_f32']: |
|
max_mr = vector_register_map[reg] |
|
if (max_mr == 0): |
|
return emit_clamp_instruction(instr, instructions) |
|
else: |
|
return emit_clamp_instruction( |
|
f'if (max_mr > {max_mr}) {{ {instr.rstrip()} }}', instructions) |
|
|
|
if instr_name == 'ld2r' or instr_name == 'vld1r_32': |
|
instructions.append(f'if (clamp_min || clamp_max) {{ {instr} }}') |
|
return |
|
|
|
if ((instr_name == 'stp' or instr_name == 'ldp') and |
|
reg in get_callee_saved()) and 'mem[sp' in instr: |
|
|
|
instructions.append(instr) |
|
return |
|
|
|
|
|
|
|
if (instr_name == 'ldp'): |
|
m = re.search(r'ldp\((x\d+), (x\d+), (mem\[x\d+\]), (\d+)', instr) |
|
if m: |
|
reg1 = m[1] |
|
reg2 = m[2] |
|
mem = m[3] |
|
offset = m[4] |
|
if all(reg in vector_register_map for reg in [reg1, reg2]): |
|
max_mr = vector_register_map[reg2] |
|
instructions.append(f'if (max_mr == {max_mr}) {{ ldr({reg1}, {mem}, {int(offset)//2}); }}'); |
|
instructions.append(f'if (max_mr > {max_mr}) {{ ldp({reg1}, {reg2}, {mem}, {offset}); }}'); |
|
return |
|
|
|
cmp_m = re.search(r'cmp\((?:x|r)0, (\d+)\);', instr) |
|
if cmp_m: |
|
instructions.append(f'if (max_mr > {int(cmp_m[1])-1}) {{ {instr} }}') |
|
return |
|
|
|
if instr_name == 'push' or instr_name == 'pop': |
|
instructions.append(instr) |
|
return |
|
|
|
if 'mem[sp' in instr: |
|
instructions.append(instr) |
|
return |
|
|
|
if '// ks -= MR' in instr or '// a += MR' in instr: |
|
|
|
instructions.append( |
|
re.sub(r'(add|subs)\((\w\d+), (\w\d+), (\d+)\);', |
|
r'\1(\2, \3, max_mr * sizeof(void*));', instr)) |
|
return |
|
|
|
if instr_name == 'ldr': |
|
|
|
|
|
|
|
|
|
if reg not in vector_register_map: |
|
|
|
ldr_m = re.search(r'mem\[([rx]\d+)', instr) |
|
if ldr_m: |
|
reg = ldr_m[1] |
|
|
|
|
|
if instr_name == 'dup': |
|
|
|
regs = re.findall('(v\d+)[^a-z]', instr) |
|
clamp_reg = vector_register_usage['clamp'] |
|
if len(regs) > 1 and regs[1] in clamp_reg: |
|
instructions.append(instr) |
|
return |
|
|
|
if reg not in vector_register_map: |
|
instructions.append(instr) |
|
return |
|
|
|
max_mr = vector_register_map[reg] |
|
if (max_mr == 0): |
|
instructions.append(instr) |
|
return |
|
|
|
instructions.append(f'if (max_mr > {max_mr}) {{ {instr} }}') |
|
|
|
|
|
def parse_microkernel( |
|
lines: List[str], prfm_mode: PrfmMode, is_a53: bool, |
|
vector_register_map: Mapping[str, int], vector_register_usage) -> Tuple[List[str], List[str]]: |
|
|
|
|
|
labels = [] |
|
sc = ';' |
|
instructions = [] |
|
for line in lines: |
|
|
|
|
|
m = re.fullmatch(IFDEF_RE, line) |
|
if m: |
|
continue |
|
|
|
m = re.fullmatch(COMMENT_RE, line) |
|
if m: |
|
emit_instruction(m[1], instructions, vector_register_map, vector_register_usage) |
|
continue |
|
|
|
m = re.fullmatch(LABEL, line) |
|
if m: |
|
labels.append(m[1]) |
|
instructions.append(f'bind(l{m[1]}){sc}') |
|
continue |
|
m = re.fullmatch(INSTR_RE, line) |
|
if m: |
|
emit_instruction(f'{fix_instr_name(m[1])}(){sc} {m[2]}', instructions, |
|
vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_OP_RE, line) |
|
if m: |
|
emit_instruction(f'{fix_instr_name(m[1])}({m[2]}){sc} {m[3]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP_REG, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}], {m[5]}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_REG, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_CONSEC_RE, line) |
|
if m: |
|
emit_instruction(f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_LIST_RE, line) |
|
if m: |
|
emit_instruction(f'{fix_instr_name(m[1])}({{{m[2]}}}){sc} {m[3]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_MEMOP_OFFSET_RE, line) |
|
if m: |
|
if m[1].lower() == 'pld': |
|
emit_prefetch_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}, {m[3]}]){sc} {m[4]}', prfm_mode, |
|
instructions) |
|
else: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}, {m[3]}]){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_MEMOP_RE, line) |
|
if m: |
|
if m[1].lower() == 'pld': |
|
emit_prefetch_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}]){sc} {m[4]}', prfm_mode, |
|
instructions) |
|
else: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}]){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_MEMOP_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}]){sc} {m[4]}', |
|
instructions, vector_register_map, is_a53, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_MEMOP_IMM_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}], {m[4]}){sc} {m[5]}', |
|
instructions, vector_register_map, is_a53, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_MEMOP_OFFSET_RE, line) |
|
if m: |
|
if m[5]: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]++){sc} {m[6]}', |
|
instructions, vector_register_map, is_a53, vector_register_usage) |
|
else: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[6]}', |
|
instructions, vector_register_map, is_a53, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_MEMOP_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}]){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_MEMOP_OFFSET_RE, line) |
|
if m: |
|
if m[6]: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]++){sc} {m[7]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
else: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}, {m[5]}]){sc} {m[7]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_MEMOP_IMM_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, mem[{m[4]}], {m[5]}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_IMM_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {m[3]}){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_REG_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}, {fix_regs(m[4])}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_REG_IMM_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, {m[5]}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({fix_regs(m[2])}, {fix_regs(m[3])}){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REGLIST_CONSECT, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}], {{{m[3]}-{m[4]}}}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REGLIST_CONSECT_WB, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}-{m[4]}}}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REGLIST_INDIV_WB, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}(mem[{m[2]}]++, {{{m[3]}}}){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_B_IMM, line) |
|
if m: |
|
emit_instruction(f'{fix_instr_name(m[1])}(l{m[2]}){sc} {m[4]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_B_REG_IMM_IMM, line) |
|
if m: |
|
instructions.append( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, l{m[4]}){sc} {m[6]}') |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_INDIV_MEMOP_IMM, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{fix_regs(m[2])}}}, mem[{m[3]}], {m[4]}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_INDEX_MEMOP_IMM, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{m[2]}()}}, {m[3]}, mem[{m[4]}], {m[5]}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REG_INDEX_REG, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}()[{m[3]}], {m[4]}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_CONSEC_MEMOP, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{m[2]}-{m[3]}}}, mem[{m[4]}]{maybe_wb(m[5])}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_REPLICATE_MEMOP, line) |
|
if m: |
|
if m[5]: |
|
emit_instruction( |
|
f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]++){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
else: |
|
emit_instruction( |
|
f'{fix_replicate_instruction(fix_instr_name(m[1]))}({{{remove_brackets(m[2])}}}, mem[{m[4]}]){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(INSTR_REGLIST_INDEX_MEMOP, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({{{m[2]}}}, mem[{m[3]}]{maybe_wb(m[4])}){sc} {m[5]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
m = re.fullmatch(P2ALIGN_RE, line) |
|
if m: |
|
instructions.append(f'align({1 << int(m[1])}){sc}') |
|
continue |
|
m = re.fullmatch(INSTR_REG_FPSCR, line) |
|
if m: |
|
instructions.append(f'{fix_instr_name(m[1])}({m[2]}, {m[3]}){sc} {m[4]}') |
|
continue |
|
m = re.fullmatch(INSTR_PLD_MEMOP, line) |
|
if m: |
|
emit_prefetch_instruction( |
|
f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}]){sc} {m[4]}', prfm_mode, |
|
instructions) |
|
continue |
|
m = re.fullmatch(INSTR_PLD_MEMOP_OFFSET, line) |
|
if m: |
|
emit_prefetch_instruction( |
|
f'{fix_instr_name(m[1])}(k{m[2]}, mem[{m[3]}, {m[4]}]){sc} {m[5]}', |
|
prfm_mode, instructions) |
|
continue |
|
m = re.fullmatch(INSTR_REG_REG_REG_COND_RE, line) |
|
if m: |
|
emit_instruction( |
|
f'{fix_instr_name(m[1])}({m[2]}, {m[3]}, {m[4]}, k{m[5]}){sc} {m[6]}', |
|
instructions, vector_register_map, vector_register_usage) |
|
continue |
|
|
|
|
|
if line.strip() == '': |
|
instructions.append('') |
|
continue |
|
|
|
|
|
if line.strip().startswith('.'): |
|
continue |
|
|
|
if line.startswith('END_FUNCTION'): |
|
break |
|
|
|
|
|
print(f'ERROR: {line}', file=sys.stderr) |
|
sys.exit(1) |
|
|
|
return instructions, labels |
|
|
|
|
|
def emit_instructions_with_same_check(check: str, instrs: List[str], |
|
output: List[str]) -> None: |
|
"""A helper method to emit a list of instructions which share the same check. |
|
""" |
|
if not instrs: |
|
return |
|
m = re.search(r'(\W*)if \(([^)]+)\) \{', instrs[0]) |
|
indent = '' |
|
check = '' |
|
if m: |
|
indent = m[1] |
|
check = m[2] |
|
output.append(f'{indent}if ({check}) {{') |
|
for instr in instrs: |
|
if instr.strip().startswith('//'): |
|
output.append(f' {instr}') |
|
continue |
|
|
|
|
|
start = instr.index('{') |
|
end = instr.rindex('}') |
|
output.append(f' {indent}{instr[start+1:end].strip()}') |
|
output.append(f'{indent}}}') |
|
|
|
|
|
def merge_consecutive_checks(instructions: List[str]) -> List[str]: |
|
"""Each instruction has its own check, leading to excessive number of checks, e.g. |
|
|
|
if (clamp) { fmin(v0) } |
|
if (clamp) { fmin(v1) } |
|
... |
|
if (clamp) { fmin(v10) } |
|
|
|
This walks the instructions stream, checks for consecutive checks for the same |
|
condition, |
|
and merge them: |
|
|
|
if (clamp) { |
|
fmin(v0); |
|
fmin(v1); |
|
... |
|
fmin(v10); |
|
} |
|
|
|
This assumes that checks should be on the same line as the instruction, e.g. |
|
`if (clamp) { ... }` |
|
""" |
|
previous_check = None |
|
current_check = None |
|
|
|
output = [] |
|
|
|
instructions_with_same_check = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for instr in instructions: |
|
m = re.search(r'if \(([^)]+)\) \{ .+', instr) |
|
if m: |
|
current_check = m.group(1) |
|
if (current_check == previous_check): |
|
instructions_with_same_check.append(instr) |
|
else: |
|
|
|
comment = '' |
|
if output and output[-1].strip().startswith('//'): |
|
comment = output.pop() |
|
elif instructions_with_same_check and instructions_with_same_check[ |
|
-1].strip().startswith('//'): |
|
comment = instructions_with_same_check.pop() |
|
emit_instructions_with_same_check(previous_check, |
|
instructions_with_same_check, output) |
|
previous_check = current_check |
|
instructions_with_same_check = [instr] |
|
if comment: |
|
|
|
|
|
instructions_with_same_check.insert( |
|
0, f'if ({current_check}) {{ {comment} }}') |
|
elif re.fullmatch(r'\W*//.+', instr) and len( |
|
instructions_with_same_check) != 0: |
|
instructions_with_same_check.append(instr) |
|
else: |
|
emit_instructions_with_same_check(previous_check, |
|
instructions_with_same_check, output) |
|
previous_check = None |
|
output.append(instr) |
|
instructions_with_same_check = [] |
|
return output |
|
|
|
|
|
def insert_post_operations(instructions: List[str]): |
|
index = 0 |
|
|
|
|
|
for i, l in enumerate(instructions): |
|
if 'Store full ' in l: |
|
index = i |
|
break |
|
assert (instructions[index - 1].strip() == '') |
|
instructions.insert( |
|
index - 1, |
|
'perform_post_operations(max_mr, num_post_operations, post_operations);') |
|
return instructions |
|
|
|
|
|
def find_params_offset_and_register(lines: List[str]) -> Tuple[str, str]: |
|
for line in lines: |
|
if 'params' in line: |
|
params_m = re.search(r'sp\W+\+\W+(\d+)', line) |
|
reg_m = re.search(r'((?:r|x)\d+)', line.split()[-1]) |
|
if params_m and reg_m: |
|
return params_m[1], reg_m[1] |
|
return None, None |
|
|
|
|
|
def convert(input_file: str, post_op: bool, reload_params: bool, debug: bool, force_prfm: bool) -> None: |
|
output = [] |
|
arch = None |
|
kernel_type = GEMM |
|
minmax = False |
|
base_filename = os.path.basename(input_file) |
|
if base_filename.startswith('f16-'): |
|
ctype = 'uint16_t' |
|
elif base_filename.startswith('f32-'): |
|
ctype = 'float' |
|
elif base_filename.startswith('qs8-') or base_filename.startswith('qc8-'): |
|
ctype = 'int8_t' |
|
elif base_filename.startswith('qu8-'): |
|
ctype = 'uint8_t' |
|
else: |
|
print('ERROR: unknown ctype') |
|
sys.exit(1) |
|
|
|
if 'aarch32' in input_file: |
|
arch = AARCH32 |
|
elif 'aarch64' in input_file: |
|
arch = AARCH64 |
|
else: |
|
print('ERROR: unknown architecture') |
|
sys.exit(1) |
|
|
|
if 'igemm' in input_file: |
|
kernel_type = IGEMM |
|
if 'minmax' in input_file: |
|
minmax = True |
|
prfm_mode = PrfmMode.NoPrfm |
|
if 'prfm' in input_file: |
|
prfm_mode = PrfmMode.PrfmInFileName |
|
assert(not force_prfm) |
|
if force_prfm: |
|
prfm_mode = PrfmMode.ForcePrfm |
|
|
|
mr = 0 |
|
nr = 0 |
|
m = re.search(r'(\d+)x(\d+)', input_file) |
|
if m: |
|
mr = int(m[1]) |
|
nr = int(m[2]) |
|
|
|
|
|
instructions = [] |
|
|
|
prologue = [] |
|
|
|
fn_name = '' |
|
|
|
lines = [] |
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
lines = f.read().splitlines() |
|
|
|
begin_function_index = 0 |
|
for i, line in enumerate(lines): |
|
if 'BEGIN_FUNCTION' in line: |
|
begin_function_index = i |
|
break |
|
|
|
fn_name = lines[begin_function_index].split()[1] |
|
|
|
|
|
prologue_lines = lines[:begin_function_index + 1] |
|
|
|
microkernel_body = lines[begin_function_index + 1:] |
|
|
|
prologue, vector_register_map, vector_register_usage = parse_prologue(input_file, prologue_lines, |
|
arch, minmax, kernel_type, |
|
prfm_mode, mr, post_op) |
|
if debug: |
|
print(vector_register_map) |
|
print(vector_register_usage) |
|
|
|
params_offset, params_register = find_params_offset_and_register( |
|
prologue_lines) |
|
if not params_register: |
|
print(fn_name) |
|
print('Unable to find params register') |
|
sys.exit(1) |
|
|
|
is_a53 = 'cortex_a53' in fn_name |
|
instructions, labels = parse_microkernel(microkernel_body, prfm_mode, is_a53, |
|
vector_register_map, vector_register_usage) |
|
|
|
instructions = merge_consecutive_checks(instructions) |
|
instructions = merge_consecutive_checks(instructions) |
|
if post_op: |
|
instructions = insert_post_operations(instructions) |
|
|
|
|
|
for p in prologue: |
|
output.append(p) |
|
|
|
labels_str = ', '.join(f'l{l}' for l in labels) |
|
output.append(f' assert(max_mr <= {mr});') |
|
output.append(f' assert(nc_mod_nr < {nr});') |
|
output.append(' assert(kc != 0);') |
|
output.append(f' assert(kc % sizeof({ctype}) == 0);') |
|
if kernel_type == IGEMM: |
|
output.append(' assert(ks != 0);') |
|
output.append('') |
|
output.append(f' Label {labels_str};') |
|
output.append( |
|
' const size_t num_post_operations = jit_gemm_params->num_post_operations;' |
|
) |
|
if not post_op: |
|
output.append(' (void) num_post_operations; // Silence unused warning.') |
|
if post_op: |
|
output.append( |
|
' const xnn_post_operation* post_operations = jit_gemm_params->post_operations;' |
|
) |
|
|
|
if minmax: |
|
if ctype == 'float': |
|
output.append(' const float min = jit_gemm_params->f32_minmax.min;') |
|
output.append(' const float max = jit_gemm_params->f32_minmax.max;') |
|
output.append(' const bool clamp_min = min != -std::numeric_limits<float>::infinity();') |
|
output.append(' const bool clamp_max = max != +std::numeric_limits<float>::infinity();') |
|
elif ctype == 'uint16_t': |
|
output.append(' const uint16_t min = jit_gemm_params->f16_minmax.min;') |
|
output.append(' const uint16_t max = jit_gemm_params->f16_minmax.max;') |
|
output.append(' const bool clamp_min = min != UINT16_C(0xFC00); // -Inf.') |
|
output.append(' const bool clamp_max = max != UINT16_C(0x7C00); // Inf.') |
|
else: |
|
print('ERROR: unknown ctype for min/max params') |
|
sys.exit(1) |
|
|
|
output.append( |
|
' assert(num_post_operations == 0 || (!clamp_min && !clamp_max));') |
|
|
|
indent = ' ' |
|
for i in instructions: |
|
if i.strip().startswith('#'): |
|
output.append(indent + fix_comments(i)) |
|
elif i.strip().startswith('//'): |
|
output.append(indent + i) |
|
elif i.strip() == '': |
|
output.append('') |
|
else: |
|
output.append(indent + (i).rstrip()) |
|
if arch == AARCH32: |
|
output.append(indent + 'align(16);') |
|
else: |
|
output.append(indent + 'align(16, AlignInstruction::kHlt);') |
|
|
|
output.append('}') |
|
|
|
if post_op: |
|
output.append('') |
|
output.append( |
|
get_post_operation_implementation(arch, mr, params_register, |
|
params_offset, reload_params, vector_register_usage)) |
|
output.append('') |
|
output.append('} // namespace') |
|
output.append(f'}} // namespace {arch}') |
|
output.append('} // namespace xnnpack') |
|
output.append('') |
|
if prfm_mode == PrfmMode.PrfmInFileName: |
|
print_generator_definition( |
|
output, |
|
kernel_type, |
|
remove_prfm_from_fn_name(fn_name), |
|
arch, |
|
minmax, |
|
prefetch='false, ') |
|
output.append('') |
|
print_generator_definition( |
|
output, kernel_type, fn_name, arch, minmax, prefetch='true, ') |
|
else: |
|
print_generator_definition(output, kernel_type, fn_name, arch, minmax) |
|
|
|
return output |
|
|
|
|
|
def print_generator_definition(output, |
|
kernel_type, |
|
fn_name, |
|
arch, |
|
minmax, |
|
prefetch=''): |
|
if kernel_type == GEMM: |
|
output.append( |
|
f'xnn_status_t {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {{' |
|
) |
|
else: |
|
output.append( |
|
f'xnn_status_t {fix_fn_name(fn_name)}(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {{' |
|
) |
|
output.append(f' using namespace xnnpack::{arch};') |
|
output.append(' Generator g(code);') |
|
if minmax: |
|
output.append(' assert(params != nullptr);') |
|
if kernel_type == GEMM: |
|
if minmax: |
|
output.append( |
|
f' g.generate({prefetch}max_mr, nc_mod_nr, kc, static_cast<const jit_gemm_params*>(params));' |
|
) |
|
else: |
|
output.append(f' g.generate({prefetch}max_mr, nc_mod_nr, kc, nullptr);') |
|
else: |
|
if minmax: |
|
output.append( |
|
f' g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, static_cast<const jit_gemm_params*>(params));' |
|
) |
|
else: |
|
output.append( |
|
f' g.generate({prefetch}max_mr, nc_mod_nr, kc, ks, nullptr);') |
|
output.append(' g.finalize();') |
|
output.append(' if (g.error() != xnnpack::Error::kNoError) {') |
|
output.append(' return xnn_status_invalid_state;') |
|
output.append(' }') |
|
output.append(' return xnn_status_success;') |
|
output.append('}') |
|
|
|
|
|
def main(sys_args): |
|
parser = argparse.ArgumentParser( |
|
description='Convert assembly to to JIT C++, writes to stdout.') |
|
parser.add_argument( |
|
'-i', |
|
'--input', |
|
metavar='input_file', |
|
help='Input assembly filename', |
|
required=True) |
|
parser.add_argument( |
|
'-o', |
|
'--output', |
|
metavar='output_file', |
|
help='Output cc filename', |
|
required=True) |
|
parser.add_argument( |
|
'--post-op', |
|
help='Should support post operation', |
|
default=True, |
|
action=argparse.BooleanOptionalAction) |
|
parser.add_argument( |
|
'--reload-params', |
|
help='Should reload params pointer before post operation', |
|
default=False, |
|
action=argparse.BooleanOptionalAction) |
|
parser.add_argument( |
|
"--debug", |
|
help='Output debugging information', |
|
default=False, |
|
action=argparse.BooleanOptionalAction) |
|
parser.add_argument( |
|
"--force-prfm", |
|
help='Force PRFM instructions in output', |
|
default=False, |
|
action=argparse.BooleanOptionalAction) |
|
args = parser.parse_args(sys_args) |
|
|
|
output = '\n'.join(convert(args.input, args.post_op, args.reload_params, args.debug, args.force_prfm)) |
|
|
|
output += '\n' |
|
|
|
output_name = args.output |
|
txt_changed = True |
|
if os.path.exists(output_name): |
|
with codecs.open(output_name, 'r', encoding='utf-8') as output_file: |
|
ofr = output_file.read() |
|
txt_changed = ofr != output |
|
if txt_changed: |
|
with codecs.open(output_name, 'w', encoding='utf-8') as output_file: |
|
output_file.write(output) |
|
|
|
|
|
if __name__ == '__main__': |
|
main(sys.argv[1:]) |
|
|