File size: 3,622 Bytes
e1640d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, QuantoConfig, GenerationConfig
import torch
import argparse

"""
 usage:
    export SAFETENSORS_FAST_GPU=1
    python main.py --quant_type int8 --world_size 8 --model_id <model_path>
"""

def generate_quanto_config(hf_config: AutoConfig, quant_type: str):
    QUANT_TYPE_MAP = {
        "default": None,
        "int8": QuantoConfig(
            weights="int8",
            modules_to_not_convert=[
                "lm_head",
                "embed_tokens",
            ] + [f"model.layers.{i}.coefficient" for i in range(hf_config.num_hidden_layers)]
            + [f"model.layers.{i}.block_sparse_moe.gate" for i in range(hf_config.num_hidden_layers)]
        ),
    }
    return QUANT_TYPE_MAP[quant_type]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quant_type", type=str, default="default", choices=["default", "int8"])
    parser.add_argument("--model_id", type=str, required=True)
    parser.add_argument("--world_size", type=int, required=True)
    return parser.parse_args()


def check_params(args, hf_config: AutoConfig):
    if args.quant_type == "int8":
        assert args.world_size >= 8, "int8 weight-only quantization requires at least 8 GPUs"

    assert hf_config.num_hidden_layers % args.world_size == 0, f"num_hidden_layers({hf_config.num_hidden_layers}) must be divisible by world_size({args.world_size})"


@torch.no_grad()
def main():
    args = parse_args()
    print("\n=============== Argument ===============")
    for key in vars(args):
        print(f"{key}: {vars(args)[key]}")
    print("========================================")

    model_id = args.model_id

    hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
    check_params(args, hf_config)
    quantization_config = generate_quanto_config(hf_config, args.quant_type)
 
    device_map = {
        'model.embed_tokens': 'cuda:0',
        'model.norm': f'cuda:{args.world_size - 1}',
        'lm_head': f'cuda:{args.world_size - 1}'
    }
    layers_per_device = hf_config.num_hidden_layers // args.world_size
    for i in range(args.world_size):
        for j in range(layers_per_device):
            device_map[f'model.layers.{i * layers_per_device + j}'] = f'cuda:{i}'

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "Hello!"
    messages = [
        {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant created by Minimax based on MiniMax-Text-01 model."}]},
        {"role": "user", "content": [{"type": "text", "text": prompt}]},
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer(text, return_tensors="pt").to("cuda")
    quantized_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="bfloat16",
        device_map=device_map,
        quantization_config=quantization_config,
        trust_remote_code=True,
        offload_buffers=True,
    )
    generation_config = GenerationConfig(
        max_new_tokens=20,
        eos_token_id=200020,
        use_cache=True,
    )
    generated_ids = quantized_model.generate(**model_inputs, generation_config=generation_config)
    print(f"generated_ids: {generated_ids}")
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    print(response)

if __name__ == "__main__":
    main()