blingBillie's picture
Upload 138 files
d443007
import argparse
import os
import pickle
from platform import system
from typing import List
import tvm
import tvm.testing
from tvm import relax
import mlc_llm
from mlc_llm import utils
from mlc_llm.relax_model import gpt_neox, llama, moss
def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--quantization-sym", action="store_true", default=False)
args.add_argument(
"--quantization-mode", type=str, choices=["int4", "int3", "fp4"], default="int4"
)
args.add_argument(
"--quantization-storage-nbit", type=int, choices=[32, 16, 8], default=32
)
args.add_argument("--no-quantize", action="store_true", default=False)
args.add_argument("--max-seq-len", type=int, default=-1)
args.add_argument("--target", type=str, default="auto")
args.add_argument(
"--db-path",
type=str,
default=None,
help="Path to log database. Default: ./log_db/{model}",
)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--use-cache",
type=int,
default=1,
help="Whether to use previously pickled IRModule and skip trace.",
)
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--debug-load-script", action="store_true", default=False)
args.add_argument(
"--llvm-mingw",
type=str,
default="",
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows",
)
args.add_argument("--system-lib", action="store_true", default=False)
parsed = args.parse_args()
assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
parsed.artifact_path = os.path.join(
parsed.artifact_path, parsed.model, parsed.dtype
)
parsed.export_kwargs = {}
parsed.lib_format = "so"
parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model)
utils.parse_target(parsed)
utils.argparse_postproc_common(parsed)
return parsed
def debug_dump_script(mod, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
dump_path = os.path.join(args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")
def debug_load_script(name, args):
input_path = os.path.join(args.artifact_path, "debug", name)
lib = {"__file__": input_path}
exec(compile(open(input_path, "rb").read(), input_path, "exec"), lib, lib)
return lib["Module"]
def debug_dump_shader(ex, name, args):
"""Debug dump mode"""
if not args.debug_dump:
return
target_kind = args.target.kind.default_keys[0]
suffix_map = {
"webgpu": ".wgsl",
"cuda": ".cu",
"metal": ".mtl",
"opencl": ".cl",
}
suffix = suffix_map.get(target_kind, ".txt")
dump_path = os.path.join(args.artifact_path, "debug", name + suffix)
source = ex.mod.imported_modules[0].imported_modules[0].get_source()
with open(dump_path, "w") as outfile:
outfile.write(source)
print(f"Dump shader to {dump_path}")
def mod_transform_before_build(
mod: tvm.IRModule,
model_params: List[tvm.nd.NDArray],
args: argparse.Namespace,
) -> tvm.IRModule:
"""First-stage: Legalize ops and trace"""
model_names = ["encoding", "decoding", "create_kv_cache", "softmax_with_temperature"]
if not args.no_quantize:
mod = mlc_llm.transform.GroupQuantize(
group_size=40 if args.quantization_mode.endswith("3") else 32,
sym=args.quantization_sym,
mode=args.quantization_mode,
storage_nbit=args.quantization_storage_nbit,
dtype=args.dtype,
)(mod)
mod = mlc_llm.transform.FuseTransposeMatmul()(mod)
mod = relax.pipeline.get_pipeline()(mod)
mod = mlc_llm.transform.FuseDecodeMatmulEwise(args.dtype, args.target_kind)(mod)
mod = relax.transform.DeadCodeElimination(model_names)(mod)
mod = relax.transform.LiftTransformParams()(mod)
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names)
debug_dump_script(mod_transform, "mod_lift_params.py", args)
new_params = utils.transform_params(mod_transform, model_params)
utils.save_params(new_params, args.artifact_path)
return mod_deploy
def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
target_kind = args.target_kind
debug_dump_script(mod_deploy, "mod_before_build.py", args)
if target_kind != "cpu":
from tvm import meta_schedule as ms
if os.path.exists(args.db_path):
db = ms.database.create(work_dir=args.db_path)
else:
db = ms.database.MemoryDatabase()
with db, tvm.target.Target("apple/m1-gpu-restricted"):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy)
if args.target_kind == "android":
mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()(mod_deploy)
mod_deploy = mlc_llm.dispatch.DispatchTIROperator(args.model_category)(
mod_deploy
)
mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)
if args.debug_load_script:
mod_deploy = debug_load_script("mod_build_stage_debug.py", args)
debug_dump_script(mod_deploy, "mod_build_stage.py", args)
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)
output_filename = f"{args.model}_{target_kind}_{args.dtype}.{args.lib_format}"
debug_dump_shader(ex, f"{args.model}_{target_kind}_{args.dtype}", args)
lib_path = os.path.join(args.artifact_path, output_filename)
ex.export_library(lib_path, **args.export_kwargs)
print(f"Finish exporting to {lib_path}")
def dump_split_tir(mod: tvm.IRModule):
template = """
from tvm.script import ir as I
from tvm.script import tir as T
# fmt: off
{content}
# fmt: on
"""
mod_static, mod_dynamic = utils.split_static_dynamic_tir(mod)
static_path = os.path.join(ARGS.artifact_path, "mod_tir_static.py")
dynamic_path = os.path.join(ARGS.artifact_path, "mod_tir_dynamic.py")
print(f"Dump static shape TIR to {static_path}")
with open(static_path, "w") as o_f:
o_f.write(template.format(content=mod_static.script()))
print(f"Dump dynamic shape TIR to {dynamic_path}")
with open(dynamic_path, "w") as o_f:
o_f.write(template.format(content=mod_dynamic.script()))
if __name__ == "__main__":
ARGS = _parse_args()
os.makedirs(ARGS.artifact_path, exist_ok=True)
os.makedirs(os.path.join(ARGS.artifact_path, "debug"), exist_ok=True)
cache_path = os.path.join(
ARGS.artifact_path, f"mod_cache_before_build_{ARGS.dtype}.pkl"
)
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
if not use_cache:
if ARGS.model_category == "llama":
mod, params = llama.get_model(ARGS)
elif ARGS.model_category == "gpt_neox":
mod, params = gpt_neox.get_model(ARGS.model, ARGS.model_path, ARGS.dtype)
elif ARGS.model_category == "moss":
mod, params = moss.get_model(ARGS)
else:
raise ValueError(f"Model {ARGS.model} not supported")
mod = mod_transform_before_build(mod, params, ARGS)
with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
else:
print(
f"Load cached module from {cache_path} and skip tracing. "
"You can use --use-cache=0 to retrace"
)
mod = pickle.load(open(cache_path, "rb"))
dump_split_tir(mod)
build(mod, ARGS)