File size: 2,279 Bytes
64ddf8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pandas as pd
from hqq.utils.optimizer import find_optimal_configs

from lm_quant_toolkit.eval.common import get_mxq_quant_meta_data_file


def dump_mxq_objectives(model_ids, bit_budgets, csv_fp="mxq-objectives.csv"):
    dikt = []
    for model_id in model_ids:
        short_id = model_id.split("/")[1]
        _, fp = get_mxq_quant_meta_data_file(model_id)
        for bit_budget in bit_budgets:
            _, objective = find_optimal_configs(fp, bit_budget, time_limit=200)
            dikt.append(
                {
                    "model": short_id,
                    "bpp": bit_budget,
                    "fnorm": objective,
                }
            )

    df = pd.DataFrame(dikt)
    df.to_csv(csv_fp, index=False)


def dump_mxq_configs(model_ids, bit_budgets, csv_fp, weight_algo, factor):
    dikt = []
    for model_id in model_ids:
        short_id = model_id.split("/")[1]
        for bit_budget in bit_budgets:
            try:
                _, fp = get_mxq_quant_meta_data_file(model_id)
                kwargs = {"weight_algo": weight_algo, "factor": factor}
                configs, objective = find_optimal_configs(
                    fp,
                    bit_budget,
                    time_limit=200,
                    **kwargs,
                )
                for k, v in configs.items():
                    comps = k.split(".", 1)
                    layer, module = comps[0], comps[1]
                    dikt.append(
                        {
                            "model": short_id,
                            "module": module,
                            "layer": layer,
                            "memmb": 0,  # to work with the plot program
                            "bit_budget": bit_budget,
                            "b1": v[0],
                            "g1": v[1],
                            "b2": v[2],
                            "g2": v[3],
                        }
                    )
            except ValueError:
                print(f"Warning: {bit_budget:.2f} unsolvable for model {model_id}")
    df = pd.DataFrame(dikt)
    df.to_csv(csv_fp, index=False)


if __name__ == "__main__":
    bit_budgets = [4.51, 4.25, 4.13]
    dump_mxq_configs(["meta-llama/Llama-2-7b-hf"], bit_budgets)