Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
import keras | |
import keras_hub | |
model_presets = [ | |
"hf://google/gemma-2-instruct-9b-keras", | |
"hf://meta-llama/Llama-3.1-8B-Instruct", | |
"hf://google/codegemma-7b-it-keras", | |
"hf://keras/mistral_instruct_7b_en", | |
"hf://keras/vicuna_1.5_7b_en", | |
] | |
model_labels = map(lambda s: s.removeprefix("hf://"), model_presets) | |
model_labels = map(lambda s: s.removeprefix("google/"), model_labels) | |
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels) | |
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels) | |
def preset_to_website_url(preset): | |
preset = preset.removeprefix("hf://") | |
url = "http://huggingface.co/" + preset | |
return url | |
def get_appropriate_chat_template(preset): | |
return "Vicuna" if "vicuna" in preset else "auto" | |
def get_default_layout_map(preset_name, device_mesh): | |
# Llama's default layout map works for mistral and vicuna | |
# because their transformer layers have the same names. | |
if ( | |
"Llama" in preset_name | |
or "mistral" in preset_name | |
or "vicuna" in preset_name | |
): | |
return keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) | |
elif "gemma" in preset_name: | |
return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh) | |
def log_applied_layout_map(model): | |
if "Gemma" in type(model).__name__: | |
transformer_decoder_block_name = "decoder_block_1" | |
elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__: | |
transformer_decoder_block_name = "transformer_layer_1" | |
else: | |
assert (0, "Model type not recognized. Cannot display model layout.") | |
# See how layer sharding was applied | |
embedding_layer = model.backbone.get_layer("token_embedding") | |
print(embedding_layer) | |
decoder_block = model.backbone.get_layer(transformer_decoder_block_name) | |
print(type(decoder_block)) | |
for variable in embedding_layer.weights + decoder_block.weights: | |
print( | |
f"{variable.path:<58} \ | |
{str(variable.shape):<16} \ | |
{str(variable.value.sharding.spec):<35} \ | |
{str(variable.dtype)}" | |
) | |
def load_model(preset): | |
devices = keras.distribution.list_devices() | |
device_mesh = keras.distribution.DeviceMesh( | |
shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices | |
) | |
model_parallel = keras.distribution.ModelParallel( | |
layout_map=get_default_layout_map(preset, device_mesh), | |
batch_dim_name="batch", | |
) | |
with model_parallel.scope(): | |
# These two buggy models need this workaround to be loaded in bfloat16 | |
if "google/gemma-2-instruct-9b-keras" in preset: | |
model = keras_hub.models.GemmaCausalLM( | |
backbone=keras_hub.models.GemmaBackbone.from_preset( | |
preset, dtype="bfloat16" | |
), | |
preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset( | |
preset | |
), | |
) | |
elif "meta-llama/Llama-3.1-8B-Instruct" in preset: | |
model = keras_hub.models.Llama3CausalLM( | |
backbone=keras_hub.models.Llama3Backbone.from_preset( | |
preset, dtype="bfloat16" | |
), | |
preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset( | |
preset | |
), | |
) | |
else: | |
model = keras_hub.models.CausalLM.from_preset( | |
preset, dtype="bfloat16" | |
) | |
log_applied_layout_map(model) | |
return model | |
# Some small models too | |
# model1 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16") | |
# model2 = keras_hub.models.CausalLM.from_preset("hf://google/gemma-2b-it-keras", dtype="bfloat16") | |
# model3 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct", dtype="bfloat16") | |
# keras/gemma_1.1_instruct_7b_en | |