martin-gorner HF staff commited on
Commit
38f8411
1 Parent(s): d20c9b9
Files changed (1) hide show
  1. models.py +2 -1
models.py CHANGED
@@ -60,7 +60,8 @@ def get_default_layout_map(preset_name, device_mesh):
60
  patch_key = "decoder_block.*attention.*(query|key|value).kernel"
61
  layout_map.pop(patch_key)
62
  layout_map[patch_key] = (None, "model", "batch")
63
- return layout_map
 
64
 
65
 
66
  def log_applied_layout_map(model):
 
60
  patch_key = "decoder_block.*attention.*(query|key|value).kernel"
61
  layout_map.pop(patch_key)
62
  layout_map[patch_key] = (None, "model", "batch")
63
+
64
+ return layout_map
65
 
66
 
67
  def log_applied_layout_map(model):