keras-chatbot-arena / models.py
martin-gorner's picture
bug fixes: logging of model loading
a2b7758
raw
history blame
3.95 kB
import keras
import keras_hub
model_presets = [
"hf://google/gemma-2-instruct-9b-keras",
"hf://meta-llama/Llama-3.1-8B-Instruct",
"hf://google/codegemma-7b-it-keras",
"hf://keras/mistral_instruct_7b_en",
"hf://keras/vicuna_1.5_7b_en",
]
model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)
def preset_to_website_url(preset):
preset = preset.removeprefix("hf://")
url = "http://huggingface.co/" + preset
return url
def get_appropriate_chat_template(preset):
return "Vicuna" if "vicuna" in preset else "auto"
def get_default_layout_map(preset_name, device_mesh):
# Llama's default layout map works for mistral and vicuna
# because their transformer layers have the same names.
if (
"Llama" in preset_name
or "mistral" in preset_name
or "vicuna" in preset_name
):
return keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
elif "gemma" in preset_name:
return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)
def log_applied_layout_map(model):
if "Gemma" in type(model).__name__:
transformer_decoder_block_name = "decoder_block_1"
elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__:
transformer_decoder_block_name = "transformer_layer_1"
else:
assert (0, "Model type not recognized. Cannot display model layout.")
# See how layer sharding was applied
embedding_layer = model.backbone.get_layer("token_embedding")
print(embedding_layer)
decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
print(type(decoder_block))
for variable in embedding_layer.weights + decoder_block.weights:
print(
f"{variable.path:<58} \
{str(variable.shape):<16} \
{str(variable.value.sharding.spec):<35} \
{str(variable.dtype)}"
)
def load_model(preset):
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
)
model_parallel = keras.distribution.ModelParallel(
layout_map=get_default_layout_map(preset, device_mesh),
batch_dim_name="batch",
)
with model_parallel.scope():
# These two buggy models need this workaround to be loaded in bfloat16
if "google/gemma-2-instruct-9b-keras" in preset:
model = keras_hub.models.GemmaCausalLM(
backbone=keras_hub.models.GemmaBackbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
preset
),
)
elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
model = keras_hub.models.Llama3CausalLM(
backbone=keras_hub.models.Llama3Backbone.from_preset(
preset, dtype="bfloat16"
),
preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
preset
),
)
else:
model = keras_hub.models.CausalLM.from_preset(
preset, dtype="bfloat16"
)
log_applied_layout_map(model)
return model
# Some small models too
# model1 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16")
# model2 = keras_hub.models.CausalLM.from_preset("hf://google/gemma-2b-it-keras", dtype="bfloat16")
# model3 = keras_hub.models.CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct", dtype="bfloat16")
# keras/gemma_1.1_instruct_7b_en