Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
File size: 4,981 Bytes
2ca0c5e 40912b5 2ca0c5e 40912b5 2ca0c5e 40912b5 1365804 40912b5 1365804 2ca0c5e 1365804 38f8411 2ca0c5e 40912b5 a2b7758 2ca0c5e 40912b5 2ca0c5e 40912b5 a2b7758 2ca0c5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import keras
import keras_hub
model_presets = [
# 8B params models
"hf://google/gemma-2-instruct-9b-keras",
"hf://meta-llama/Llama-3.1-8B-Instruct",
"hf://google/codegemma-7b-it-keras",
"hf://keras/mistral_instruct_7b_en",
"hf://keras/vicuna_1.5_7b_en",
# "keras/gemma_1.1_instruct_7b_en", # won't fit?
# 1-3B params models
"hf://meta-llama/Llama-3.2-1B-Instruct",
"hf://google/gemma-2b-it-keras",
"hf://meta-llama/Llama-3.2-3B-Instruct",
]
model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)
def preset_to_website_url(preset):
preset = preset.removeprefix("hf://")
url = "http://huggingface.co/" + preset
return url
def get_appropriate_chat_template(preset):
return "Vicuna" if "vicuna" in preset else "auto"
def get_default_layout_map(preset_name, device_mesh):
# Llama's default layout map works for mistral and vicuna
# because their transformer layers have the same names.
if (
"Llama" in preset_name
or "mistral" in preset_name
or "vicuna" in preset_name
):
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
# Default layout map patch:
# This line is missing for some Llama models (TODO: fix this in keras_hub)
layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
return layout_map
elif "gemma" in preset_name:
layout_map = keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
if "gemma-2b-" in preset_name:
# Default layout map patch:
# Gemma QKV weigts are shaped [NB_HEADS, EMBED_DIM, INNER_DIM]
# Llama QKV weights are shaped [EMBED_DIM, NB_HEADS, INNER_DIM]
# However:
# The default layout map for KQV weights on Gemma is: (model_dim,data_dim,None)
# Which means sharding NB_HEADS on the "model" dimension.
# But gemma-2b-it-keras has only 1 head so this won't work: must patch it
# TODO: fix this in the Gemma layout map in Keras hub.
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
layout_map.pop(patch_key)
layout_map[patch_key] = (None, "model", "batch")
return layout_map
def log_applied_layout_map(model):
print("Model class:", type(model).__name__)
if "Gemma" in type(model).__name__:
transformer_decoder_block_name = "decoder_block_1"
elif "Llama" in type(model).__name__: # works for Llama (Vicuna) and Llama3
transformer_decoder_block_name = "transformer_layer_1"
elif "Mistral" in type(model).__name__:
transformer_decoder_block_name = "transformer_layer_1"
else:
print("Unknown architecture. Cannot display the applied layout.")
return
# See how layer sharding was applied
embedding_layer = model.backbone.get_layer("token_embedding")
print(embedding_layer)
decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
print(type(decoder_block))
for variable in embedding_layer.weights + decoder_block.weights:
print(
f"{variable.path:<58} \
{str(variable.shape):<16} \
{str(variable.value.sharding.spec):<35} \
{str(variable.dtype)}"
)
def load_model(preset):
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
)
model_parallel = keras.distribution.ModelParallel(
layout_map=get_default_layout_map(preset, device_mesh),
batch_dim_name="batch",
)
with model_parallel.scope():
# These two buggy models need this workaround to be loaded in bfloat16
if "google/gemma-2-instruct-9b-keras" in preset:
model = keras_hub.models.GemmaCausalLM(
backbone=keras_hub.models.GemmaBackbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
preset
),
)
elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
model = keras_hub.models.Llama3CausalLM(
backbone=keras_hub.models.Llama3Backbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
preset
),
)
else:
model = keras_hub.models.CausalLM.from_preset(
preset, dtype="bfloat16"
)
log_applied_layout_map(model)
return model
|