Spaces:
Runtime error
Runtime error
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb. | |
# %% auto 0 | |
__all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock', | |
'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer'] | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1 | |
import io | |
import time | |
import math | |
import random | |
import dataclasses | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.profiler import profile, record_function, ProfilerActivity, schedule | |
from fastcore.basics import store_attr | |
from huggingface_hub import hf_hub_download | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3 | |
from pathlib import Path | |
import json | |
from fastprogress import progress_bar, master_bar | |
import webdataset as wds | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4 | |
from .train import * | |
from .modules import * | |
from . import vq_stoks | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8 | |
def rand(start, end): | |
return random.random() * (end - start) + start | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9 | |
def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750): | |
atoks_per_second = atoks_len / 30 | |
def _trunc(samples): | |
for s in samples: | |
if random.random() < random_trunc_p: | |
seconds = rand(0.3, 30) | |
s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)] | |
s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)] | |
yield s | |
return _trunc | |
def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096): | |
def _pad(samples): | |
for s in samples: | |
s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token) | |
s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100) | |
yield s | |
return _pad | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10 | |
def speaker_id_extractor(speaker_map): | |
def _extractor(samples): | |
for s in samples: | |
s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]]) | |
yield s | |
return _extractor | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14 | |
def load_datasets( | |
input:str, # webdataset folder | |
samples:int, # samples per epoch | |
subsample:float=1, # use a fraction of the files | |
val_samples:int=512, | |
random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds | |
stoks_pad_token=4096, | |
): | |
if isinstance(input, (Path, str)): | |
path = Path(input) | |
if path.is_dir(): | |
glob = '*-s2a-*.tar.gz' | |
else: | |
glob = path.name | |
path = path.parent | |
input = Path(path).glob(glob) | |
elif isinstance(input, list): | |
pass | |
else: | |
raise ArgumentError("input should be either a list or a path with an optional glob specifier") | |
shards = [str(x) for x in input] | |
speakers = set() | |
for shard in shards: | |
with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines())) | |
speakers = {id:i for i,id in enumerate(sorted(speakers))} | |
def ds(shards, length): | |
ds = wds.WebDataset(wds.ResampledShards(shards)).compose( | |
wds.decode(), | |
speaker_id_extractor(speakers), | |
random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x, | |
pad_samples(stoks_pad_token=stoks_pad_token), | |
wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'), | |
wds.batched(64), | |
) | |
ds.speakers = speakers | |
ds.total_samples = length | |
return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64) | |
return ( | |
ds(shards[1:], samples), | |
ds(shards[:1], val_samples), | |
) | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33 | |
import pylab as plt | |
import fastprogress | |
import IPython | |
import numpy as np | |
class CMLMVisual: | |
"""Visualize training progress""" | |
def __init__ (self, model, masterbar, total_steps): | |
self.model = model | |
self.masterbar = masterbar | |
self.total_steps = total_steps | |
self.epochs = total_steps // masterbar.main_bar.total | |
gs = plt.GridSpec(3, 1, height_ratios=[2,2,1]) | |
graph_fig = plt.figure(figsize=(10,6)) | |
self.graph_fig = graph_fig | |
self.loss_p = graph_fig.add_subplot(gs[0]) | |
self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p) | |
self.acc_p.tick_params('x', labelbottom=False) | |
self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p) | |
self.lr_p.tick_params('x', labelbottom=False) | |
self.graph_out = None | |
self.its = [] | |
self.train_losses = [] | |
self.val_losses = [] | |
self.lr_history = [] | |
self.acc = np.nan | |
self.acc_history = [] | |
self.pacc_history = [] | |
def show(self): | |
self.start_t = time.time() | |
self.masterbar.write(["samples", "train", "val", "time"], table=True) | |
self.graph_out = display(self.graph_fig, display_id=True) | |
self.acc_out = display(IPython.display.HTML(''), display_id=True) | |
def hide(self): | |
if self.graph_out is not None: | |
self.graph_out.update(IPython.display.HTML('')) | |
def plot(self): | |
loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p | |
loss_p.clear() | |
loss_p.plot(self.its, self.train_losses) | |
loss_p.plot(self.its, self.val_losses) | |
loss_p.set_xlim(0, self.total_steps) | |
loss_p.set_yscale('log') | |
acc_p.clear() | |
for k in self.acc_history[-1].keys(): | |
acc_p.plot(self.its, [x[k] for x in self.acc_history], ':') | |
# acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0]))) | |
lr_p.clear() | |
lrs = np.array(self.lr_history) | |
lr_p.plot(self.its, lrs) | |
self.graph_out.update(self.graph_fig) | |
def add_data(self, it, lr, train_loss, val_los): | |
self.its.append(it) | |
self.train_losses.append(train_loss) | |
self.val_losses.append(val_los) | |
self.lr_history.append(lr) | |
metrics = self.model.get_metrics() | |
self.acc_history.append(metrics) | |
# self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}") | |
# self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy()) | |
# if self.acc_history: | |
html = "<h5>Accuracies:</h5><table>" | |
html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>" | |
html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>" | |
html += "</table>" | |
self.acc_out.update(IPython.display.HTML(html)) | |
self.plot() | |
def add_table_row(self, it, avg_train_loss, val_loss): | |
elapsed_t = time.time() - self.start_t | |
self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True) | |
def on_iter(self, bar, it, avg_train_loss, val_loss): | |
epoch = math.ceil(it / self.total_steps * self.epochs) | |
bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}" | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34 | |
# modified from https://blog.eleuther.ai/rotary-embeddings/ | |
import torch | |
class Rotary(torch.nn.Module): | |
def __init__(self, dim, base=10000): | |
super().__init__() | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
self.seq_len_cached = None | |
self.cos_cached = None | |
self.sin_cached = None | |
def forward(self, x, seq_dim=1): | |
seq_len = x.shape[seq_dim] | |
if seq_len != self.seq_len_cached: | |
self.seq_len_cached = seq_len | |
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.cos_cached = emb.cos()[None, :, None, :] | |
self.sin_cached = emb.sin()[None, :, None, :] | |
return self.cos_cached, self.sin_cached | |
# rotary pos emb helpers: | |
def rotate_half(x): | |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
return torch.cat( | |
(-x2, x1), dim=-1 | |
) | |
#@torch.jit.script | |
def apply_rotary_pos_emb(q, k, cos, sin): | |
return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin) | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35 | |
from torch import Tensor, nn | |
import torch.nn.functional as F | |
from typing import Dict, Iterable, Optional | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False, | |
qk_scale: float = 1, ffn_mult: int = 4): | |
super().__init__() | |
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) | |
self.attn_ln = LayerNorm(n_state) | |
self.cross_attn = ( | |
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None | |
) | |
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None | |
n_mlp = n_state * ffn_mult | |
self.mlp = nn.Sequential( | |
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) | |
) | |
self.mlp_ln = LayerNorm(n_state) | |
def forward( | |
self, | |
x: Tensor, | |
xa: Optional[Tensor] = None, | |
causal = False, | |
kv_cache: Optional[dict] = None, | |
): | |
x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0] | |
if self.cross_attn: | |
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] | |
x = x + self.mlp(self.mlp_ln(x)) | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False): | |
super().__init__() | |
self.n_head = n_head | |
self.sqrt_qk_scale = math.sqrt(qk_scale) | |
self.query = QueryHead(n_state, n_state) | |
self.key = nn.Linear(n_state, n_state, bias=False) | |
self.value = nn.Linear(n_state, n_state) | |
self.out = nn.Linear(n_state, n_state) | |
self.rotary = None | |
if rope: | |
self.rotary = Rotary(n_state // n_head) | |
def forward( | |
self, | |
x: Tensor, | |
xa: Optional[Tensor] = None, | |
causal = False, | |
kv_cache: Optional[dict] = None, | |
): | |
q = self.query(x) | |
if kv_cache is None or xa is None or self.key not in kv_cache: | |
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; | |
# otherwise, perform key/value projections for self- or cross-attention as usual. | |
k = self.key(x if xa is None else xa) | |
v = self.value(x if xa is None else xa) | |
else: | |
# for cross-attention, calculate keys and values once and reuse in subsequent calls. | |
k = kv_cache[self.key] | |
v = kv_cache[self.value] | |
if self.sqrt_qk_scale != 1: | |
q *= self.sqrt_qk_scale | |
k *= self.sqrt_qk_scale | |
wv, qk = self.qkv_attention_pth20(q, k, v, causal) | |
# wv, qk = self.qkv_attention_xformers(q, k, v, causal) | |
return self.out(wv), qk | |
def qkv_attention_pth20( | |
self, q: Tensor, k: Tensor, v: Tensor, causal = False | |
): | |
n_batch, n_ctx, n_state = q.shape | |
q = q.view(*q.shape[:2], self.n_head, -1) | |
k = k.view(*k.shape[:2], self.n_head, -1) | |
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) | |
#print('before rot:', q.shape, k.shape) | |
if self.rotary: | |
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k)) | |
#print(' after rot:', q.shape, k.shape) | |
k = k.permute(0, 2, 1, 3) | |
q = q.permute(0, 2, 1, 3) | |
# modified for better performance under PyTorch 2.0 | |
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal) | |
# previously we've returned q@k which we don't have now | |
# since it's not actually used anywhere else, let's just keep two return values for compatibility | |
return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None | |
def qkv_attention_xformers( | |
self, q: Tensor, k: Tensor, v: Tensor, causal = False | |
): | |
n_batch, n_ctx, n_state = q.shape | |
q = q.view(*q.shape[:2], self.n_head, -1) | |
k = k.view(*k.shape[:2], self.n_head, -1) | |
v = v.view(*v.shape[:2], self.n_head, -1) | |
if self.rotary: | |
q, k = apply_rotary_pos_emb(q, k, *self.rotary(k)) | |
bias = xops.LowerTriangularMask() if causal else None | |
wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias) | |
# previously we've returned q@k which we don't have now | |
# since it's not actually used anywhere else, let's just keep two return values for compatibility | |
return wv.flatten(start_dim=2), None | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36 | |
class DelSumDecoder(nn.Module): | |
def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None): | |
super().__init__() | |
self.length = length | |
width = n_head * head_width | |
self.width = width | |
self.codes = codes | |
self.quantizers = quantizers | |
self.linear_heads = linear_heads | |
self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)]) | |
if pos_embs is not None: | |
self.register_buffer("positional_embedding", pos_embs) | |
self.layers = nn.ModuleList([ | |
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth)) | |
]) | |
self.ln_post = LayerNorm(width) | |
if self.linear_heads: | |
self.heads = LinearHead(width, (codes+1) * quantizers, bias=False) | |
else: | |
self.splitter = nn.Sequential( | |
nn.Linear(width, width * quantizers), | |
nn.GELU(), | |
) | |
self.heads = nn.ModuleList([ | |
LinearHead(width, codes+1, bias=True) for _ in range(quantizers) | |
]) | |
def forward(self, toks, xenc): | |
b,_,n = toks.shape | |
newn = min(n+1, self.length) | |
embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device) | |
for i in range(self.quantizers): | |
embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device)) | |
if i < n: | |
embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1]) | |
x = embs.to(xenc.dtype) | |
for l in self.layers: | |
x = l(x, xenc, causal=True) | |
x = self.ln_post(x) | |
if self.linear_heads: | |
logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3) | |
else: | |
split = self.splitter(x).view(b,newn,self.quantizers,self.width) | |
logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1) | |
return logits | |
class EmbeddingProjector(nn.Linear): | |
pass | |
def rand(start, end): | |
return random.random() * (end - start) + start | |
class Tunables: | |
init_std :float = 9 | |
embeddings_std :float = 0.2 | |
embeddings_lr_scale: float = 10 | |
output_mult :float = 5.6 | |
# FIXME: try separate mults for self and cross attention | |
query_mult :float = .3 | |
encoder_depth_ratio :float = 0.25 | |
linear_heads :bool = False | |
rope :bool = True | |
lr0 :float = 3e-3 | |
clip_gradient_norm :float = 2 | |
weight_decay :float = 1e-3 | |
warmup_steps :float = 2000 | |
random :bool = False | |
def __post_init__(self): | |
# randomize the hyperparams if requested | |
if self.random: | |
self.init_std = 2*10**rand(0,1) | |
self.embeddings_std = 10**rand(-1.7,-0.22) | |
self.embeddings_lr_scale = 2**rand(2,4) | |
self.output_mult = 2**rand(1.5,3) | |
self.query_mult = 2**rand(-3,-1.3) | |
self.encoder_depth_ratio = random.choice([0.25,0.5]) | |
self.linear_heads = False | |
self.rope = True | |
self.lr0 = 3e-3 | |
self.clip_gradient_norm = 10**rand(-1,1) | |
self.warmup_steps = 100*(10**rand(1.18,1.3)) | |
def upgrade(args): | |
args = {k:v for k,v in args.items()} | |
def old_default(name, value): | |
if name not in args: args[name] = value | |
old_default('rope', False) | |
old_default('linear_heads', True) | |
return args | |
class SADelARTransformer(nn.Module): | |
def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4, | |
quantizers=8, speaker_map={"1":0}, tunables=Tunables()): | |
super().__init__() | |
self.quantizers = quantizers | |
width = n_head * head_width | |
store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map") | |
self.width = width | |
self.base_width = 3 * head_width | |
self.tunables = tunables | |
if stoks_width is None: stoks_width = width | |
if spk_width is None: spk_width = width | |
self.emb_factor = width != stoks_width | |
self.spk_factor = width != spk_width | |
if tunables.rope: | |
self.positional_embeddings = None | |
else: | |
self.register_buffer('positional_embeddings', sinusoids(ctx_n, width)) | |
self.speaker_embedding = nn.Embedding(len(speaker_map), width) | |
self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width) | |
if self.emb_factor: | |
self.emb_to_hidden = nn.Linear(stoks_width, width) | |
if self.spk_factor: | |
self.spk_to_hidden = EmbeddingProjector(spk_width, width) | |
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width) | |
encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio) | |
decoder_depth = depth * 2 - encoder_depth | |
self.encoder = nn.Sequential(*[ | |
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth) | |
]) | |
self.ln_post = LayerNorm(width) | |
self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale, | |
length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult, | |
depth=decoder_depth, quantizers=quantizers, | |
linear_heads=tunables.linear_heads, rope=tunables.rope) | |
self.register_buffer('val_true', torch.zeros(self.quantizers).cuda()) | |
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda()) | |
self.apply(self.init_transformer) | |
def setup(self, device): | |
pass | |
def load_frozen_semantic_embeddings(self, vqmodel): | |
with torch.no_grad(): | |
self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0] | |
self.semantic_embedding.lr_scale = 0 | |
def init_transformer(self, m): | |
if isinstance(m, LinearHead): | |
m.no_weight_decay = True | |
torch.nn.init.constant_(m.weight, 0) | |
elif isinstance(m, QueryHead): | |
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) | |
torch.nn.init.constant_(m.weight, 0) | |
elif isinstance(m, nn.Embedding): | |
m.no_weight_decay = True | |
m.lr_scale = self.tunables.embeddings_lr_scale | |
std = self.tunables.embeddings_std | |
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) | |
elif isinstance(m, EmbeddingProjector): | |
m.lr_scale = self.tunables.embeddings_lr_scale/2 | |
std = self.tunables.init_std | |
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) | |
elif isinstance(m, nn.Linear): | |
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) | |
std = self.tunables.init_std / m.weight.shape[1] | |
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) | |
if m.bias is not None: | |
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std) | |
elif isinstance(m, nn.LayerNorm): | |
m.no_weight_decay = True | |
torch.nn.init.constant_(m.bias, 0) | |
torch.nn.init.constant_(m.weight, 1) | |
def embed_stoks(self, Stoks): | |
b,n = Stoks.shape | |
if self.stoks_len == 1500: | |
# converts 50 toks/s to 75 toks/s by adding padding between every two tokens | |
x = Stoks.reshape(b,n//2,2) | |
x = x.repeat_interleave(2, -1)[:,:,:3] | |
x[:,:,1] = 1024 | |
x = x.reshape(b,n//2*3) | |
else: | |
# it's a lot easier with 25 toks/s | |
x = Stoks.repeat_interleave(3, -1) | |
# embed semantic tokens | |
Sembs = self.semantic_embedding(x.to(torch.long)) | |
if self.emb_factor: | |
Sembs = self.emb_to_hidden(Sembs) | |
return Sembs | |
def forward(self, Stoks, Atoks, speakers, noloss=False): | |
Atoks = Atoks.to(torch.long) | |
semb = self.embed_stoks(Stoks) | |
with record_function("encoder"): | |
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings | |
xenc = self.ln_post(self.encoder(semb)) | |
# xenc = torch.zeros_like(xenc) | |
with record_function("decoder"): | |
Atoks_gt = Atoks.clone() | |
Atoks_gt[Atoks == -100] = 1024 | |
# we can randomize speaker ids during validation to measure | |
# the importance of the speaker embedding vs. just the acoustic prompt/prefix | |
# if not self.training: speakers = speakers[torch.randperm(speakers.nelement())] | |
spk_embs = self.speaker_embedding(speakers) | |
if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs) | |
logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1)) | |
logits *= self.tunables.output_mult / (self.width / self.base_width) | |
if noloss: | |
return logits | |
with record_function("loss"): | |
N = Atoks.shape[-1] | |
loss = 0 | |
for i in range(self.quantizers): | |
loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1)) | |
loss /= self.quantizers | |
if not self.training: | |
for i in range(self.quantizers): | |
Atoks_i = Atoks[:,i,:N-i] | |
valid_Atoks = Atoks_i != -100 | |
self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum() | |
self.val_total[i] += valid_Atoks.float().sum() | |
return logits, loss | |
def get_metrics(self): | |
metrics = { | |
f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total) | |
} | |
self.val_true[:] = 0 | |
self.val_total[:] = 0 | |
return metrics | |
# | |
# inference | |
# | |
def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None): | |
if not local_filename: | |
local_filename = hf_hub_download(repo_id=repo_id, filename=filename) | |
spec = torch.load(local_filename) | |
if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] } | |
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables']))) | |
model.load_state_dict(spec['state_dict']) | |
model.eval() | |
return model | |
def get_extra_state(self): | |
return { 'speaker_map': self.speaker_map } | |
def set_extra_state(self, st): | |
self.speaker_map = st['speaker_map'] | |
def load_checkpoint(self, local_filename): | |
spec = torch.load(local_filename, map_location='cpu') | |
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint' | |
state_dict = {k.replace('model.', ''):v | |
for k,v in spec['state_dict'].items()} | |
self.load_state_dict(state_dict) | |
return self | |
def save_model(self, fname): | |
torch.save(dict(config = self.__stored_args__, | |
tunables = dataclasses.asdict(self.tunables), | |
state_dict = self.state_dict()), fname) | |
def device(self): | |
return next(self.parameters()).device | |
def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True): | |
dev = self.device | |
if self.stoks_len == 1500: | |
N = N or len(stoks) * 3 // 2 | |
else: | |
N = N or len(stoks) * 3 | |
stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0) | |
speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev) | |
toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev) | |
it = range(0,N) | |
if show_progress_bar: it = progress_bar(it) | |
for i in it: | |
p = self(stoks, toks[:,:,:i], speakers, noloss=True) | |
last_p = p[0,:,-1] | |
if top_k: | |
last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf | |
for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)): | |
toks[0,j,max(0,i-j)] = tok | |
if toks[0,0,i] == 1024: return toks[0,:,:i] | |
return toks[0] | |
# %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37 | |
def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs): | |
assert(dataset is not None) | |
kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs) | |
if size == 'micro': | |
return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs) | |
if size == 'tiny-narrow': | |
return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs) | |
if size == 'tiny': | |
return SADelARTransformer(depth=4, n_head=6, **kwargs) | |
if size == 'base': | |
return SADelARTransformer(depth=6, n_head=8, **kwargs) | |
if size == 'base-deep': | |
return SADelARTransformer(depth=9, n_head=8, **kwargs) | |
if size == 'base-wide': | |
return SADelARTransformer(depth=6, n_head=12, **kwargs) | |
if size == 'small/2': | |
return SADelARTransformer(depth=9, n_head=12, **kwargs) | |
if size == 'small': | |
return SADelARTransformer(depth=12, n_head=12, **kwargs) | |
if size == 'medium': | |
return SADelARTransformer(depth=24, n_head=16, **kwargs) | |
def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None): | |
if frozen_embeddings_model: | |
vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) | |
model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1]) | |
model.load_frozen_semantic_embeddings(vqmodel) | |
else: | |
model = _make_model(size, quantizers, tunables, dataset) | |
return model | |