chen459664's picture
Add files using upload-large-folder tool
64ddf8d verified
import gc
import glob
import os
from datetime import datetime
from pathlib import Path
import pandas as pd
import torch
from hqq.core.quantize import BaseQuantizeConfig as HQQQuantConfig
from hqq.utils.optimizer import find_optimal_configs
from lm_quant_toolkit.utils.hub import LLAMA_MODELS, VIT_OPENCLIP_MODELS
HQQ_CONFIGS = [
("b8g32", HQQQuantConfig(nbits=8, group_size=32, quant_scale=True)),
("b8g64", HQQQuantConfig(nbits=8, group_size=64, quant_scale=True)),
("b8g128", HQQQuantConfig(nbits=8, group_size=128, quant_scale=True)),
("b4g32", HQQQuantConfig(nbits=4, group_size=32, quant_scale=True)),
("b4g64", HQQQuantConfig(nbits=4, group_size=64, quant_scale=True)),
("b4g128", HQQQuantConfig(nbits=4, group_size=128, quant_scale=True)),
("b3g32", HQQQuantConfig(nbits=3, group_size=32, quant_scale=True)),
("b3g64", HQQQuantConfig(nbits=3, group_size=64, quant_scale=True)),
("b3g128", HQQQuantConfig(nbits=3, group_size=128, quant_scale=True)),
("b2g16", HQQQuantConfig(nbits=2, group_size=16, quant_scale=True)),
("b2g32", HQQQuantConfig(nbits=2, group_size=32, quant_scale=True)),
("b2g64", HQQQuantConfig(nbits=2, group_size=64, quant_scale=True)),
]
def get_mxq_quant_meta_data_file(model_id):
short_id = model_id.split("/")[1]
data_dir = os.path.join(os.path.dirname(__file__), "..", "..", "data")
fp = os.path.join(data_dir, f"fnorm-{short_id}.csv")
return os.path.exists(fp), os.path.abspath(fp)
def persist_progress(df, progress_path):
"""Save the progress of experiment for resumption."""
Path(progress_path).parent.mkdir(parents=True, exist_ok=True)
df.to_csv(progress_path, index=False)
def save_partial_metric(experiment_name, algo, model_id, config, metric, result_dir):
metrics = [metric]
df = pd.DataFrame(metrics)
result_dir = os.path.join(result_dir, experiment_name)
os.makedirs(result_dir, exist_ok=True)
model_short_id = model_id.split("/")[1]
file_name = f"{result_dir}/partial-{algo}-{model_short_id}-{config}.csv"
df.to_csv(file_name, index=False)
def _dump_cuda_mem_snapshot(experiment_name, model_id, algo, result_dir):
mem_fp = f"{result_dir}/{experiment_name}/mem-snapshot-{algo}-{model_id.split('/')[1]}.pickle"
torch.cuda.memory._dump_snapshot(mem_fp)
def combine_metrics(experiment_name, result_dir):
dfs = []
iters = glob.iglob(f"{result_dir}/{experiment_name}/partial-*.csv")
for it in iters:
df = pd.read_csv(it)
dfs.append(df)
combined = pd.concat(dfs)
ts_str = datetime.now().strftime("%Y%m%d%H%M%S")
file_name = f"{result_dir}/{experiment_name}/result-{experiment_name}-{ts_str}.csv"
Path(file_name).parent.mkdir(parents=True, exist_ok=True)
combined.to_csv(file_name, index=False)
def cleanup(model):
del model
torch.cuda.empty_cache()
gc.collect()
def calc_bits(b1, g1, b2=8, g2=128):
return b1 + 2 * b2 / g1 + 32 / g1 / g2
def get_mxq_bits(reduction_pcts=[3, 5, 8]):
nbits = []
for cfg in HQQ_CONFIGS:
bpp = calc_bits(
cfg[1]["weight_quant_params"]["nbits"],
cfg[1]["weight_quant_params"]["group_size"],
8,
128,
)
nbits.extend([round(bpp * (1 - pct / 100), 2) for pct in reduction_pcts])
return sorted(list(set(nbits)), reverse=True)
def _reset_peak_memory_stats():
return torch.cuda.reset_peak_memory_stats()
def get_memory_metrics():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
def plan_eval_bit_budgets(
model_arch="ViT",
points=5,
step=1,
bases=[4.51],
include_base=False,
):
for base in bases:
ideals, solvables = get_eval_plan(model_arch, base, points, step, include_base)
print("*" * 72)
print(f"base: {base}")
for t in zip(ideals, solvables):
print(f"ideal: {t[0]:.2f}, solvable: {t[1]:.2f}")
print("*" * 72)
def get_eval_plan(model_arch, base, points, step, include_base):
ideals = []
solvables = []
start = 0 if include_base else 1
for point in range(start, points + 1):
tentative = round(base + point * step, 2)
ideals.append(tentative)
ret = try_solvable(model_arch, tentative, step)
if ret is not None:
solvables.append(ret)
else:
solvables.append(0.0)
return ideals, solvables
def try_solvable(model_arch, bit_budget, step):
if model_arch == "ViT":
model_ids = VIT_OPENCLIP_MODELS.keys()
else:
model_ids = [
key for key in LLAMA_MODELS if LLAMA_MODELS[key].get("experiment", False)
]
dikt = {}
feasible_budget = round(bit_budget, 2)
for model_id in model_ids:
attempts = 1
_, fp = get_mxq_quant_meta_data_file(model_id)
while True:
try:
# find_optimal_configs(fp, feasible_budget, time_limit=200)
find_optimal_configs(
fp,
feasible_budget,
time_limit=200,
weight_algo="sensi-milp",
)
dikt[model_id] = feasible_budget
break
except ValueError:
print(f"Warning: {feasible_budget:.2f} unsolvable for model {model_id}")
if attempts > 3:
return None
feasible_budget += -0.01 if step < 0 else 0.01
attempts += 1
if len(set(dikt.values())) > 1:
for model_id, budget in dikt.items():
if abs(budget - feasible_budget) > 0.01:
try:
fp = get_mxq_quant_meta_data_file(model_id)
find_optimal_configs(fp, feasible_budget, time_limit=200)
except ValueError:
print(
f"Warning: {feasible_budget:.2f} unsolvable for model {model_id}"
)
return None
return feasible_budget
def debug_milp_solvable(model_id, bit_budget):
_, fp = get_mxq_quant_meta_data_file(model_id)
configs = find_optimal_configs(fp, bit_budget, time_limit=200)
print(configs)
def plan_432_bits():
bases = [4.51, 4.25, 4.13]
plan_eval_bit_budgets(
model_arch="llm", points=5, step=0.02, bases=bases, include_base=True
)
plan_eval_bit_budgets(
model_arch="llm", points=5, step=-0.02, bases=bases, include_base=True
)
bases = [3.51, 3.25, 3.13, 2.51, 2.25, 2.13]
plan_eval_bit_budgets(
model_arch="llm", points=3, step=0.02, bases=bases, include_base=True
)
plan_eval_bit_budgets(
model_arch="llm", points=3, step=-0.02, bases=bases, include_base=True
)
def plan_567_bits():
best_bit_budget = calc_bits(8, 32, 8, 128)
save_objs = [10, 20, 30, 40]
bases = [best_bit_budget * (100 - obj) / 100 for obj in save_objs]
plan_eval_bit_budgets(
model_arch="llm", points=5, step=0.02, bases=bases, include_base=True
)
plan_eval_bit_budgets(
model_arch="llm", points=5, step=-0.02, bases=bases, include_base=True
)
def fill_budget_gap(start, stop, step=0.02):
result = []
d = start + step
while d < stop:
result.append(round(d, 2))
d += step
return result
def fill_gaps():
gaps = []
# gaps.extend(fill_budget_gap(3.57, 4.03))
# gaps.extend(fill_budget_gap(3.29, 3.45))
# gaps.extend(fill_budget_gap(6.92, 7.56))
# gaps.extend(fill_budget_gap(4.61, 5.00))
# gaps.extend(fill_budget_gap(5.20, 5.86))
gaps.extend(fill_budget_gap(6.01, 6.70))
# gaps.extend([4.39, 4.37])
print(gaps)
plan_eval_bit_budgets(
model_arch="llm", points=0, step=0.02, bases=gaps, include_base=True
)
if __name__ == "__main__":
# fill_gaps()
plan_432_bits()
plan_567_bits()