DZRobo commited on
Commit
cdc8e19
·
1 Parent(s): e06d499

Skip NAG when context key dim mismatches layer

Browse files

Adds a check to bypass NAG cross-attention if the context's last dimension does not match the expected key dimension for the layer, preventing matmul crashes with mismatched context inputs.

Files changed (1) hide show
  1. mod/mg_sagpu_attention.py +10 -0
mod/mg_sagpu_attention.py CHANGED
@@ -440,6 +440,16 @@ def _kj_crossattn_forward_nag(self, x, context=None, value=None, mask=None, **kw
440
  if context is None or not torch.is_tensor(context):
441
  return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
442
 
 
 
 
 
 
 
 
 
 
 
443
  # Expect batch 2 with [uncond, cond]; if not, fall back
444
  if context.shape[0] < 2:
445
  return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
 
440
  if context is None or not torch.is_tensor(context):
441
  return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
442
 
443
+ # Sanity: skip NAG when context dim doesn't match this layer's expected key dim
444
+ try:
445
+ exp = getattr(self, "to_k", None)
446
+ exp_dim = int(exp.in_features) if exp is not None else None
447
+ except Exception:
448
+ exp_dim = None
449
+ if (exp_dim is not None) and (int(context.shape[-1]) != exp_dim):
450
+ # Mismatched CLIP/context (e.g., SDXL context 2048 on SD1 ControlNet 768); avoid matmul crash
451
+ return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
452
+
453
  # Expect batch 2 with [uncond, cond]; if not, fall back
454
  if context.shape[0] < 2:
455
  return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)