| # -*- coding: utf-8 -*- | |
| import torch | |
| def normalize_output(q, k, o): | |
| k = k.cumsum(-2) | |
| z = (q * k).sum(-1, keepdim=True) | |
| return o / (z + 1e-10) | |
| # -*- coding: utf-8 -*- | |
| import torch | |
| def normalize_output(q, k, o): | |
| k = k.cumsum(-2) | |
| z = (q * k).sum(-1, keepdim=True) | |
| return o / (z + 1e-10) | |