# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.

# %% auto 0
__all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
           'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']

# %% ../nbs/A. Neural modules.ipynb 2
import torch
import numpy as np
import math

from torch import Tensor, nn
import torch.nn.functional as F
from typing import Dict, Iterable, Optional

# import xformers.ops as xops

# %% ../nbs/A. Neural modules.ipynb 3
# Code in this file is mostly borrowed from
# https://github.com/openai/whisper/blob/main/whisper/model.py
# and is under the MIT License

class LayerNorm(nn.LayerNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

# Used in μP to initialize the weights and configure the optimizer
# These two layers map the transformer width into a fixed dimension
class LinearHead(nn.Linear):
    pass

class QueryHead(nn.Linear):
    pass

# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
def init_transformer(m):
    if isinstance(m, (nn.Linear, nn.Embedding)):
        torch.nn.init.trunc_normal_(m.weight, std=.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        torch.nn.init.constant_(m.bias, 0)
        torch.nn.init.constant_(m.weight, 1.0)

# %% ../nbs/A. Neural modules.ipynb 4
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

# %% ../nbs/A. Neural modules.ipynb 5
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
        super().__init__()
        self.n_state = n_state
        self.n_head = n_head
        self.sqrt_qk_scale = math.sqrt(qk_scale)
        self.query = QueryHead(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        self.cross = cross
        self.query_subsampling = 1
        self.key_subsampling = 1

        self.cached_kvx = None
        self.register_buffer('k_cache', None)
        self.register_buffer('v_cache', None)
        
        self.rotary = None
        if rope:
            self.rotary = Rotary(n_state // n_head)
        self.qkv = None
        self.kv = None

    def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
        cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
        self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
        self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)

    def merge_linears(self, layers, mults):
        bias = [x.bias for x in layers if x.bias is not None][0]
        din, dout = layers[0].weight.shape
        new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
        with torch.no_grad():
            new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
            new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
        return new

    def convert_for_eval(self):
        if self.qkv or self.kv: raise AttributeError("already converted")
        
        self.odim = self.key.weight.shape[1]
        if self.cross:
            self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
            self.kv = self.merge_linears([self.key, self.value],
                                         [self.sqrt_qk_scale, 1])
        else:
            self.qkv = self.merge_linears([self.query, self.key, self.value],
                                          [self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
        
    def split_heads(self, x, x_positions, rope=False, subsampling=1):
        x = x.view(*x.shape[:2], self.n_head, -1)
        if rope:
            x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        qx,
        q_positions,
        kvx,
        kv_positions,
        causal = False,
        mask=None,
    ):
        if self.qkv:
            q,k,v = self.qkv(qx).split(self.odim, dim=-1)
        elif self.kv:
            q = self.q(qx)
            k,v = self.kv(kvx).split(self.odim, dim=-1)
        else:
            q,k,v = None,None,None
        
        if q is None: q = self.query(qx) * self.sqrt_qk_scale
        q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)

        if kvx is not self.cached_kvx:
            if k is None: k = self.key(kvx) * self.sqrt_qk_scale
            k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
            if v is None: v = self.value(kvx)
            v = self.split_heads(v, kv_positions)
            if self.k_cache is not None:
                self.k_cache[:,:,kv_positions] = k
                self.v_cache[:,:,kv_positions] = v

        if self.k_cache is not None:
            k, v = self.k_cache, self.v_cache

        if mask is not None:
            mask = mask[q_positions]
            
        wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
        
        return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))

# %% ../nbs/A. Neural modules.ipynb 6
# modified from https://blog.eleuther.ai/rotary-embeddings/

import torch

class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if not self.seq_len_cached or seq_len > self.seq_len_cached:
            self.seq_len_cached = 2500
            # self.seq_len_cached = seq_len
            
            t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, :, None, :]
            self.sin_cached = emb.sin()[None, :, None, :]
        return self.cos_cached, self.sin_cached


# rotary pos emb helpers:
def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat(
        (-x2, x1), dim=len(x.shape)-1
    )

def rope_rotate(x, positions, cos, sin):
    return x * cos[:,positions] + rotate_half(x) * sin[:,positions]

# %% ../nbs/A. Neural modules.ipynb 7
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
                 qk_scale: float = 1, ffn_mult: int = 4):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * ffn_mult
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
    
    def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
        self.attn.setup_kv_cache(max_batch_size, max_seq_len)
        if self.cross_attn:
            self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
    
    def forward(
        self,
        x: Tensor,
        x_positions: Tensor = None,
        xa: Optional[Tensor] = None,
        xa_positions: Optional[Tensor] = None,
        causal = False,
        mask=None,
    ):
        lnx = self.attn_ln(x)
        x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
        if self.cross_attn:
            lnx = self.cross_attn_ln(x)
            x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
        x = x + self.mlp(self.mlp_ln(x))
        return x

# %% ../nbs/A. Neural modules.ipynb 8
class BaseDecoder(nn.Module):
    def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
        super().__init__()
        self.length = length
        self.width = width
        self.layers = nn.ModuleList([
            ResidualAttentionBlock(
                self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
            ) for _ in range(math.floor(depth))
        ])

        self.ln_post = LayerNorm(width)
        
        mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x, x_positions, xenc, xenc_positions):
        for i,l in enumerate(self.layers):
            x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)

        x = self.ln_post(x)

        return x

# %% ../nbs/A. Neural modules.ipynb 9
class EmbeddingProjector(nn.Linear):
    pass

class FlexEmbeddings(nn.Module):
    def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
        super().__init__()
        self.codes = codes
        self.special_codes = special_codes
        if frozen_width is None: frozen_width = width
        
        self.main = nn.Embedding(codes, frozen_width or width)
        self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
        self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
        if special_codes:
            self.special = special_embedding or nn.Embedding(special_codes, width)
            
        self.register_buffer('merged_in', None)
        self.register_buffer('merged_out', None)
        self.register_buffer('bias_out', None)
    
    def set_frozen_embeddings(self, values):
        with torch.no_grad():
            self.main.weight[:] = values
            self.main.lr_scale = 0
    
    @torch.no_grad()
    def convert_for_eval(self):
        if not self.special_codes: return
        # in
        main_w = self.main.weight
        if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
        weight = torch.cat([main_w, self.special.weight], dim=0)
        self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
        
        # out
        weight = self.main.weight
        if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
        self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
        if self.hidden_to_emb:
            self.bias_out = torch.cat([
                self.hidden_to_emb.bias @ self.main.weight.T,
                torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
            ], dim=0)
        else:
            self.bias_out = None

    def forward(self, toks):
        if not self.training and self.merged_in is not None:
            return self.merged_in(toks)
        
        if self.special_codes:
            special_mask = toks >= self.codes
            embs = self.main(torch.where(special_mask, 0, toks))
        else:
            embs = self.main(toks)
        
        if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
        
        if self.special_codes:
            embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
        
        return embs
    
    def unembed(self, embs):
        if not self.training and self.merged_out is not None:
            return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out

        orig_embs = embs
        if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
        
        main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
        
        if not self.special_codes:
            return main_logits
        
        special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
        return torch.cat([main_logits, special_logits], dim=-1)