Spaces:
Starting
on
T4
Starting
on
T4
File size: 940 Bytes
a277bb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from torch import nn
class PositionalEncodingsFixed(nn.Module):
def __init__(self, emb_dim, temperature=10000):
super(PositionalEncodingsFixed, self).__init__()
self.emb_dim = emb_dim
self.temperature = temperature
def _1d_pos_enc(self, mask, dim):
temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)
enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
enc = torch.stack([
enc[..., 0::2].sin(), enc[..., 1::2].cos()
], dim=-1).flatten(-2)
return enc
def forward(self, bs, h, w, device):
mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
x = self._1d_pos_enc(mask, dim=2)
y = self._1d_pos_enc(mask, dim=1)
return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)
|