Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
Commit
•
a2b7758
1
Parent(s):
b637f0b
bug fixes: logging of model loading
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ from models import (
|
|
17 |
|
18 |
model_labels_list = list(model_labels)
|
19 |
|
20 |
-
#
|
21 |
models = []
|
22 |
for preset in model_presets:
|
23 |
model = load_model(preset)
|
@@ -32,7 +32,7 @@ for preset in model_presets:
|
|
32 |
# model = keras_hub.models.Llama3CausalLM.from_preset(
|
33 |
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
|
34 |
# )
|
35 |
-
# models = [model, model]
|
36 |
|
37 |
|
38 |
def chat_turn_assistant_1(
|
@@ -170,7 +170,7 @@ with gr.Blocks(fill_width=True, title="Keras demo") as demo:
|
|
170 |
gr.HTML(
|
171 |
"<H2> Battle of the Keras chatbots on TPU</H2>"
|
172 |
+ "All the models are loaded into the TPU memory. "
|
173 |
-
+ "You can call
|
174 |
+ "The entire chat history is fed to the models at every submission."
|
175 |
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
|
176 |
)
|
|
|
17 |
|
18 |
model_labels_list = list(model_labels)
|
19 |
|
20 |
+
# load and warm up (compile) all the models
|
21 |
models = []
|
22 |
for preset in model_presets:
|
23 |
model = load_model(preset)
|
|
|
32 |
# model = keras_hub.models.Llama3CausalLM.from_preset(
|
33 |
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
|
34 |
# )
|
35 |
+
# models = [model, model, model, model, model]
|
36 |
|
37 |
|
38 |
def chat_turn_assistant_1(
|
|
|
170 |
gr.HTML(
|
171 |
"<H2> Battle of the Keras chatbots on TPU</H2>"
|
172 |
+ "All the models are loaded into the TPU memory. "
|
173 |
+
+ "You can call any of them and compare their answers. <br/>"
|
174 |
+ "The entire chat history is fed to the models at every submission."
|
175 |
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
|
176 |
)
|
models.py
CHANGED
@@ -39,24 +39,25 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
39 |
|
40 |
|
41 |
def log_applied_layout_map(model):
|
42 |
-
if "Gemma" in type(model):
|
43 |
transformer_decoder_block_name = "decoder_block_1"
|
44 |
-
elif "Llama3" in type(model) or "Mistral" in type(model):
|
45 |
transformer_decoder_block_name = "transformer_layer_1"
|
46 |
else:
|
47 |
assert (0, "Model type not recognized. Cannot display model layout.")
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
60 |
|
61 |
|
62 |
def load_model(preset):
|
|
|
39 |
|
40 |
|
41 |
def log_applied_layout_map(model):
|
42 |
+
if "Gemma" in type(model).__name__:
|
43 |
transformer_decoder_block_name = "decoder_block_1"
|
44 |
+
elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__:
|
45 |
transformer_decoder_block_name = "transformer_layer_1"
|
46 |
else:
|
47 |
assert (0, "Model type not recognized. Cannot display model layout.")
|
48 |
+
|
49 |
+
# See how layer sharding was applied
|
50 |
+
embedding_layer = model.backbone.get_layer("token_embedding")
|
51 |
+
print(embedding_layer)
|
52 |
+
decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
|
53 |
+
print(type(decoder_block))
|
54 |
+
for variable in embedding_layer.weights + decoder_block.weights:
|
55 |
+
print(
|
56 |
+
f"{variable.path:<58} \
|
57 |
+
{str(variable.shape):<16} \
|
58 |
+
{str(variable.value.sharding.spec):<35} \
|
59 |
+
{str(variable.dtype)}"
|
60 |
+
)
|
61 |
|
62 |
|
63 |
def load_model(preset):
|