File size: 4,981 Bytes
2ca0c5e
 
 
 
40912b5
2ca0c5e
 
 
 
 
40912b5
 
 
 
 
2ca0c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40912b5
1365804
40912b5
 
 
1365804
2ca0c5e
1365804
 
 
 
 
 
 
 
 
 
 
 
 
 
38f8411
 
2ca0c5e
 
 
40912b5
 
a2b7758
2ca0c5e
40912b5
 
 
2ca0c5e
40912b5
 
 
a2b7758
 
 
 
 
 
 
 
 
 
 
 
 
2ca0c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import keras
import keras_hub

model_presets = [
    # 8B params models
    "hf://google/gemma-2-instruct-9b-keras",
    "hf://meta-llama/Llama-3.1-8B-Instruct",
    "hf://google/codegemma-7b-it-keras",
    "hf://keras/mistral_instruct_7b_en",
    "hf://keras/vicuna_1.5_7b_en",
    # "keras/gemma_1.1_instruct_7b_en", # won't fit?
    # 1-3B params models
    "hf://meta-llama/Llama-3.2-1B-Instruct",
    "hf://google/gemma-2b-it-keras",
    "hf://meta-llama/Llama-3.2-3B-Instruct",
]

model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)


def preset_to_website_url(preset):
    preset = preset.removeprefix("hf://")
    url = "http://huggingface.co/" + preset
    return url


def get_appropriate_chat_template(preset):
    return "Vicuna" if "vicuna" in preset else "auto"


def get_default_layout_map(preset_name, device_mesh):
    # Llama's default layout map works for mistral and vicuna
    # because their transformer layers have the same names.
    if (
        "Llama" in preset_name
        or "mistral" in preset_name
        or "vicuna" in preset_name
    ):
        layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
        # Default layout map patch:
        # This line is missing for some Llama models (TODO: fix this in keras_hub)
        layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
        return layout_map

    elif "gemma" in preset_name:
        layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)

        if "gemma-2b-" in preset_name:
            # Default layout map patch:
            # Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
            # Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
            # However:
            # The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
            # Which means sharding NB_HEADS on the "model" dimension.
            # But gemma-2b-it-keras has only 1 head so this won't work: must patch it
            # TODO: fix this in the Gemma layout map in Keras hub.
            patch_key = "decoder_block.*attention.*(query|key|value).kernel"
            layout_map.pop(patch_key)
            layout_map[patch_key] = (None, "model", "batch")

        return layout_map


def log_applied_layout_map(model):
    print("Model class:", type(model).__name__)

    if "Gemma" in type(model).__name__:
        transformer_decoder_block_name = "decoder_block_1"
    elif "Llama" in type(model).__name__:  # works for Llama (Vicuna) and Llama3
        transformer_decoder_block_name = "transformer_layer_1"
    elif "Mistral" in type(model).__name__:
        transformer_decoder_block_name = "transformer_layer_1"
    else:
        print("Unknown architecture. Cannot display the applied layout.")
        return

    # See how layer sharding was applied
    embedding_layer = model.backbone.get_layer("token_embedding")
    print(embedding_layer)
    decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
    print(type(decoder_block))
    for variable in embedding_layer.weights + decoder_block.weights:
        print(
            f"{variable.path:<58}  \
                {str(variable.shape):<16}  \
                {str(variable.value.sharding.spec):<35} \
                {str(variable.dtype)}"
        )


def load_model(preset):
    devices = keras.distribution.list_devices()
    device_mesh = keras.distribution.DeviceMesh(
        shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
    )
    model_parallel = keras.distribution.ModelParallel(
        layout_map=get_default_layout_map(preset, device_mesh),
        batch_dim_name="batch",
    )

    with model_parallel.scope():
        # These two buggy models need this workaround to be loaded in bfloat16
        if "google/gemma-2-instruct-9b-keras" in preset:
            model = keras_hub.models.GemmaCausalLM(
                backbone=keras_hub.models.GemmaBackbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
            model = keras_hub.models.Llama3CausalLM(
                backbone=keras_hub.models.Llama3Backbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        else:
            model = keras_hub.models.CausalLM.from_preset(
                preset, dtype="bfloat16"
            )

    log_applied_layout_map(model)
    return model