|
import argparse |
|
import os |
|
import pickle |
|
from platform import system |
|
from typing import List |
|
|
|
import tvm |
|
import tvm.testing |
|
from tvm import relax |
|
|
|
import mlc_llm |
|
from mlc_llm import utils |
|
from mlc_llm.relax_model import gpt_neox, llama, moss |
|
|
|
|
|
def _parse_args(): |
|
args = argparse.ArgumentParser() |
|
utils.argparse_add_common(args) |
|
args.add_argument("--quantization-sym", action="store_true", default=False) |
|
args.add_argument( |
|
"--quantization-mode", type=str, choices=["int4", "int3", "fp4"], default="int4" |
|
) |
|
args.add_argument( |
|
"--quantization-storage-nbit", type=int, choices=[32, 16, 8], default=32 |
|
) |
|
args.add_argument("--no-quantize", action="store_true", default=False) |
|
args.add_argument("--max-seq-len", type=int, default=-1) |
|
args.add_argument("--target", type=str, default="auto") |
|
args.add_argument( |
|
"--db-path", |
|
type=str, |
|
default=None, |
|
help="Path to log database. Default: ./log_db/{model}", |
|
) |
|
args.add_argument("--artifact-path", type=str, default="dist") |
|
args.add_argument( |
|
"--use-cache", |
|
type=int, |
|
default=1, |
|
help="Whether to use previously pickled IRModule and skip trace.", |
|
) |
|
args.add_argument("--debug-dump", action="store_true", default=False) |
|
args.add_argument("--debug-load-script", action="store_true", default=False) |
|
|
|
args.add_argument( |
|
"--llvm-mingw", |
|
type=str, |
|
default="", |
|
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows", |
|
) |
|
args.add_argument("--system-lib", action="store_true", default=False) |
|
|
|
parsed = args.parse_args() |
|
assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0 |
|
|
|
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model) |
|
parsed.artifact_path = os.path.join( |
|
parsed.artifact_path, parsed.model, parsed.dtype |
|
) |
|
parsed.export_kwargs = {} |
|
parsed.lib_format = "so" |
|
parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model) |
|
utils.parse_target(parsed) |
|
utils.argparse_postproc_common(parsed) |
|
return parsed |
|
|
|
|
|
def debug_dump_script(mod, name, args): |
|
"""Debug dump mode""" |
|
if not args.debug_dump: |
|
return |
|
dump_path = os.path.join(args.artifact_path, "debug", name) |
|
with open(dump_path, "w") as outfile: |
|
outfile.write(mod.script(show_meta=True)) |
|
print(f"Dump mod to {dump_path}") |
|
|
|
|
|
def debug_load_script(name, args): |
|
input_path = os.path.join(args.artifact_path, "debug", name) |
|
lib = {"__file__": input_path} |
|
exec(compile(open(input_path, "rb").read(), input_path, "exec"), lib, lib) |
|
return lib["Module"] |
|
|
|
|
|
def debug_dump_shader(ex, name, args): |
|
"""Debug dump mode""" |
|
if not args.debug_dump: |
|
return |
|
target_kind = args.target.kind.default_keys[0] |
|
suffix_map = { |
|
"webgpu": ".wgsl", |
|
"cuda": ".cu", |
|
"metal": ".mtl", |
|
"opencl": ".cl", |
|
} |
|
suffix = suffix_map.get(target_kind, ".txt") |
|
dump_path = os.path.join(args.artifact_path, "debug", name + suffix) |
|
source = ex.mod.imported_modules[0].imported_modules[0].get_source() |
|
with open(dump_path, "w") as outfile: |
|
outfile.write(source) |
|
print(f"Dump shader to {dump_path}") |
|
|
|
|
|
def mod_transform_before_build( |
|
mod: tvm.IRModule, |
|
model_params: List[tvm.nd.NDArray], |
|
args: argparse.Namespace, |
|
) -> tvm.IRModule: |
|
"""First-stage: Legalize ops and trace""" |
|
model_names = ["encoding", "decoding", "create_kv_cache", "softmax_with_temperature"] |
|
|
|
if not args.no_quantize: |
|
mod = mlc_llm.transform.GroupQuantize( |
|
group_size=40 if args.quantization_mode.endswith("3") else 32, |
|
sym=args.quantization_sym, |
|
mode=args.quantization_mode, |
|
storage_nbit=args.quantization_storage_nbit, |
|
dtype=args.dtype, |
|
)(mod) |
|
mod = mlc_llm.transform.FuseTransposeMatmul()(mod) |
|
|
|
mod = relax.pipeline.get_pipeline()(mod) |
|
mod = mlc_llm.transform.FuseDecodeMatmulEwise(args.dtype, args.target_kind)(mod) |
|
mod = relax.transform.DeadCodeElimination(model_names)(mod) |
|
mod = relax.transform.LiftTransformParams()(mod) |
|
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names) |
|
|
|
debug_dump_script(mod_transform, "mod_lift_params.py", args) |
|
|
|
new_params = utils.transform_params(mod_transform, model_params) |
|
utils.save_params(new_params, args.artifact_path) |
|
return mod_deploy |
|
|
|
|
|
def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: |
|
target_kind = args.target_kind |
|
debug_dump_script(mod_deploy, "mod_before_build.py", args) |
|
if target_kind != "cpu": |
|
from tvm import meta_schedule as ms |
|
|
|
if os.path.exists(args.db_path): |
|
db = ms.database.create(work_dir=args.db_path) |
|
else: |
|
db = ms.database.MemoryDatabase() |
|
with db, tvm.target.Target("apple/m1-gpu-restricted"): |
|
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy) |
|
if args.target_kind == "android": |
|
mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()(mod_deploy) |
|
mod_deploy = mlc_llm.dispatch.DispatchTIROperator(args.model_category)( |
|
mod_deploy |
|
) |
|
mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy) |
|
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) |
|
|
|
if args.debug_load_script: |
|
mod_deploy = debug_load_script("mod_build_stage_debug.py", args) |
|
|
|
debug_dump_script(mod_deploy, "mod_build_stage.py", args) |
|
|
|
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) |
|
|
|
output_filename = f"{args.model}_{target_kind}_{args.dtype}.{args.lib_format}" |
|
|
|
debug_dump_shader(ex, f"{args.model}_{target_kind}_{args.dtype}", args) |
|
lib_path = os.path.join(args.artifact_path, output_filename) |
|
ex.export_library(lib_path, **args.export_kwargs) |
|
print(f"Finish exporting to {lib_path}") |
|
|
|
|
|
def dump_split_tir(mod: tvm.IRModule): |
|
template = """ |
|
from tvm.script import ir as I |
|
from tvm.script import tir as T |
|
|
|
# fmt: off |
|
{content} |
|
# fmt: on |
|
""" |
|
mod_static, mod_dynamic = utils.split_static_dynamic_tir(mod) |
|
static_path = os.path.join(ARGS.artifact_path, "mod_tir_static.py") |
|
dynamic_path = os.path.join(ARGS.artifact_path, "mod_tir_dynamic.py") |
|
print(f"Dump static shape TIR to {static_path}") |
|
with open(static_path, "w") as o_f: |
|
o_f.write(template.format(content=mod_static.script())) |
|
print(f"Dump dynamic shape TIR to {dynamic_path}") |
|
with open(dynamic_path, "w") as o_f: |
|
o_f.write(template.format(content=mod_dynamic.script())) |
|
|
|
|
|
if __name__ == "__main__": |
|
ARGS = _parse_args() |
|
os.makedirs(ARGS.artifact_path, exist_ok=True) |
|
os.makedirs(os.path.join(ARGS.artifact_path, "debug"), exist_ok=True) |
|
cache_path = os.path.join( |
|
ARGS.artifact_path, f"mod_cache_before_build_{ARGS.dtype}.pkl" |
|
) |
|
use_cache = ARGS.use_cache and os.path.isfile(cache_path) |
|
if not use_cache: |
|
if ARGS.model_category == "llama": |
|
mod, params = llama.get_model(ARGS) |
|
elif ARGS.model_category == "gpt_neox": |
|
mod, params = gpt_neox.get_model(ARGS.model, ARGS.model_path, ARGS.dtype) |
|
elif ARGS.model_category == "moss": |
|
mod, params = moss.get_model(ARGS) |
|
else: |
|
raise ValueError(f"Model {ARGS.model} not supported") |
|
mod = mod_transform_before_build(mod, params, ARGS) |
|
with open(cache_path, "wb") as outfile: |
|
pickle.dump(mod, outfile) |
|
print(f"Save a cached module to {cache_path}.") |
|
else: |
|
print( |
|
f"Load cached module from {cache_path} and skip tracing. " |
|
"You can use --use-cache=0 to retrace" |
|
) |
|
mod = pickle.load(open(cache_path, "rb")) |
|
dump_split_tir(mod) |
|
build(mod, ARGS) |
|
|