# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb.

# %% auto 0
__all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock',
           'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer']

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1
import io
import time
import math
import random
import dataclasses

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity, schedule
from fastcore.basics import store_attr
from huggingface_hub import hf_hub_download

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3
from pathlib import Path
import json
from fastprogress import progress_bar, master_bar
import webdataset as wds

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4
from .train import *
from .modules import *
from . import vq_stoks

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8
def rand(start, end):
    return random.random() * (end - start) + start

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
    atoks_per_second = atoks_len / 30
    def _trunc(samples):
        for s in samples:
            if random.random() < random_trunc_p:
                seconds = rand(0.3, 30)
                s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
            s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
            yield s
    return _trunc

def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
    def _pad(samples):
        for s in samples:
            s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
            s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
            yield s
    return _pad

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10
def speaker_id_extractor(speaker_map):
    def _extractor(samples):
        for s in samples:
            s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
            yield s
    return _extractor

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14
def load_datasets(
        input:str,             # webdataset folder
        samples:int,           # samples per epoch
        subsample:float=1,     # use a fraction of the files
        val_samples:int=512,
        random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
        stoks_pad_token=4096,
    ):

    if isinstance(input, (Path, str)):
        path = Path(input)
        if path.is_dir():
            glob = '*-s2a-*.tar.gz'
        else:
            glob = path.name
            path = path.parent
        input = Path(path).glob(glob)
    elif isinstance(input, list):
        pass
    else:
        raise ArgumentError("input should be either a list or a path with an optional glob specifier")
    shards = [str(x) for x in input]

    speakers = set()
    for shard in shards:
        with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
    speakers = {id:i for i,id in enumerate(sorted(speakers))}

    def ds(shards, length):
        ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
            wds.decode(),
            speaker_id_extractor(speakers),
            random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
            pad_samples(stoks_pad_token=stoks_pad_token),
            wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'),
            wds.batched(64),
        )
        ds.speakers = speakers
        ds.total_samples = length
        return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
    
    return (
        ds(shards[1:], samples),
        ds(shards[:1], val_samples),
    )

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33
import pylab as plt
import fastprogress
import IPython
import numpy as np

class CMLMVisual:
    """Visualize training progress"""
    def __init__ (self, model, masterbar, total_steps):
        self.model = model
        self.masterbar = masterbar
        self.total_steps = total_steps
        self.epochs = total_steps // masterbar.main_bar.total
        
        gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
        graph_fig = plt.figure(figsize=(10,6))
        self.graph_fig = graph_fig
        self.loss_p = graph_fig.add_subplot(gs[0])
        self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
        self.acc_p.tick_params('x', labelbottom=False)
        self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
        self.lr_p.tick_params('x', labelbottom=False)
        self.graph_out = None
        
        self.its = []
        self.train_losses = []
        self.val_losses = []
        self.lr_history = []
        self.acc = np.nan
        self.acc_history = []
        self.pacc_history = []
            
    def show(self):
        self.start_t = time.time()
        self.masterbar.write(["samples", "train", "val", "time"], table=True)
        self.graph_out = display(self.graph_fig, display_id=True)
        self.acc_out = display(IPython.display.HTML(''), display_id=True)
    
    def hide(self):
        if self.graph_out is not None:
            self.graph_out.update(IPython.display.HTML(''))
    
    def plot(self):
        loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
        loss_p.clear()
        loss_p.plot(self.its, self.train_losses)
        loss_p.plot(self.its, self.val_losses)
        loss_p.set_xlim(0, self.total_steps)
        loss_p.set_yscale('log')
        acc_p.clear()
        for k in self.acc_history[-1].keys():
            acc_p.plot(self.its, [x[k] for x in self.acc_history], ':')
#         acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0])))
        lr_p.clear()
        lrs = np.array(self.lr_history)
        lr_p.plot(self.its, lrs)
        self.graph_out.update(self.graph_fig)
    
    def add_data(self, it, lr, train_loss, val_los):
        self.its.append(it)
        self.train_losses.append(train_loss)
        self.val_losses.append(val_los)
        self.lr_history.append(lr)
        metrics = self.model.get_metrics()
        self.acc_history.append(metrics)
#         self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}")
#         self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy())
#         if self.acc_history:
        html  = "<h5>Accuracies:</h5><table>"
        html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>"
        html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>"
        html += "</table>"
        self.acc_out.update(IPython.display.HTML(html))
        self.plot()

    def add_table_row(self, it, avg_train_loss, val_loss):
        elapsed_t = time.time() - self.start_t
        self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
    
    def on_iter(self, bar, it, avg_train_loss, val_loss):
        epoch = math.ceil(it / self.total_steps * self.epochs)
        bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34
# modified from https://blog.eleuther.ai/rotary-embeddings/
import torch

class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, :, None, :]
            self.sin_cached = emb.sin()[None, :, None, :]
        return self.cos_cached, self.sin_cached


# rotary pos emb helpers:
def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat(
        (-x2, x1), dim=-1
    )

#@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin)

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35
from torch import Tensor, nn
import torch.nn.functional as F
from typing import Dict, Iterable, Optional

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
                 qk_scale: float = 1, ffn_mult: int = 4):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * ffn_mult
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
        
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        causal = False,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
    
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False):
        super().__init__()
        self.n_head = n_head
        self.sqrt_qk_scale = math.sqrt(qk_scale)
        self.query = QueryHead(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)
        
        self.rotary = None
        if rope:
            self.rotary = Rotary(n_state // n_head)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        causal = False,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        if self.sqrt_qk_scale != 1:
            q *= self.sqrt_qk_scale
            k *= self.sqrt_qk_scale
        
        wv, qk = self.qkv_attention_pth20(q, k, v, causal)
#         wv, qk = self.qkv_attention_xformers(q, k, v, causal)
        
        return self.out(wv), qk

    def qkv_attention_pth20(
        self, q: Tensor, k: Tensor, v: Tensor, causal = False
    ):
        n_batch, n_ctx, n_state = q.shape
        q = q.view(*q.shape[:2], self.n_head, -1)
        k = k.view(*k.shape[:2], self.n_head, -1)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        
        #print('before rot:', q.shape, k.shape)
        if self.rotary:
            q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
        #print(' after rot:', q.shape, k.shape)

        k = k.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        # modified for better performance under PyTorch 2.0
        wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal)

        # previously we've returned q@k which we don't have now
        # since it's not actually used anywhere else, let's just keep two return values for compatibility
        return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None

    def qkv_attention_xformers(
        self, q: Tensor, k: Tensor, v: Tensor, causal = False
    ):
        n_batch, n_ctx, n_state = q.shape
        q = q.view(*q.shape[:2], self.n_head, -1)
        k = k.view(*k.shape[:2], self.n_head, -1)
        v = v.view(*v.shape[:2], self.n_head, -1)

        if self.rotary:
            q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))

        bias = xops.LowerTriangularMask() if causal else None
        wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias)

        # previously we've returned q@k which we don't have now
        # since it's not actually used anywhere else, let's just keep two return values for compatibility
        return wv.flatten(start_dim=2), None

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36
class DelSumDecoder(nn.Module):
    def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None):
        super().__init__()
        self.length = length
        width = n_head * head_width
        self.width = width
        self.codes = codes
        self.quantizers = quantizers
        self.linear_heads = linear_heads

        self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)])
        if pos_embs is not None:
            self.register_buffer("positional_embedding", pos_embs)
        
        self.layers = nn.ModuleList([
            ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth))
        ])

        self.ln_post = LayerNorm(width)

        if self.linear_heads:
            self.heads = LinearHead(width, (codes+1) * quantizers, bias=False)
        else:
            self.splitter = nn.Sequential(
                nn.Linear(width, width * quantizers),
                nn.GELU(),
            )
            self.heads = nn.ModuleList([
                LinearHead(width, codes+1, bias=True) for _ in range(quantizers)
            ])

    def forward(self, toks, xenc):
        b,_,n = toks.shape
        newn = min(n+1, self.length)
        embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
        for i in range(self.quantizers):
            embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device))
            if i < n:
                embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1])

        x = embs.to(xenc.dtype)
    
        for l in self.layers:
            x = l(x, xenc, causal=True)
        x = self.ln_post(x)

        if self.linear_heads:
            logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3)
        else:
            split = self.splitter(x).view(b,newn,self.quantizers,self.width)
            logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1)

        return logits
    
class EmbeddingProjector(nn.Linear):
    pass

def rand(start, end):
    return random.random() * (end - start) + start
    
@dataclasses.dataclass
class Tunables:
    init_std :float = 9
    embeddings_std :float = 0.2
    embeddings_lr_scale: float = 10
    output_mult :float = 5.6
    # FIXME: try separate mults for self and cross attention
    query_mult :float = .3
    encoder_depth_ratio :float = 0.25
    linear_heads :bool = False
    rope :bool = True
    
    lr0 :float = 3e-3
    clip_gradient_norm :float = 2
    weight_decay :float = 1e-3
    warmup_steps :float = 2000

    random :bool = False

    def __post_init__(self):
        # randomize the hyperparams if requested
        if self.random:
            self.init_std = 2*10**rand(0,1)
            self.embeddings_std = 10**rand(-1.7,-0.22)
            self.embeddings_lr_scale = 2**rand(2,4)
            self.output_mult = 2**rand(1.5,3)
            self.query_mult = 2**rand(-3,-1.3)
            self.encoder_depth_ratio = random.choice([0.25,0.5])
            self.linear_heads = False
            self.rope = True
            
            self.lr0 = 3e-3
            self.clip_gradient_norm = 10**rand(-1,1)
            self.warmup_steps = 100*(10**rand(1.18,1.3))
            
    @staticmethod
    def upgrade(args):
        args = {k:v for k,v in args.items()}
        def old_default(name, value):
            if name not in args: args[name] = value
        old_default('rope', False)
        old_default('linear_heads', True)
        return args
            
class SADelARTransformer(nn.Module):
    def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4,
                 quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
        super().__init__()
        self.quantizers = quantizers
        width = n_head * head_width
        store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
        self.width = width
        self.base_width = 3 * head_width
        self.tunables = tunables
        
        if stoks_width is None: stoks_width = width
        if spk_width is None: spk_width = width
        self.emb_factor = width != stoks_width
        self.spk_factor = width != spk_width

        if tunables.rope:
            self.positional_embeddings = None
        else:
            self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
        
        self.speaker_embedding = nn.Embedding(len(speaker_map), width)
        self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
        if self.emb_factor:
            self.emb_to_hidden = nn.Linear(stoks_width, width)
        
        if self.spk_factor:
            self.spk_to_hidden = EmbeddingProjector(spk_width, width)

        qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
        
        encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
        decoder_depth = depth * 2 - encoder_depth
        self.encoder = nn.Sequential(*[
            ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
        ])
        self.ln_post = LayerNorm(width)
        
        self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale,
                                     length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult,
                                     depth=decoder_depth, quantizers=quantizers,
                                     linear_heads=tunables.linear_heads, rope=tunables.rope)

        self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
        self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
        self.apply(self.init_transformer)

    def setup(self, device):
        pass
        
    def load_frozen_semantic_embeddings(self, vqmodel):
        with torch.no_grad():
            self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
            self.semantic_embedding.lr_scale = 0
        
    def init_transformer(self, m):
        if isinstance(m, LinearHead):
            m.no_weight_decay = True
            torch.nn.init.constant_(m.weight, 0)
        elif isinstance(m, QueryHead):
            m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
            torch.nn.init.constant_(m.weight, 0)
        elif isinstance(m, nn.Embedding):
            m.no_weight_decay = True
            m.lr_scale = self.tunables.embeddings_lr_scale
            std = self.tunables.embeddings_std
            torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
        elif isinstance(m, EmbeddingProjector):
            m.lr_scale = self.tunables.embeddings_lr_scale/2
            std = self.tunables.init_std
            torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
        elif isinstance(m, nn.Linear):
            m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
            std = self.tunables.init_std / m.weight.shape[1]
            torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
            if m.bias is not None:
                torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
        elif isinstance(m, nn.LayerNorm):
            m.no_weight_decay = True
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1)

    def embed_stoks(self, Stoks):
        b,n = Stoks.shape
        if self.stoks_len == 1500:
            # converts 50 toks/s to 75 toks/s by adding padding between every two tokens
            x = Stoks.reshape(b,n//2,2)
            x = x.repeat_interleave(2, -1)[:,:,:3]
            x[:,:,1] = 1024
            x = x.reshape(b,n//2*3)
        else:
            # it's a lot easier with 25 toks/s
            x = Stoks.repeat_interleave(3, -1)
        # embed semantic tokens
        Sembs = self.semantic_embedding(x.to(torch.long))
        if self.emb_factor:
            Sembs = self.emb_to_hidden(Sembs)
        return Sembs

    def forward(self, Stoks, Atoks, speakers, noloss=False):
        Atoks = Atoks.to(torch.long)
        semb = self.embed_stoks(Stoks)
        with record_function("encoder"):
            if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
            xenc = self.ln_post(self.encoder(semb))
#             xenc = torch.zeros_like(xenc)
        with record_function("decoder"):
            Atoks_gt = Atoks.clone()
            Atoks_gt[Atoks == -100] = 1024
            # we can randomize speaker ids during validation to measure
            # the importance of the speaker embedding vs. just the acoustic prompt/prefix
#             if not self.training: speakers = speakers[torch.randperm(speakers.nelement())]
            spk_embs = self.speaker_embedding(speakers)
            if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
            logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1))
            logits *= self.tunables.output_mult / (self.width / self.base_width)
            
        if noloss:
            return logits

        with record_function("loss"):
            N = Atoks.shape[-1]
            loss = 0
            for i in range(self.quantizers):
                loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
            loss /= self.quantizers

        if not self.training:
            for i in range(self.quantizers):
                Atoks_i = Atoks[:,i,:N-i]
                valid_Atoks = Atoks_i != -100
                self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
                self.val_total[i] += valid_Atoks.float().sum()

        return logits, loss

    def get_metrics(self):
        metrics = {
            f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
        }
        self.val_true[:] = 0
        self.val_total[:] = 0
        return metrics

    #
    # inference
    #
    @classmethod
    def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None):
        if not local_filename:
            local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
        spec = torch.load(local_filename)
        if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
        model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
        model.load_state_dict(spec['state_dict'])
        model.eval()
        return model
    
    def get_extra_state(self):
        return { 'speaker_map': self.speaker_map }
    
    def set_extra_state(self, st):
        self.speaker_map = st['speaker_map']

    def load_checkpoint(self, local_filename):
        spec = torch.load(local_filename, map_location='cpu')
        assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
        state_dict = {k.replace('model.', ''):v
                      for k,v in spec['state_dict'].items()}
        self.load_state_dict(state_dict)
        return self
    
    def save_model(self, fname):
        torch.save(dict(config = self.__stored_args__,
                        tunables = dataclasses.asdict(self.tunables),
                        state_dict = self.state_dict()), fname)

    @property
    def device(self):
        return next(self.parameters()).device
    
    @torch.no_grad()
    def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True):
        dev = self.device
        if self.stoks_len == 1500:
            N = N or len(stoks) * 3 // 2
        else:
            N = N or len(stoks) * 3
        stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0)
        speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev)
        toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev)
        it = range(0,N)
        if show_progress_bar: it = progress_bar(it)
        for i in it:
            p = self(stoks, toks[:,:,:i], speakers, noloss=True)
            last_p = p[0,:,-1]
            if top_k:
                last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
            for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)):
                toks[0,j,max(0,i-j)] = tok
            if toks[0,0,i] == 1024: return toks[0,:,:i]
        return toks[0]

# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs):
    assert(dataset is not None)
    kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs)
    if size == 'micro':
        return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
    if size == 'tiny-narrow':
        return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
    if size == 'tiny':
        return SADelARTransformer(depth=4, n_head=6, **kwargs)
    if size == 'base':
        return SADelARTransformer(depth=6, n_head=8, **kwargs)
    if size == 'base-deep':
        return SADelARTransformer(depth=9, n_head=8, **kwargs)
    if size == 'base-wide':
        return SADelARTransformer(depth=6, n_head=12, **kwargs)
    if size == 'small/2':
        return SADelARTransformer(depth=9, n_head=12, **kwargs)
    if size == 'small':
        return SADelARTransformer(depth=12, n_head=12, **kwargs)
    if size == 'medium':
        return SADelARTransformer(depth=24, n_head=16, **kwargs)

def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
    if frozen_embeddings_model:
        vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
        model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
        model.load_frozen_semantic_embeddings(vqmodel)
    else:
        model = _make_model(size, quantizers, tunables, dataset)
    return model