File size: 1,320 Bytes
ea5c647
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import os
import torch
from safetensors.torch import load_file


def main(file):
    print(f"loading: {file}")
    if os.path.splitext(file)[1] == ".safetensors":
        sd = load_file(file)
    else:
        sd = torch.load(file, map_location="cpu")

    values = []

    keys = list(sd.keys())
    for key in keys:
        if "lora_up" in key or "lora_down" in key:
            values.append((key, sd[key]))
    print(f"number of LoRA modules: {len(values)}")

    if args.show_all_keys:
        for key in [k for k in keys if k not in values]:
            values.append((key, sd[key]))
        print(f"number of all modules: {len(values)}")

    for key, value in values:
        value = value.to(torch.float32)
        print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")


def setup_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
    parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")

    return parser


if __name__ == "__main__":
    parser = setup_parser()

    args = parser.parse_args()

    main(args.file)