DZRobo
commited on
Commit
·
cdc8e19
1
Parent(s):
e06d499
Skip NAG when context key dim mismatches layer
Browse filesAdds a check to bypass NAG cross-attention if the context's last dimension does not match the expected key dimension for the layer, preventing matmul crashes with mismatched context inputs.
- mod/mg_sagpu_attention.py +10 -0
mod/mg_sagpu_attention.py
CHANGED
|
@@ -440,6 +440,16 @@ def _kj_crossattn_forward_nag(self, x, context=None, value=None, mask=None, **kw
|
|
| 440 |
if context is None or not torch.is_tensor(context):
|
| 441 |
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
# Expect batch 2 with [uncond, cond]; if not, fall back
|
| 444 |
if context.shape[0] < 2:
|
| 445 |
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
|
|
|
|
| 440 |
if context is None or not torch.is_tensor(context):
|
| 441 |
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
|
| 442 |
|
| 443 |
+
# Sanity: skip NAG when context dim doesn't match this layer's expected key dim
|
| 444 |
+
try:
|
| 445 |
+
exp = getattr(self, "to_k", None)
|
| 446 |
+
exp_dim = int(exp.in_features) if exp is not None else None
|
| 447 |
+
except Exception:
|
| 448 |
+
exp_dim = None
|
| 449 |
+
if (exp_dim is not None) and (int(context.shape[-1]) != exp_dim):
|
| 450 |
+
# Mismatched CLIP/context (e.g., SDXL context 2048 on SD1 ControlNet 768); avoid matmul crash
|
| 451 |
+
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
|
| 452 |
+
|
| 453 |
# Expect batch 2 with [uncond, cond]; if not, fall back
|
| 454 |
if context.shape[0] < 2:
|
| 455 |
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs)
|