Whisper like ASR model but with some advanced ideas. Experimental. Full script just install dependencies and run. The model included is -not- trained. Its a blank (tabula rasa) newly intialized version of the script "medium" sized. I'm experimenting with some of the new stuff from the vision llm people but with audio.. Here is a super cool paper: https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2022.949142/full

Updated. Was having some issues there with the hybrid attention and tensor sharing.. fixed.!

Drop-in enhanced givens rotary block -- Its like a rubiks cube of embbedings :)

Think of regular rope embeddings as a rotating 3d block in space.. Now add columns and rows that rotate and then rotate the faces of each resulting cube :

class CombinedRotaryEmbedding(nn.Module): def init(self, n_state, n_head, num_rotations, base=10000, checkpointing=False): super().init() self.n_state = n_state self.n_head = n_head self.h_dim = n_state // n_head self.num_rotations = num_rotations self.base = base self.checkpointing = checkpointing

    self.thetas = nn.Parameter(torch.zeros(num_rotations))
    self.rotation_pairs = nn.Parameter(data=torch.rand(num_rotations, 2) * self.h_dim)
    self.theta_scale = nn.Parameter(data=torch.ones(1))  
    self.rotation_matrix = nn.Parameter(data=torch.eye(n=self.h_dim))
    self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))

def givens_rotation_matrix(self, n_state, i, j, theta):
    G = torch.eye(n_state, device=theta.device)
    G[i, i] = math.cos(theta)
    G[i, j] = -math.sin(theta)
    G[j, i] = math.sin(theta)
    G[j, j] = math.cos(theta)
    return G

def update_base(self, new_base):
    self.base = float(new_base)
    self.base = new_base
    self.inv_freq = nn.Parameter(data=1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim)))

def reset_parameters(self):
    nn.init.orthogonal_(tensor=self.rotation_matrix)
    nn.init.zeros_(tensor=self.thetas)

def forward(self, x):
    if self.checkpointing:
        return checkpoint(self._forward, x)
    else:
        return self._forward(x)

def _forward(self, x):
    if x.dim() not in [3, 4]:
        raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
    
    if x.dim() == 3:
        batch_size, seq_len, n_state = x.size()
        x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
    else:
        batch_size, seq_len, n_head, h_dim = x.size()
        if n_head != self.n_head or h_dim != self.h_dim:
            raise ValueError(f"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}")
    
    x = x.reshape(-1, self.h_dim)
    
    for k in range(self.num_rotations):
        i, j = self.rotation_pairs[k].long()
        theta = self.thetas[k] * self.theta_scale  
        G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta)
        x = torch.matmul(input=x, other=G)
    
    x = torch.matmul(input=x, other=self.rotation_matrix)
    x = x.view(batch_size, seq_len, self.n_head, self.h_dim)
    
    sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device), self.inv_freq.to(device=x.device))
    sin = sinusoid_inp.sin()[None, :, None, :]
    cos = sinusoid_inp.cos()[None, :, None, :]
    
    x1, x2 = x[..., ::2], x[..., 1::2]
    x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
    x = x.view(batch_size, seq_len, self.n_state)
    return x
Downloads last month
12
Inference API
Unable to determine this model's library. Check the docs .