File size: 2,279 Bytes
64ddf8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import pandas as pd
from hqq.utils.optimizer import find_optimal_configs
from lm_quant_toolkit.eval.common import get_mxq_quant_meta_data_file
def dump_mxq_objectives(model_ids, bit_budgets, csv_fp="mxq-objectives.csv"):
dikt = []
for model_id in model_ids:
short_id = model_id.split("/")[1]
_, fp = get_mxq_quant_meta_data_file(model_id)
for bit_budget in bit_budgets:
_, objective = find_optimal_configs(fp, bit_budget, time_limit=200)
dikt.append(
{
"model": short_id,
"bpp": bit_budget,
"fnorm": objective,
}
)
df = pd.DataFrame(dikt)
df.to_csv(csv_fp, index=False)
def dump_mxq_configs(model_ids, bit_budgets, csv_fp, weight_algo, factor):
dikt = []
for model_id in model_ids:
short_id = model_id.split("/")[1]
for bit_budget in bit_budgets:
try:
_, fp = get_mxq_quant_meta_data_file(model_id)
kwargs = {"weight_algo": weight_algo, "factor": factor}
configs, objective = find_optimal_configs(
fp,
bit_budget,
time_limit=200,
**kwargs,
)
for k, v in configs.items():
comps = k.split(".", 1)
layer, module = comps[0], comps[1]
dikt.append(
{
"model": short_id,
"module": module,
"layer": layer,
"memmb": 0, # to work with the plot program
"bit_budget": bit_budget,
"b1": v[0],
"g1": v[1],
"b2": v[2],
"g2": v[3],
}
)
except ValueError:
print(f"Warning: {bit_budget:.2f} unsolvable for model {model_id}")
df = pd.DataFrame(dikt)
df.to_csv(csv_fp, index=False)
if __name__ == "__main__":
bit_budgets = [4.51, 4.25, 4.13]
dump_mxq_configs(["meta-llama/Llama-2-7b-hf"], bit_budgets)
|