Whisper like ASR model but with some advanced ideas. Experimental. Full script just install dependencies and run. The model included is -not- trained. Its a blank (tabula rasa) newly intialized version of the script "medium" sized. I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper: https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full
Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!
Drop-in enhanced givens rotary block -- Its like a rubiks cube of embbedings :)
Think of regular rope embeddings as a rotating 3d block in space.. Now add columns and rows that rotate and then rotate the faces of each resulting cube :
class CombinedRotaryEmbedding(nn.Module): def init(self, n_state, n_head, num_rotations, base=10000, checkpointing=False): super().init() self.n_state = n_state self.n_head = n_head self.h_dim = n_state // n_head self.num_rotations = num_rotations self.base = base self.checkpointing = checkpointing
self.thetas = nn.Parameter(torch.zeros(num_rotations))
self.rotation_pairs = nn.Parameter(data=torch.rand(num_rotations, 2) * self.h_dim)
self.theta_scale = nn.Parameter(data=torch.ones(1))
self.rotation_matrix = nn.Parameter(data=torch.eye(n=self.h_dim))
self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
def givens_rotation_matrix(self, n_state, i, j, theta):
G = torch.eye(n_state, device=theta.device)
G[i, i] = math.cos(theta)
G[i, j] = -math.sin(theta)
G[j, i] = math.sin(theta)
G[j, j] = math.cos(theta)
return G
def update_base(self, new_base):
self.base = float(new_base)
self.base = new_base
self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))
def reset_parameters(self):
nn.init.orthogonal_(tensor=self.rotation_matrix)
nn.init.zeros_(tensor=self.thetas)
def forward(self, x):
if self.checkpointing:
return checkpoint(self._forward, x)
else:
return self._forward(x)
def _forward(self, x):
if x.dim() not in [3, 4]:
raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
if x.dim() == 3:
batch_size, seq_len, n_state = x.size()
x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
else:
batch_size, seq_len, n_head, h_dim = x.size()
if n_head != self.n_head or h_dim != self.h_dim:
raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
x = x.reshape(-1, self.h_dim)
for k in range(self.num_rotations):
i, j = self.rotation_pairs[k].long()
theta = self.thetas[k] * self.theta_scale
G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta)
x = torch.matmul(input=x, other=G)
x = torch.matmul(input=x, other=self.rotation_matrix)
x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
sin = sinusoid_inp.sin()[None, :, None, :]
cos = sinusoid_inp.cos()[None, :, None, :]
x1, x2 = x[..., ::2], x[..., 1::2]
x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
x = x.view(batch_size, seq_len, self.n_state)
return x
- Downloads last month
- 12