martin-gorner HF staff commited on
Commit
a2b7758
1 Parent(s): b637f0b

bug fixes: logging of model loading

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. models.py +15 -14
app.py CHANGED
@@ -17,7 +17,7 @@ from models import (
17
 
18
  model_labels_list = list(model_labels)
19
 
20
- # lod a warm up (compile) all the models
21
  models = []
22
  for preset in model_presets:
23
  model = load_model(preset)
@@ -32,7 +32,7 @@ for preset in model_presets:
32
  # model = keras_hub.models.Llama3CausalLM.from_preset(
33
  # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
34
  # )
35
- # models = [model, model]
36
 
37
 
38
  def chat_turn_assistant_1(
@@ -170,7 +170,7 @@ with gr.Blocks(fill_width=True, title="Keras demo") as demo:
170
  gr.HTML(
171
  "<H2> Battle of the Keras chatbots on TPU</H2>"
172
  + "All the models are loaded into the TPU memory. "
173
- + "You can call them at will and compare their answers. <br/>"
174
  + "The entire chat history is fed to the models at every submission."
175
  + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
176
  )
 
17
 
18
  model_labels_list = list(model_labels)
19
 
20
+ # load and warm up (compile) all the models
21
  models = []
22
  for preset in model_presets:
23
  model = load_model(preset)
 
32
  # model = keras_hub.models.Llama3CausalLM.from_preset(
33
  # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
34
  # )
35
+ # models = [model, model, model, model, model]
36
 
37
 
38
  def chat_turn_assistant_1(
 
170
  gr.HTML(
171
  "<H2> Battle of the Keras chatbots on TPU</H2>"
172
  + "All the models are loaded into the TPU memory. "
173
+ + "You can call any of them and compare their answers. <br/>"
174
  + "The entire chat history is fed to the models at every submission."
175
  + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
176
  )
models.py CHANGED
@@ -39,24 +39,25 @@ def get_default_layout_map(preset_name, device_mesh):
39
 
40
 
41
  def log_applied_layout_map(model):
42
- if "Gemma" in type(model):
43
  transformer_decoder_block_name = "decoder_block_1"
44
- elif "Llama3" in type(model) or "Mistral" in type(model):
45
  transformer_decoder_block_name = "transformer_layer_1"
46
  else:
47
  assert (0, "Model type not recognized. Cannot display model layout.")
48
- # See how layer sharding was applied
49
- embedding_layer = model.backbone.get_layer("token_embedding")
50
- print(embedding_layer)
51
- decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
52
- print(type(decoder_block))
53
- for variable in embedding_layer.weights + decoder_block.weights:
54
- print(
55
- f"{variable.path:<58} \
56
- {str(variable.shape):<16} \
57
- {str(variable.value.sharding.spec):<35} \
58
- {str(variable.dtype)}"
59
- )
 
60
 
61
 
62
  def load_model(preset):
 
39
 
40
 
41
  def log_applied_layout_map(model):
42
+ if "Gemma" in type(model).__name__:
43
  transformer_decoder_block_name = "decoder_block_1"
44
+ elif "Llama3" in type(model).__name__ or "Mistral" in type(model).__name__:
45
  transformer_decoder_block_name = "transformer_layer_1"
46
  else:
47
  assert (0, "Model type not recognized. Cannot display model layout.")
48
+
49
+ # See how layer sharding was applied
50
+ embedding_layer = model.backbone.get_layer("token_embedding")
51
+ print(embedding_layer)
52
+ decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
53
+ print(type(decoder_block))
54
+ for variable in embedding_layer.weights + decoder_block.weights:
55
+ print(
56
+ f"{variable.path:<58} \
57
+ {str(variable.shape):<16} \
58
+ {str(variable.value.sharding.spec):<35} \
59
+ {str(variable.dtype)}"
60
+ )
61
 
62
 
63
  def load_model(preset):