Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
Commit
•
38f8411
1
Parent(s):
d20c9b9
bug fix
Browse files
models.py
CHANGED
@@ -60,7 +60,8 @@ def get_default_layout_map(preset_name, device_mesh):
|
|
60 |
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
61 |
layout_map.pop(patch_key)
|
62 |
layout_map[patch_key] = (None, "model", "batch")
|
63 |
-
|
|
|
64 |
|
65 |
|
66 |
def log_applied_layout_map(model):
|
|
|
60 |
patch_key = "decoder_block.*attention.*(query|key|value).kernel"
|
61 |
layout_map.pop(patch_key)
|
62 |
layout_map[patch_key] = (None, "model", "batch")
|
63 |
+
|
64 |
+
return layout_map
|
65 |
|
66 |
|
67 |
def log_applied_layout_map(model):
|