import torch | |
import os | |
import pyworld as pw | |
import numpy as np | |
import torchaudio | |
import torch.nn.functional as F | |
from datasets import load_dataset, Audio | |
from dataclasses import dataclass | |
from typing import Any, List, Dict | |
import math | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import torch.nn.init as init | |
from torch import Tensor | |
from typing import Optional, Union, Tuple | |
from torch.nn.functional import scaled_dot_product_attention | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float32 | |
class LayerNorm(nn.Module): | |
def __init__(self, emb_dim): | |
super().__init__() | |
self.eps = 1e-5 | |
self.scale = nn.Parameter(torch.ones(emb_dim)) | |
self.shift = nn.Parameter(torch.zeros(emb_dim)) | |
def forward(self, x): | |
mean = x.mean(dim=-1, keepdim=True) | |
var = x.var(dim=-1, keepdim=True, unbiased=False) | |
norm_x = (x - mean) / torch.sqrt(var + self.eps) | |
return self.scale * norm_x + self.shift | |
class GELU(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return 0.5 * x * (1 + torch.tanh( | |
torch.sqrt(torch.tensor(2.0 / torch.pi)) * | |
(x + 0.044715 * torch.pow(x, 3)) | |
)) | |
def sinusoids(ctx, dims, max_tscale=10000): | |
assert dims % 2 == 0 | |
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1) | |
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32)) | |
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0) | |
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) | |
positional_embedding = nn.Parameter(position, requires_grad=True) | |
return positional_embedding | |
def get_activation(act: str) -> nn.Module: | |
act_map = { | |
"gelu": nn.GELU(), | |
"relu": nn.ReLU(), | |
"sigmoid": nn.Sigmoid(), | |
"tanh": nn.Tanh(), | |
"swish": nn.SiLU(), | |
"tanhshrink": nn.Tanhshrink(), | |
"softplus": nn.Softplus(), | |
"softshrink": nn.Softshrink(), | |
"leaky_relu": nn.LeakyReLU(), | |
"elu": nn.ELU() | |
} | |
return act_map.get(act, nn.GELU()) | |
def cos_sim(q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor: | |
q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12) | |
k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12) | |
qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2)) | |
qk_cosine = qk_cosine + mask | |
weights = F.softmax(qk_cosine, dim=-1) | |
out = torch.matmul(weights, v) | |
return out | |
def taylor_softmax_2nd_order(x): | |
exp_approx = 1 + x + (x**2) / 2 | |
return exp_approx / torch.sum(exp_approx, dim=-1, keepdim=True) | |
def taylor_softmax_approximation(x, order=2): | |
if order == 0: | |
return torch.ones_like(x) / x.size(-1) | |
elif order == 1: | |
numerator = 1 + x | |
elif order == 2: | |
numerator = 1 + x + 0.5 * x**2 | |
else: | |
raise NotImplementedError("Higher orders are not implemented yet.") | |
denominator = torch.sum(numerator, dim=-1, keepdim=True) | |
return numerator / denominator | |
def rbf_scores(q, k, rbf_sigma=1.0, rbf_ratio=0.0): | |
dot_scores = torch.matmul(q, k.transpose(-1, -2)) | |
if rbf_ratio <= 0.0: | |
return dot_scores | |
q_norm = q.pow(2).sum(dim=-1, keepdim=True) | |
k_norm = k.pow(2).sum(dim=-1, keepdim=True) | |
qk = torch.matmul(q, k.transpose(-1, -2)) | |
dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk | |
rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2)) | |
return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores | |
def sliding_window_mask(q_len, k_len, window, device): | |
# mask[i, j] = 1 if j in [i-window+1, i], else 0 | |
idxs = torch.arange(q_len, device=device).unsqueeze(1) | |
jdxs = torch.arange(k_len, device=device).unsqueeze(0) | |
mask = (jdxs >= (idxs - window + 1)) & (jdxs <= idxs) | |
return mask.float() # shape: (q_len, k_len) | |
def mask_win(text_ctx, aud_ctx): | |
mask = torch.tril(torch.ones(text_ctx, text_ctx, device=device, dtype=dtype), diagonal=0) | |
audio_mask = torch.tril(torch.ones(text_ctx, aud_ctx - text_ctx, device=device, dtype=dtype)) | |
full_mask = torch.cat([mask, audio_mask], dim=-1) | |
return full_mask | |
def maskc(ctx, device): | |
return torch.tril(torch.ones(ctx, ctx, device=device, dtype=dtype), diagonal=0) | |
def create_attention_mask(batch_size, ctx, is_causal=True, padding_mask=None, device=None): | |
if is_causal: | |
mask = torch.triu(torch.ones((ctx, ctx), device=device), diagonal=0) | |
mask = mask.expand(batch_size, 1, ctx, ctx) | |
else: | |
mask = torch.zeros((batch_size, 1, ctx, ctx), device=device) | |
if padding_mask is not None: | |
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2).bool() | |
mask = (mask.bool() | (~padding_mask)).float() | |
return mask | |
def calculate_attention(q, k, v, mask=None, temp=1.0): | |
scaled_q = q | |
if temp != 1.0 and temp > 0: | |
scaled_q = q * (1.0 / temp)**.5 | |
out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1) | |
return out | |
def calculate_attentionb(q_norm, k_norm, v_iter, mask=None, temp=1.0): | |
d_k = q_norm.size(-1) | |
scores = torch.matmul(q_norm, k_norm.transpose(-2, -1)) / (torch.sqrt(torch.tensor(d_k, dtype=torch.float32)) / temp) | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, float('-inf')) | |
attention_weights = F.softmax(scores, dim=-1) | |
output = torch.matmul(attention_weights, v_iter) | |
return output | |
class LocalOut(nn.Module): | |
def __init__(self, dims: int, head: int): | |
super().__init__() | |
self.head_dim = dims // head | |
self.dims = dims | |
self.q_module = nn.Linear(self.head_dim, self.head_dim) | |
self.k_module = nn.Linear(self.head_dim, self.head_dim) | |
self.v_module = nn.Linear(self.head_dim, self.head_dim) | |
self.o_proj = nn.Linear(self.head_dim, self.head_dim) | |
def _reshape_to_output(self, attn_output: Tensor) -> Tensor: | |
batch, _, ctx, _ = attn_output.shape | |
return attn_output.transpose(1, 2).contiguous().view(batch, ctx, self.dims) | |
def qkv_init(dims, head): | |
head_dim = dims // head | |
q = nn.Linear(dims, dims) | |
k = nn.Linear(dims, dims) | |
v = nn.Linear(dims, dims) | |
o = nn.Linear(dims, dims) | |
lna = nn.LayerNorm(dims) | |
lnb = nn.LayerNorm(dims) | |
lnc = nn.LayerNorm(head_dim) | |
lnd = nn.LayerNorm(head_dim) | |
return q, k, v, o, lna, lnb, lnc, lnd | |
def shape(dims, head, q, k, v): | |
batch_size = q.shape[0] | |
seq_len_q = q.shape[1] | |
seq_len_kv = k.shape[1] | |
head_dim = dims // head | |
q = q.view(batch_size, seq_len_q, head, head_dim).transpose(1, 2) | |
k = k.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2) | |
v = v.view(batch_size, seq_len_kv, head, head_dim).transpose(1, 2) | |
return q, k, v | |
def create_qkv(dims, head, q, k, v, x, xa): | |
head_dim = dims // head | |
scale = head_dim ** -0.25 | |
q = q(x) * scale | |
k = k(xa) * scale | |
v = v(xa) | |
batch, ctx, dims = x.shape | |
def _shape(tensor): | |
return tensor.view(batch, ctx, head, head_dim).transpose(1, 2).contiguous() | |
return _shape(q), _shape(k), _shape(v) | |
# def calculate_attention(q, k, v, mask=None, temp=1.0): | |
# scaled_q = q | |
# if temp != 1.0 and temp > 0: | |
# scaled_q = q * (1.0 / temp)**.5 | |
# out = scaled_dot_product_attention(scaled_q, k, v, is_causal=mask is not None and q.shape[1] > 1) | |
# return out | |
# class LocalOut(nn.Module): | |
# def __init__(self, dims: int, head: int): | |
# super().__init__() | |
# head_dim = dims // head | |
# self.head_dim = head_dim | |
# self.query_module = nn.Linear(head_dim, head_dim) | |
# self.key_module = nn.Linear(head_dim, head_dim) | |
# self.value_module = nn.Linear(head_dim, head_dim) | |
# self.out_proj = nn.Linear(head_dim, head_dim) | |
# def _reshape_to_output(self, x): | |
# return x | |
# class attentiona(nn.Module): | |
# def __init__(self, dims: int, head: int, max_iter: int = 3, threshold: float = 0.01, factor: float = 0.1, dropout: float = 0.1, temp = 1.0): | |
# super(attentiona, self).__init__() | |
# self.q, self.k, self.v, self.o, self.lna, self.lnb = qkv_init(dims, head) | |
# self.dims = dims | |
# self.head = head | |
# self.head_dim = dims // head | |
# self.max_iter = max_iter | |
# self.threshold = nn.Parameter(torch.tensor(threshold)) | |
# self.temp = nn.Parameter(torch.tensor(temp), requires_grad=True) | |
# self.factor = nn.Parameter(torch.tensor(factor)) | |
# self.lnc = nn.LayerNorm(self.head_dim, bias=False) | |
# self.lnd = nn.LayerNorm(self.head_dim, bias=False) | |
# self.attn_local = LocalOut(self.head_dim) | |
# def _focus(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None): | |
# z = default(xa, x) | |
# q, k, v = create_qkv(self.dims, self.head, self.q, self.k, self.v, self.lna(x), self.lna(z)) | |
# iteration = 0 | |
# temp = self.temp.item() | |
# prev_out = torch.zeros_like(q) | |
# attn_out = torch.zeros_like(q) | |
# threshold = self.threshold.item() | |
# factor = self.factor.item() | |
# qcur = q | |
# while iteration < self.max_iter: | |
# eff_span = min(qcur.shape[1], k.shape[1]) | |
# if xa is not None: | |
# eff_span = min(eff_span, xa.shape[1]) | |
# if eff_span == 0: | |
# break | |
# qiter = qcur[:, :, :eff_span, :] | |
# kiter = k[:, :, :eff_span, :] | |
# viter = v[:, :, :eff_span, :] | |
# q = self.attn_local.query_module(qiter) | |
# k = self.attn_local.key_module(kiter) | |
# v = self.attn_local.value_module(viter) | |
# iter_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: | |
# iter_mask = mask[:, :, :eff_span, :eff_span] | |
# elif mask.dim() == 2: | |
# iter_mask = mask[:eff_span, :eff_span] | |
# attn_iter = calculate_attention( | |
# self.lnc(q), self.lnd(k), v, | |
# mask=iter_mask, temp=temp) | |
# iter_out = torch.zeros_like(qcur) | |
# iter_out[:, :, :eff_span, :] = attn_iter | |
# diff = torch.abs(iter_out - prev_out).mean() | |
# dthresh = threshold + factor * diff | |
# if diff < dthresh and iteration > 0: | |
# attn_out = iter_out | |
# break | |
# prev_out = iter_out.clone() | |
# qcur = qcur + iter_out | |
# attn_out = iter_out | |
# iteration += 1 | |
# temp += 0.005 | |
# output = attn_out.permute(0, 2, 1, 3).flatten(start_dim=2) | |
# return self.o(output), None | |
# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, mask: Optional[Tensor] = None) -> Tensor: | |
# batch, ctx, dims = x.shape | |
# output = torch.zeros_like(x) | |
# num_win = (ctx + win_size - 1) // win_size | |
# for i in range(num_win): | |
# qstart = i * win_size | |
# qend = min(qstart + win_size, ctx) | |
# win_qlen = qend - qstart | |
# if win_qlen == 0: | |
# continue | |
# kstart = max(0, qend - span_len) | |
# kend = qend | |
# qwin = x[:, qstart:qend, :] | |
# kwin = x[:, kstart:kend, :] | |
# win_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: | |
# win_mask = mask[:, :, qstart:qend, kstart:kend] | |
# elif mask.dim() == 2: | |
# win_mask = mask[qstart:qend, kstart:kend] | |
# attn_out, _ = self._focus(x=qwin, xa=kwin, mask=win_mask) | |
# output[:, qstart:qend, :] = attn_out | |
# return output | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, | |
# use_sliding_win: bool = False, win_size: int = 512, span_len: int = 1024) -> Tensor: | |
# if use_sliding_win: | |
# return self._slide_win_local(x, win_size, span_len, mask) | |
# else: | |
# output, _ = self._focus(x, xa, mask) | |
# return output | |
class KVCache(nn.Module): | |
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): | |
super().__init__() | |
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) | |
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) | |
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) | |
def update(self, input_pos, k_val, v_val): | |
# input_pos: [S], k_val: [B, H, S, D] | |
assert input_pos.shape[0] == k_val.shape[2] | |
k_out = self.k_cache | |
v_out = self.v_cache | |
k_out[:, :, input_pos] = k_val # pyright: ignore[reportIndexIssue] | |
v_out[:, :, input_pos] = v_val # pyright: ignore[reportIndexIssue] | |
return k_out, v_out | |
def mel_scale_scalar(freq: float) -> float: | |
return 1127.0 * math.log(1.0 + freq / 700.0) | |
def mel_scale(freq: Tensor) -> Tensor: | |
return 1127.0 * (1.0 + freq / 700.0).log() | |
def trace_x(func): | |
def wrapper(*args, **kwargs): | |
print(f"Calling {func.__name__}") | |
result = func(*args, **kwargs) | |
if isinstance(result, torch.Tensor): | |
print(f" {func.__name__} returned shape: {result.shape}") | |
return result | |
return wrapper | |
def track_x(new_x, operation=""): | |
""" track_x(x, "x") """ | |
x_id = [id(new_x)] | |
if new_x is None: | |
return new_x | |
current_id = id(new_x) | |
if current_id != x_id[0]: | |
print(f"x FLOW: {x_id[0]} → {current_id} in {operation}") | |
x_id[0] = current_id | |
else: | |
print(f"x REUSE: {current_id} in {operation}") | |
return new_x | |
def track_xa(new_xa, operation=""): | |
""" track_xa(xa, "xa - decoder") """ | |
xa_id = [id(new_xa)] if new_xa is not None else [None] | |
if new_xa is None: | |
return new_xa | |
current_id = id(new_xa) | |
if current_id != xa_id[0]: | |
print(f"xa FLOW: {xa_id[0]} → {current_id} in {operation}") | |
xa_id[0] = current_id # pyright: ignore[reportArgumentType, reportCallIssue] | |
else: | |
print(f"xa REUSE: {current_id} in {operation}") | |
return new_xa | |
def get_activation(act: str) -> nn.Module: | |
"""Get activation function by name.""" | |
act_map = { | |
"gelu": nn.GELU(), | |
"relu": nn.ReLU(), | |
"sigmoid": nn.Sigmoid(), | |
"tanh": nn.Tanh(), | |
"swish": nn.SiLU(), | |
"tanhshrink": nn.Tanhshrink(), | |
"softplus": nn.Softplus(), | |
"softshrink": nn.Softshrink(), | |
"leaky_relu": nn.LeakyReLU(), | |
"elu": nn.ELU() | |
} | |
return act_map.get(act, nn.GELU()) | |
def get_generation_config(param): | |
return GenerationConfig( # type: ignore | |
max_length=param.text_ctx, | |
pad_token_id=getattr(param, "pad_token_id", 0), | |
bos_token_id=getattr(param, "bos_token_id", 1), | |
eos_token_id=getattr(param, "eos_token_id", 2), | |
do_sample=False, | |
num_beams=1, | |
early_stopping=False, | |
length_penalty=1.0, | |
no_repeat_ngram_size=0, | |
repetition_penalty=1.0, | |
temperature=1.0, | |
decoder_start_token_id=1, | |
is_multilingual=False, | |
use_cache=False, | |
return_timestamps=False) | |
class feature_encoder(nn.Module): | |
def __init__(self, mels, input_dims, dims, head, layer, act, features, feature=None, use_rope=False, spec_shape=None, debug=[], attend_feature=False, target_length=None): | |
""" | |
Feature encoder for audio processing. | |
""" | |
super().__init__() | |
self.dims = dims | |
self.head = head | |
self.head_dim = dims // head | |
self.dropout = 0.01 | |
self.use_rope = use_rope | |
self.attend_feature = attend_feature | |
self.target_length = target_length | |
self.feature = feature | |
self.debug = debug | |
act_fn = get_activation(act) | |
if self.attend_feature: | |
# self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head) | |
self.mlp = nn.Sequential(nn.Linear(dims, dims), nn.ReLU(), nn.Linear(dims, dims)) | |
else: | |
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None | |
self.mlp = None | |
self.spectrogram = nn.Sequential( | |
Conv1d(mels, dims, kernel_size=3), act_fn, | |
Conv1d(dims, dims, kernel_size=3), act_fn, | |
Conv1d(dims, dims, kernel_size=3, groups=dims), act_fn) | |
self.waveform = nn.Sequential( | |
Conv1d(1, dims//4, kernel_size=15, stride=4, padding=7), act_fn, | |
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn, | |
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn) | |
self.pitch = nn.Sequential( | |
Conv1d(1, dims, kernel_size=7, stride=1, padding=3), act_fn, | |
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn, | |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) | |
if use_rope: | |
# if spec_shape is not None: | |
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore | |
else: | |
self.rope = None # type: ignore | |
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
self.norm = RMSNorm(dims) | |
def rope(self, x, xa=None, mask=None, feats=None, feature=None, layer=None): | |
if isinstance(x, int): | |
ctx = x | |
elif isinstance(x, torch.Tensor): | |
ctx = x.shape[1] if x.dim() > 1 else x.shape[0] | |
batch, ctx, dims = x.shape[0], ctx, x.shape[-1] | |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) | |
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) | |
x = self.rope.apply_rotary(x, freqs) # pyright: ignore[reportOptionalSubscript, reportAttributeAccessIssue] | |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) | |
return x | |
def mel_scalar(self, freq: float) -> float: | |
return 1127.0 * math.log(1.0 + freq / 700.0) | |
def forward(self, x, xa=None, mask=None, feats=None, feature=None, layer=None, max_tscale=36000): | |
target_length = x.shape[1] if self.target_length is None else self.target_length | |
if feature == "pitch": | |
xp = x.clone() | |
enc_dict = feats if feats is not None else {} | |
enc_dict = dict(enc_dict) | |
enc_dict["f0"] = xp | |
# xp = self.mel_scalar(xp.mean()) | |
# print(f"Using pitch scalar: {xp}") | |
# max_tscale = xp*300 | |
# print(f"Using max_tscale: {max_tscale}") | |
feats = enc_dict | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.pitch(x).permute(0, 2, 1) | |
if feature == "phase": | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.pitch(x).permute(0, 2, 1) | |
if feature == "waveform": | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.waveform(x).permute(0, 2, 1) | |
if target_length and x.shape[1] != self.target_length: | |
x = F.adaptive_avg_pool1d(x.transpose(1, 2), target_length).transpose(1, 2) | |
if feature == "harmonics": | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.spectrogram(x).permute(0, 2, 1) | |
if feature == "aperiodic": | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.spectrogram(x).permute(0, 2, 1) | |
if feature == "spectrogram": | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
x = self.spectrogram(x).permute(0, 2, 1) | |
if self.use_rope: | |
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype) | |
x = self.rope(x=x, xa=None, mask=None, feats=feats, feature=feature, layer=layer) | |
else: | |
max_tscale = x.shape[1] * 1000 if max_tscale is None else max_tscale | |
x = x + self.positional(x.shape[1], x.shape[-1], max_tscale).to(device, dtype) | |
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
x = self.norm(x) | |
# if self.attend_feature: | |
# xa = feats[feature] # pyright: ignore[reportOptionalSubscript] | |
# if xa is not None: | |
# q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head) | |
# out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) | |
# x = x + out | |
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
x = self.norm(x) | |
return x | |
class OneShot(nn.Module): | |
def __init__(self, dims: int, head: int, scale: float = 0.3, features: Optional[List[str]] = None): | |
super().__init__() | |
if features is None: | |
features = ["spectrogram", "waveform", "pitch", "aperiodic", "harmonics"] | |
self.head = head | |
self.head_dim = dims // head | |
self.scale = 1.0 // len(features) if features else scale | |
self.q = Linear(dims, dims) | |
self.k = Linear(dims, dims) | |
def forward(self, x: Tensor, xa: Tensor, feature=None) -> Tensor | None: | |
B, L, D = x.shape | |
K = xa.size(1) | |
q = self.q(x).view(B, L, self.head, self.head_dim).transpose(1,2) | |
k = self.k(xa).view(B, K, self.head, self.head_dim).transpose(1,2) | |
bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.head_dim) | |
return bias | |
class curiosity(nn.Module): | |
def __init__(self, d, h, bias=True): | |
super().__init__() | |
self.h = h | |
self.dh = d // h | |
self.qkv = nn.Linear(d, d * 3, bias=bias) | |
self.qkv_aux = nn.Linear(d, d * 3, bias=bias) | |
self.o = nn.Linear(d, d, bias=bias) | |
self.g = nn.Parameter(torch.zeros(h)) | |
def split(self, x): | |
b, t, _ = x.shape | |
return x.view(b, t, self.h, self.dh).transpose(1, 2) | |
def merge(self, x): | |
b, h, t, dh = x.shape | |
return x.transpose(1, 2).contiguous().view(b, t, h * dh) | |
def forward(self, x, xa, mask=None): | |
q, k, v = self.qkv(x).chunk(3, -1) | |
qa, ka, va = self.qkv_aux(xa).chunk(3, -1) | |
q, k, v = map(self.split, (q, k, v)) | |
qa, ka, va = map(self.split, (qa, ka, va)) | |
dots = (q @ k.transpose(-2, -1)) / self.dh**0.5 | |
dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5 | |
if mask is not None: dots = dots.masked_fill(mask, -9e15) | |
p = dots.softmax(-1) | |
pa = dots_aux.softmax(-1) | |
h_main = p @ v | |
h_aux = pa @ va | |
g = torch.sigmoid(self.g).view(1, -1, 1, 1) | |
out = self.merge(h_main * (1 - g) + h_aux * g) | |
return self.o(out) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, dims, ctx): | |
super(PositionalEncoding, self).__init__() | |
self.dims = dims | |
self.ctx = ctx | |
self.pe = self.get_positional_encoding(max_ctx=ctx) | |
def get_positional_encoding(self, max_ctx): | |
pe = torch.zeros(max_ctx, self.dims) | |
position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, self.dims, 2, dtype=torch.float32) | |
* (-math.log(10000.0) / self.dims) | |
) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
return pe.to(device) | |
def forward(self, x): | |
ctx = x.size(1) | |
pe = self.pe[:, :ctx, :] | |
x = x * math.sqrt(self.dims) | |
x = x + pe | |
return x | |
def valid(default_value, *items): | |
for item in items: | |
if item is not None: | |
return item | |
return default_value | |
def dict_to(d, device, dtype=dtype): | |
return {k: v.to(device, dtype) if isinstance(v, torch.Tensor) else v | |
for k, v in d.items()} | |
def exists(v): | |
return v is not None | |
def default(v, b): | |
return v if exists(v) else b | |
class Conv1d(nn.Conv1d): | |
def _conv_forward( | |
self, x: Tensor, weight: Tensor, bias) -> Tensor: | |
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) | |
class Conv2d(nn.Conv2d): | |
def _conv_forward( | |
self, x: Tensor, weight: Tensor, bias) -> Tensor: | |
return super()._conv_forward(x, weight.to(x.device, x.dtype), None if bias is None else bias.to(x.device, x.dtype)) | |
class Linear(nn.Module): | |
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
super(Linear, self).__init__() | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
init.xavier_uniform_(self.linear.weight) | |
if bias: | |
init.zeros_(self.linear.bias) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.linear(x) | |
class RMSNorm(nn.Module): | |
def __init__(self, dims: Union[int, Tensor, List, Tuple], | |
eps = 1e-8, elementwise_affine = True): | |
super(RMSNorm, self).__init__() | |
if isinstance(dims, int): | |
self.normalized_shape = (dims,) | |
else: | |
self.normalized_shape = tuple(dims) | |
self.eps = eps | |
self.elementwise_affine = elementwise_affine | |
if self.elementwise_affine: | |
self.weight = nn.Parameter(torch.empty(self.normalized_shape)) # type: ignore | |
init.ones_(self.weight) | |
else: | |
self.register_parameter("weight", None) | |
def forward(self, x): | |
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) # type: ignore | |
def LayerNorm(x: Tensor, normalized_shape: Union[int, Tensor, List, Tuple], | |
weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, | |
eps: float = 1e-5) -> Tensor: | |
return F.layer_norm(x, normalized_shape, weight, bias, eps) # type: ignore | |
def get_device(): | |
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def get_dtype(): | |
return torch.float32 if torch.cuda.is_available() else torch.float64 | |
def tox(): | |
return {"device": get_device(), "dtype": get_dtype()} | |
def sinusoids(ctx, dims, max_tscale=10000): | |
assert dims % 2 == 0 | |
pos = torch.log(torch.tensor(float(max_tscale))) / (dims // 2 - 1) | |
tscales = torch.exp(-pos * torch.arange(dims // 2, device=device, dtype=torch.float32)) | |
scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0) | |
position = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) | |
positional_embedding = nn.Parameter(position, requires_grad=True) | |
return positional_embedding | |
class SelfCriticalRL(nn.Module): | |
def __init__(self, model, tokenizer, reward_fn): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.reward_fn = reward_fn | |
def forward(self, input_ids, features, labels=None, max_len=128, feature_name="spectrogram"): | |
with torch.no_grad(): | |
greedy_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len) | |
greedy_text = [self.tokenizer.decode(ids) for ids in greedy_ids] | |
sampled_ids = self.model.generate(input_ids=input_ids, **{feature_name: features}, max_length=max_len, do_sample=True, top_k=5) | |
sampled_text = [self.tokenizer.decode(ids) for ids in sampled_ids] | |
rewards = [] | |
baseline = [] | |
for s, g, ref in zip(sampled_text, greedy_text, labels): # type: ignore | |
ref_text = self.tokenizer.decode(ref) | |
rewards.append(self.reward_fn(s, ref_text)) | |
baseline.append(self.reward_fn(g, ref_text)) | |
rewards = torch.tensor(rewards, device=device, dtype=torch.float) | |
baseline = torch.tensor(baseline, device=device, dtype=torch.float) | |
advantage = rewards - baseline | |
logits = self.model(input_ids=sampled_ids, **{feature_name: features})["logits"] # logits: [batch, sampled_seq_len, vocab_size] | |
log_probs = F.log_softmax(logits, dim=-1) | |
log_probs_seq = torch.gather(log_probs, 2, sampled_ids.unsqueeze(-1)).squeeze(-1) | |
log_probs_sum = log_probs_seq.sum(dim=1) | |
loss = -(advantage * log_probs_sum).mean() | |
return loss | |
class SelfTrainingModule(nn.Module): | |
def __init__(self, model, tokenizer, quality_fn=None, threshold=0.8): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.quality_fn = quality_fn | |
self.threshold = threshold | |
def generate_pseudo_labels(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"): | |
with torch.no_grad(): | |
pred_ids = self.model.generate(input_ids=unlabeled_batch, **{feature_name: features}, max_length=max_len) | |
if self.quality_fn is not None: | |
quality_scores = self.quality_fn(pred_ids, self.model, features) | |
mask = quality_scores > self.threshold | |
pred_ids = pred_ids[mask] | |
return pred_ids | |
def forward(self, unlabeled_batch, features, max_len=128, feature_name="spectrogram"): | |
pseudo_labels = self.generate_pseudo_labels(unlabeled_batch, features, max_len, feature_name=feature_name) | |
logits = self.model(input_ids=unlabeled_batch, **{feature_name: features}, labels=pseudo_labels)["logits"] | |
loss = nn.functional.cross_entropy( | |
logits.view(-1, logits.shape[-1]), pseudo_labels.view(-1), ignore_index=0) | |
return loss | |
def confidence_indicator(pred_ids, model, features): | |
with torch.no_grad(): | |
logits = model(input_ids=pred_ids, **features)["logits"] | |
probs = torch.softmax(logits, dim=-1) | |
max_probs, _ = probs.max(dim=-1) | |
return max_probs.mean(dim=1) | |
def wer_reward(hyp, ref): | |
hyp_words = hyp.split() | |
ref_words = ref.split() | |
d = [[0] * (len(ref_words)+1) for _ in range(len(hyp_words)+1)] | |
for i in range(len(hyp_words)+1): | |
d[i][0] = i | |
for j in range(len(ref_words)+1): | |
d[0][j] = j | |
for i in range(1, len(hyp_words)+1): | |
for j in range(1, len(ref_words)+1): | |
if hyp_words[i-1] == ref_words[j-1]: | |
d[i][j] = d[i-1][j-1] | |
else: | |
d[i][j] = 1 + min(d[i-1][j], d[i][j-1], d[i-1][j-1]) | |
wer = d[-1][-1] / max(1, len(ref_words)) | |
return -wer # negative WER as reward | |
def clean_ids(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): | |
if isinstance(ids, torch.Tensor): | |
ids = ids.tolist() | |
return [int(id) for id in ids if id != -100 and id != pad_token_id and id != bos_token_id and id != eos_token_id] | |
def clean_batch(batch_ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): | |
return [clean_ids(seq, pad_token_id, bos_token_id, eos_token_id) for seq in batch_ids] | |
def setup_tokenizer(dir: str): | |
from tokenizers import Tokenizer | |
tokenizer = Tokenizer.from_file(f"{dir}") | |
orig_encode = tokenizer.encode | |
orig_decode = tokenizer.decode | |
def enc(text, add_special_tokens=True): | |
ids = orig_encode(text).ids | |
if not add_special_tokens: | |
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]] | |
ids = [id for id in ids if id not in sp_ids] | |
return ids | |
def bdec(ids_list, pad_token_id=0, bos_token_id=1, eos_token_id=2, skip_special_tokens=True): | |
results = [] | |
if isinstance(ids_list, torch.Tensor): | |
ids_list = ids_list.tolist() | |
elif isinstance(ids_list, np.ndarray): | |
ids_list = ids_list.tolist() | |
for ids in ids_list: | |
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)] | |
results.append(orig_decode(ids)) | |
return results | |
def dec(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): | |
ids = [int(id) for id in ids if id not in (pad_token_id, bos_token_id, eos_token_id, -100)] | |
return orig_decode(ids) | |
def save_pretrained(save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
tokenizer.save(f"{save_dir}/tokenizer.json") | |
tokenizer.encode = enc | |
tokenizer.batch_decode = bdec | |
tokenizer.decode = dec | |
tokenizer.save_pretrained = save_pretrained | |
tokenizer.pad_token_id = 0 | |
tokenizer.bos_token_id = 1 | |
tokenizer.eos_token_id = 2 | |
return tokenizer | |
def tokenize_pitch(pitch_features, target_length): | |
pitch_len = pitch_features.shape[-1] | |
token_len = target_length | |
if pitch_len > token_len: | |
pitch_tokens = F.adaptive_avg_pool1d(pitch_features, token_len) | |
else: | |
pitch_tokens = F.interpolate(pitch_features, token_len) | |
return pitch_tokens | |
def load_wave(wave_data, sample_rate=16000): | |
if isinstance(wave_data, str): | |
waveform, sample_rate = torchaudio.load(uri=wave_data, normalize=False) | |
elif isinstance(wave_data, dict): | |
waveform = torch.tensor(data=wave_data["array"]).float() | |
sample_rate = wave_data["sampling_rate"] # noqa: F841 | |
else: | |
raise TypeError("Invalid wave_data format.") | |
return waveform | |
def world_to_mel(sp, ap, sample_rate=16000, n_mels=128): | |
import librosa | |
mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=1024, n_mels=n_mels) | |
mel_basis = torch.from_numpy(mel_basis).float() | |
sp_mel = torch.matmul(sp, mel_basis.T) # (frames, 128) | |
ap_mel = torch.matmul(ap, mel_basis.T) # (frames, 128) | |
return sp_mel, ap_mel | |
def extract_features(batch, tokenizer, waveform=False, spec=False, pitch_tokens=False, pitch=False, harmonics=False, sample_rate=16000, hop_length=256, mode="mean", debug=False, phase_mod=False, crepe=False, aperiodics=False, dummy=False): | |
# import torch | |
# import torchaudio | |
# import torchaudio.functional as F | |
# import torchaudio.transforms as T | |
torch_windows = { | |
'hann': torch.hann_window, | |
'hamming': torch.hamming_window, | |
'blackman': torch.blackman_window, | |
'bartlett': torch.bartlett_window, | |
'ones': torch.ones, | |
None: torch.ones, | |
} | |
if dummy: | |
return { | |
"spectrogram": torch.zeros((1, 128, 100)), | |
"f0": torch.zeros((1, 100)), | |
"pitch_tokens": torch.zeros((1, 100)), | |
"pitch": torch.zeros((1, 100)), | |
"harmonics": torch.zeros((1, 128, 100)), | |
"aperiodics": torch.zeros((1, 128, 100)), | |
"crepe_time": None, | |
"crepe_frequency": None, | |
"crepe_confidence": None, | |
"crepe_activation": None, | |
} | |
audio = batch["audio"] | |
sample_rate = audio["sampling_rate"] | |
# audio_length = len(audio["array"]) / audio["sampling_rate"] | |
labels = tokenizer.encode(batch["transcription"]) | |
# sentence_length = len(batch["transcription"]) | |
wav = load_wave(wave_data=audio, sample_rate=sample_rate) | |
def crepe_predict(wav, sample_rate, viterbi=False): | |
import torchcrepe | |
wav = wav.numpy().astype(np.float32) | |
time, frequency, confidence, activation = torchcrepe.predict( | |
wav, sample_rate=sample_rate, viterbi=viterbi) | |
crepe_time = torch.from_numpy(time) | |
crepe_frequency = torch.from_numpy(frequency) | |
crepe_confidence = torch.from_numpy(confidence) | |
crepe_activation = torch.from_numpy(activation) | |
return crepe_time, crepe_frequency, crepe_confidence, crepe_activation | |
if crepe: | |
crepe_time, crepe_frequency, crepe_confidence, crepe_activation = crepe_predict(wav, sample_rate, viterbi=True) | |
else: | |
crepe_time = None | |
crepe_frequency = None | |
crepe_confidence = None | |
crepe_activation = None | |
def spectrogram(wav, sample_rate, n_fft=1024, hop_length=256, window_fn=torch.hann_window): | |
if isinstance(window_fn, str): | |
window_fn = torch_windows[window_fn] | |
if window_fn is None: | |
window_fn = torch.ones(n_fft) | |
if isinstance(window_fn, torch.Tensor): | |
window_fn = window_fn.to(device) | |
return torchaudio.functional.spectrogram( | |
wav, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, | |
window=window_fn, center=True, pad_mode="reflect", power=1.0) | |
def mel_spectrogram(wav, sample_rate): | |
spectrogram_config = { | |
"hop_length": 256, | |
"f_min": 150, | |
"f_max": 2000, | |
"n_mels": 128, | |
"n_fft": 1024, | |
"sample_rate": 16000, | |
"pad_mode": "constant", | |
"center": True, | |
"power": 1.0, | |
"window_fn": torch.hann_window, | |
"mel_scale": "htk", | |
"norm": None, | |
"normalized": False, | |
} | |
transform = torchaudio.transforms.MelSpectrogram(**spectrogram_config) | |
mel_spectrogram = transform(wav) | |
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10() | |
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0) | |
spectrogram_tensor = (log_mel + 4.0) / 4.0 | |
return spectrogram_tensor | |
if spec: | |
spectrogram_tensor = mel_spectrogram(wav, sample_rate) | |
def mfcc(wav, sample_rate, n_mels=128, n_fft=1024, hop_length=256, window_fn=torch.hann_window): | |
transform = torchaudio.transforms.MFCC( | |
sample_rate=sample_rate, | |
n_mfcc=n_mels, | |
melkwargs={ | |
"n_fft": n_fft, | |
"hop_length": hop_length, | |
"window_fn": window_fn, | |
"n_mels": n_mels, | |
"center": True, | |
"pad_mode": "reflect", | |
"norm": None, | |
"mel_scale": "htk", | |
} | |
) | |
mfcc_tensor = transform(wav) | |
return mfcc_tensor | |
# def compute_pitch(wav, sample_rate, hop_length=256): | |
# # pitch = F.detect_pitch_frequency(wav, sample_rate) | |
# # f0 = pitch | |
# import pyworld as pw | |
# wav_np = wav.numpy().astype(np.float64) | |
# f0, t = pw.dio(wav_np, sample_rate, frame_period=hop_length / sample_rate * 1000) | |
# f0 = pw.stonemask(wav_np, f0, t, sample_rate) | |
# return f0, t | |
def harmonics_and_aperiodics(wav, f0, t, sample_rate): | |
import pyworld as pw | |
wav_np = wav.numpy().astype(np.float64) | |
sp = pw.cheaptrick(wav_np, f0, t, sample_rate, fft_size=256) | |
ap = pw.d4c(wav_np, f0, t, sample_rate, fft_size=256) | |
harmonic_tensor = torch.from_numpy(sp) | |
aperiodic_tensor = torch.from_numpy(ap) | |
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T | |
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T | |
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0) | |
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0) | |
return harmonic_tensor, aperiodic_tensor | |
if pitch or pitch_tokens or harmonics or aperiodics: | |
wavnp = wav.numpy().astype(np.float64) | |
f0_np, t = pw.dio(wavnp, sample_rate, frame_period=hop_length / sample_rate * 1000) | |
f0_np = pw.stonemask(wavnp, f0_np, t, sample_rate) | |
if pitch_tokens: | |
wav = torch.from_numpy(wavnp) | |
t2 = torch.from_numpy(t) | |
audio_duration = len(wav) / sample_rate | |
T = len(labels) | |
tok_dur_sec = audio_duration / T | |
token_starts = torch.arange(T) * tok_dur_sec | |
token_ends = token_starts + tok_dur_sec | |
start_idx = torch.searchsorted(t2, token_starts, side="left") | |
end_idx = torch.searchsorted(t2, token_ends, side="right") | |
pitch_tok = torch.zeros(T, dtype=torch.float32) | |
for i in range(T): | |
lo, hi = start_idx[i], max(start_idx[i]+1, end_idx[i]) # type: ignore | |
segment = f0_np[lo:hi] | |
if mode == "mean": | |
pitch_tok[i] = segment.mean() | |
elif mode == "median": | |
pitch_tok[i] = torch.median(segment) | |
else: | |
pitch_tok[i] = segment[-1] | |
pitch_tok[pitch_tok < 100.0] = 0.0 | |
bos_pitch = pitch_tok[0] if len(pitch_tok) > 0 else 0.0 | |
pitch_tokens_tensor = torch.cat([torch.tensor([bos_pitch]), pitch_tok]) | |
pitch_tokens_tensor = torch.where(pitch_tokens_tensor == 0.0, torch.zeros_like(pitch_tokens_tensor), (pitch_tokens_tensor - 71.0) / (500.0 - 71.0)) | |
else: | |
pitch_tokens_tensor = None | |
if phase_mod: | |
tframe = torch.mean(t2[1:] - t2[:-1]) | |
phi0 = 0.0 | |
omega = 2 * torch.pi * f0_tensor # type: ignore | |
dphi = omega * tframe | |
phi = torch.cumsum(dphi, dim=0) + phi0 | |
phase = torch.remainder(phi, 2 * torch.pi) | |
else: | |
phase = None | |
if pitch: | |
p_tensor = torchaudio.functional.detect_pitch_frequency(wav, sample_rate) | |
# p_tensor = torch.from_numpy(f0_np) | |
# p_tensor = p_tensor.unsqueeze(0) | |
else: | |
p_tensor = None | |
if harmonics or aperiodics: | |
spnp = pw.cheaptrick(wavnp, f0_np, t, sample_rate, fft_size=256) | |
apnp = pw.d4c(wavnp, f0_np, t, sample_rate, fft_size=256) | |
harmonic_tensor = torch.from_numpy(spnp) | |
aperiodic_tensor = torch.from_numpy(apnp) | |
harmonic_tensor = harmonic_tensor[:, :128].contiguous().T | |
aperiodic_tensor = aperiodic_tensor[:, :128].contiguous().T | |
harmonic_tensor = torch.where(harmonic_tensor == 0.0, torch.zeros_like(harmonic_tensor), harmonic_tensor / 1.0) | |
aperiodic_tensor = torch.where(aperiodic_tensor == 0.0, torch.zeros_like(aperiodic_tensor), aperiodic_tensor / 1.0) | |
else: | |
harmonic_tensor = None | |
aperiodic_tensor = None | |
if waveform: | |
wave_tensor = wav | |
else: | |
wave_tensor = None | |
if dummy: | |
if spectrogram_tensor is not None: | |
dummy_tensor = torch.ones_like(spectrogram_tensor) | |
elif p_tensor is not None: | |
dummy_tensor = torch.ones_like(p_tensor) | |
elif pitch_tokens_tensor is not None: | |
dummy_tensor = torch.ones_like(pitch_tokens_tensor) | |
else: | |
batch_size = 128 | |
seq_len = 1024 | |
dummy_tensor = torch.ones(batch_size, seq_len) | |
dummy_tensor = dummy_tensor.to(device) | |
else: | |
dummy_tensor = None | |
if debug: | |
print(f"['pitch_tokens']: {pitch_tokens_tensor.shape if pitch_tokens else None}") | |
print(f"['harmonic']: {harmonic_tensor.shape if harmonics else None}") | |
print(f"['aperiodic']: {aperiodic_tensor.shape if aperiodics else None}") | |
print(f"['spectrogram']: {spectrogram_tensor.shape if spec else None}") | |
print(f"['waveform']: {wave_tensor.shape if waveform else None}") | |
print(f"['labels']: {len(labels) if labels else None}") | |
print(f"['phase']: {phase.shape if phase else None}") | |
print(f"['pitch']: {p_tensor.shape if pitch else None}") | |
print(f"['crepe_time']: {crepe_time.shape if crepe else None}") | |
print(f"['crepe_frequency']: {crepe_frequency.shape if crepe else None}") | |
print(f"['crepe_confidence']: {crepe_confidence.shape if crepe else None}") | |
print(f"['crepe_activation']: {crepe_activation.shape if crepe else None}") | |
print(f"['dummy']: {dummy_tensor.shape if dummy else None}") | |
return { | |
"waveform": wave_tensor if waveform else None, | |
"spectrogram": spectrogram_tensor if spec else None, | |
"pitch_tokens": pitch_tokens_tensor if pitch_tokens else None, | |
"pitch": p_tensor if pitch else None, | |
"harmonic": harmonic_tensor if harmonics else None, | |
"aperiodic": aperiodic_tensor if aperiodics else None, | |
"labels": labels, | |
"phase": phase if phase_mod else None, | |
"crepe_time": crepe_time if crepe else None, | |
"crepe_frequency": crepe_frequency if crepe else None, | |
"crepe_confidence": crepe_confidence if crepe else None, | |
"crepe_activation": crepe_activation if crepe else None, | |
"dummy": dummy_tensor if dummy else None, | |
} | |
# class PEncoder(nn.Module): # pitch encoder | |
# def __init__(self, dims: int, head: int, layer: int, kernel_size: int, act: str, | |
# max_seq_len: int, input_dims: int = 1, use_rope=False): | |
# super().__init__() | |
# self.head = head | |
# self.head_dim = dims // head | |
# self.dims = dims | |
# self.use_rope=use_rope | |
# self.dropout_rate = 0.01 | |
# act_fn = get_activation(act) | |
# self.positional_encoding = nn.Parameter(torch.randn(1, max_seq_len, dims)) | |
# if use_rope: | |
# self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore | |
# else: | |
# self.rope = None | |
# self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
# self.attend_pitch = False | |
# if self.attend_pitch: | |
# self.mlp = nn.Sequential( | |
# nn.Linear(dims, dims), | |
# nn.ReLU(), | |
# nn.Linear(dims, dims), | |
# ) | |
# else: | |
# self.mlp = None | |
# self.pitch_encoder = nn.Sequential( | |
# nn.Conv1d(input_dims, dims, kernel_size=kernel_size, stride=1, padding=kernel_size // 2), act_fn, | |
# nn.Conv1d(dims, dims, kernel_size=kernel_size - 2, stride=1, padding=(kernel_size - 2) // 2), act_fn, | |
# nn.Conv1d(dims, dims, kernel_size=kernel_size - 4, stride=1, padding=(kernel_size - 4) // 2, groups=dims), act_fn | |
# ) | |
# def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): | |
# batch, ctx, dims = x.shape | |
# x = x.view(batch, ctx, head, self.head_dim).permute(0, 2, 1, 3) | |
# freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore | |
# x = self.rope.apply_rotary(x, freqs)# type: ignore | |
# x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) | |
# return x | |
# self.norm = nn.LayerNorm(dims) | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, mask: Optional[Tensor] = None, | |
# feats: Optional[Any] = None, feature: str = "pitch", layer: str = "PEncoder", | |
# audio_duration: Optional[float] = None, sample_rate: Optional[int] = None, | |
# labels_len: Optional[int] = None, f0_np: Optional[np.ndarray] = None, | |
# t2_np: Optional[np.ndarray] = None, mode: str = "mean") -> Tensor: | |
# if x.dim() == 2 and feature == "pitch": | |
# x_processed = x.unsqueeze(1) # Input to pitch_encoder: (batch_size, 1, num_pitch_tokens) | |
# x_processed = self.pitch_encoder(x_processed) # Output: (batch_size, dims, num_pitch_tokens) | |
# x = x_processed.permute(0, 2, 1) # Reassign to x for consistency | |
# if self.use_rope: | |
# pass # Placeholder for RoPE application | |
# seq_len = x.shape[1] | |
# # x = x + self.positional_encoding[:, :seq_len, :] | |
# x = x + sinusoids(x.shape[1], x.shape[-1], 36000).to(device, dtype) | |
# if self.mlp is not None: | |
# x = self.mlp(x) | |
# x = nn.functional.dropout(x, p=self.dropout_rate, training=self.training) | |
# x = self.norm(x) | |
# return x | |
def plot_waveform(waveform, sr, title="Waveform", ax=None): | |
waveform = waveform.numpy() | |
num_channels, num_frames = waveform.shape | |
time_axis = torch.arange(0, num_frames) / sr | |
if ax is None: | |
_, ax = plt.subplots(num_channels, 1) | |
ax.plot(time_axis, waveform[0], linewidth=1) | |
ax.grid(True) | |
ax.set_xlim([0, time_axis[-1]]) | |
ax.set_title(title) | |
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None): | |
import librosa | |
if ax is None: | |
_, ax = plt.subplots(1, 1) | |
if title is not None: | |
ax.set_title(title) | |
ax.set_ylabel(ylabel) | |
ax.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto", interpolation="nearest") | |
def plot_fbank(fbank, title=None): | |
fig, axs = plt.subplots(1, 1) | |
axs.set_title(title or "Filter bank") | |
axs.imshow(fbank, aspect="auto") | |
axs.set_ylabel("frequency bin") | |
axs.set_xlabel("mel bin") | |
def plot_pitch(waveform, sr, pitch): | |
figure, axis = plt.subplots(1, 1) | |
axis.set_title("Pitch Feature") | |
axis.grid(True) | |
end_time = waveform.shape[1] / sr | |
time_axis = torch.linspace(0, end_time, waveform.shape[1]) | |
axis.plot(time_axis, waveform[0], linewidth=1, color="gray", alpha=0.3) | |
axis2 = axis.twinx() | |
time_axis = torch.linspace(0, end_time, pitch.shape[1]) | |
axis2.plot(time_axis, pitch[0], linewidth=2, label="Pitch", color="green") | |
axis2.legend(loc=0) | |
def prepare_datasets(tokenizer, token, sanity_check=False, sample_rate=16000, streaming=False, | |
load_saved=False, save_dataset=False, cache_dir=None, extract_args=None, max_ctx=2048): | |
if extract_args is None: | |
extract_args = { | |
"waveform": False, | |
"spec": False, | |
"f0": False, | |
"pitch_tokens": False, | |
"pitch": False, | |
"harmonic": False, | |
"aperiodic": False, | |
"sample_rate": 16000, | |
"hop_length": 256, | |
"mode": "mean", | |
"debug": False, | |
"phase_mod": False, | |
"crepe": False, | |
"dummy": False, | |
} | |
if load_saved: | |
if cache_dir is None: | |
cache_dir = "./processed_datasets" | |
else: | |
cache_dir = cache_dir | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file_train = os.path.join(cache_dir, "train.arrow") | |
cache_file_test = os.path.join(cache_dir, "test.arrow") | |
if os.path.exists(cache_file_train) and os.path.exists(cache_file_test): | |
from datasets import Dataset | |
train_dataset = Dataset.load_from_disk(cache_file_train) | |
test_dataset = Dataset.load_from_disk(cache_file_test) | |
return train_dataset, test_dataset | |
if sanity_check: | |
test = load_dataset( | |
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).cast_column("audio", Audio(sampling_rate=sample_rate)).take(1) | |
dataset = test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=test.column_names) | |
train_dataset = dataset | |
test_dataset = dataset | |
return train_dataset, test_dataset | |
else: | |
def filter_func(x): | |
return (0 < len(x["transcription"]) < max_ctx and | |
len(x["audio"]["array"]) > 0 and | |
len(x["audio"]["array"]) < max_ctx * 160) | |
# raw_train = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="train", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription") | |
# raw_test = load_dataset("mozilla-foundation/common_voice_17_0", "en", token=token, split="test", trust_remote_code=True, streaming=True).rename_column("sentence", "transcription").take(1000) | |
raw_train = load_dataset( | |
"google/fleurs", "en_us", token=token, split="train", trust_remote_code=True, streaming=streaming).take(1000) | |
raw_test = load_dataset( | |
"google/fleurs", "en_us", token=token, split="test", trust_remote_code=True, streaming=streaming).take(100) | |
raw_train = raw_train.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate)) | |
raw_test = raw_test.filter(filter_func).cast_column("audio", Audio(sampling_rate=sample_rate)) | |
train_dataset = raw_train.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_train.column_names) | |
test_dataset = raw_test.map(lambda x: extract_features(x, tokenizer, **extract_args), remove_columns=raw_test.column_names) | |
train_dataset.save_to_disk(cache_file_train) if save_dataset is True else None | |
test_dataset.save_to_disk(cache_file_test) if save_dataset is True else None | |
return train_dataset, test_dataset | |
class tgate(nn.Module): | |
def __init__(self, dims, num_types=4): | |
super().__init__() | |
self.gates = nn.ModuleList([nn.Sequential(Linear(dims, 1), nn.Sigmoid()) for _ in range(num_types)]) | |
self.classifier = nn.Sequential(Linear(dims, num_types), nn.Softmax(dim=-1)) | |
def forward(self, x): | |
types = self.classifier(x) | |
gates = torch.stack([gate(x) for gate in self.gates], dim=-1) | |
cgate = torch.sum(gates * types.unsqueeze(2), dim=-1) | |
return cgate | |
def get_feature_encoder(feature: str, mels: int, input_dims: int, dims: int, head: int, layer: int, act=None, features=None) -> nn.Module: | |
if feature == "spectrogram": | |
return FEncoder(mels=mels, input_dims=input_dims, dims=dims, head=head) | |
elif feature == "waveform": | |
return WEncoder(input_dims, dims, head, layer, act, feature, features) | |
elif feature == "pitch": | |
return PEncoder(input_dims, dims, head, layer, act, feature, features) | |
else: | |
raise ValueError(f"Unknown feature type: {feature}") | |
class FEncoder(nn.Module): | |
def __init__(self, mels, input_dims, dims, head, layer, act, feature, features, use_rope=False, spec_shape=None, debug=[]): | |
super().__init__() | |
self.head = head | |
self.head_dim = dims // head | |
self.dropout = 0.01 | |
self.use_rope = use_rope | |
self.dims = dims | |
self.debug = debug | |
self.feature = feature | |
self.mels = mels | |
self.input_dims = input_dims | |
act_fn = get_activation(act) | |
self.encoder = nn.Sequential( | |
Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn, | |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn, | |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) | |
if use_rope: | |
if spec_shape is not None: | |
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape) # type: ignore | |
else: | |
self.rope = None | |
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
self.norm = RMSNorm(dims) | |
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"): | |
batch, ctx, dims = x.shape | |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) | |
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore | |
x = self.rope.apply_rotary(x, freqs)# type: ignore | |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) | |
return x | |
def forward(self, x, xa=None, mask=None, feats=None, feature="audio", layer="FEncoder"): | |
x = self.encoder(x).permute(0, 2, 1) | |
if self.use_rope: | |
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) | |
else: | |
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) | |
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
print(f"feature encoder: {x.shape} {feature}") if "fencoder" in self.debug else None | |
x = self.norm(x) | |
return x | |
class WEncoder(nn.Module): # waveform encoder | |
def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None): | |
super().__init__() | |
self.head = head | |
self.head_dim = dims // head | |
self.dropout = 0.01 | |
self.use_rope = use_rope | |
self.dims = dims | |
self.debug = debug | |
act_fn = get_activation(act) | |
self.target_length = None | |
self.encoder = nn.Sequential( | |
Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn, | |
Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn, | |
Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn) | |
if use_rope: | |
if spec_shape is not None: | |
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore | |
else: | |
self.rope = None | |
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
self.norm = RMSNorm(dims) | |
def apply_rope_to_features(self, x, xa=None, mask=None, feats=None, feature="waveform", layer="WEncoder"): | |
batch, ctx, dims = x.shape | |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) | |
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer)# type: ignore | |
x = self.rope.apply_rotary(x, freqs)# type: ignore | |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) | |
return x | |
def forward(self, x, xa=None, mask=None, feats= None, feature="waveform", layer = "WEncoder"): | |
x = self.encoder(x).permute(0, 2, 1) # (batch, time, dims) | |
if self.target_length and x.shape[1] != self.target_length: | |
x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2) | |
if self.use_rope: | |
x = self.apply_rope_to_features(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) | |
else: | |
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) | |
x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
print(f"waveform encoder: {x.shape} {feature}") if "fencoder" in self.debug else None | |
return self.norm(x) | |
class PEncoder(nn.Module): # pitch encoder | |
def __init__(self, input_dims, dims, head, layer, act, use_rope=False, debug=[], one_shot=False, spec_shape=None): | |
super().__init__() | |
self.head = head | |
self.head_dim = dims // head | |
self.dims = dims | |
self.dropout = 0.01 | |
self.use_rope = use_rope | |
self.debug = debug | |
act_fn = get_activation(act) | |
self.attend_pitch = False | |
if self.attend_pitch: | |
self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head) | |
self.mlp = nn.Sequential( | |
nn.Linear(dims, dims), | |
nn.ReLU(), | |
nn.Linear(dims, dims), | |
) | |
else: | |
self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None | |
self.mlp = None | |
self.pitch_encoder = nn.Sequential( | |
Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn, | |
Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn, | |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn) | |
if use_rope: | |
self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)# type: ignore | |
else: | |
self.rope = None | |
self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale) | |
self.norm = RMSNorm(dims) | |
def rope_to_feature(self, x, xa=None, mask=None, feats=None, feature="pitch", layer="PEncoder"): | |
batch, ctx, dims = x.shape | |
x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3) | |
freqs = self.rope(ctx, feats=feats, feature=feature, layer=layer) # type: ignore | |
x = self.rope.apply_rotary(x, freqs)# type: ignore | |
x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims) | |
return x | |
def forward(self, x, xa=None, mask=None, feats= None, feature="pitch", layer="PEncoder"): | |
# f0=x | |
# freqs = self.rope(f0.shape[1], feats=feats, feature=feature, layer=layer) | |
if x.dim() == 2: | |
x = x.unsqueeze(0) | |
# if feature == "pitch": | |
x = self.pitch_encoder(x).permute(0, 2, 1) | |
if self.use_rope: | |
x = self.rope_to_feature(x, xa=xa, mask=mask, feats=feats, feature=feature, layer=layer) | |
x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype) | |
if self.mlp is not None: | |
x = self.mlp(x) | |
if self.attend_pitch: | |
if xa is not None: | |
q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head) | |
out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True) | |
x = x + out | |
# x = nn.functional.dropout(x, p=self.dropout, training=self.training) | |
x = self.norm(x) | |
print(f"Pitch encoder: {x.shape} {feature}") if "fencoder" in self.debug else None | |
return x | |
class DataCollator: | |
tokenizer: Any | |
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: | |
all_keys = set() | |
for f in features: | |
all_keys.update(f.keys()) | |
batch = {} | |
pad_token_id = getattr(self.tokenizer, 'pad_token_id', 0) | |
bos_token_id = getattr(self.tokenizer, 'bos_token_id', 1) | |
eos_token_id = getattr(self.tokenizer, 'eos_token_id', 2) | |
for key in all_keys: | |
if key == "labels": | |
labels_list = [f["labels"] for f in features] | |
max_len = max(len(l) for l in labels_list) # noqa: E741 | |
all_ids, all_labels = [], [] | |
for label in labels_list: | |
label_list = label.tolist() if isinstance(label, torch.Tensor) else label | |
decoder_input = [bos_token_id] + label_list | |
label_eos = label_list + [eos_token_id] | |
input_len = max_len + 1 - len(decoder_input) | |
label_len = max_len + 1 - len(label_eos) | |
padded_input = decoder_input + [pad_token_id] * input_len | |
padded_labels = label_eos + [pad_token_id] * label_len | |
all_ids.append(padded_input) | |
all_labels.append(padded_labels) | |
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long) | |
batch["labels"] = torch.tensor(all_labels, dtype=torch.long) | |
elif key in ["spectrogram", "waveform", "pitch", "harmonic", "aperiodic", "pitch_tokens", "f0", "phase", "crepe_time", "crepe_frequency", "crepe_confidence", "crepe_activation", "dummy"]: | |
items = [f[key] for f in features if key in f] | |
items = [item for item in items if item is not None] | |
if not items: | |
continue | |
items = [torch.tensor(item) if not isinstance(item, torch.Tensor) else item for item in items] | |
max_len = max(item.shape[-1] for item in items) | |
padded = [] | |
for item in items: | |
pad_width = max_len - item.shape[-1] | |
if pad_width > 0: | |
pad_item = F.pad(item, (0, pad_width), mode='constant', value=pad_token_id) | |
else: | |
pad_item = item | |
padded.append(pad_item) | |
batch[key] = torch.stack(padded) | |
# if key == "spectrogram": | |
# batch["spectrogram"] = batch[key] | |
return batch | |
# import tiktoken | |
# import torch | |
# from torch.utils.data import Dataset, DataLoader | |
# class tokenize(Dataset): | |
# def __init__(self, txt, tokenizer, max_length, stride): | |
# self.input_ids = [] | |
# self.labels = [] | |
# token_ids = tokenizer.encode(txt, allowed_special={"<eos>"}) | |
# for i in range(0, len(token_ids) - max_length, stride): | |
# input_chunk = token_ids[i:i + max_length] | |
# target_chunk = token_ids[i + 1: i + max_length + 1] | |
# self.input_ids.append(torch.tensor(input_chunk)) | |
# self.labels.append(torch.tensor(target_chunk)) | |
# def __len__(self): | |
# return len(self.input_ids) | |
# def __getitem__(self, idx): | |
# return self.input_ids[idx], self.labels[idx] | |
# def create_dataloader_v1(txt, batch_size, max_length, stride, shuffle=True, drop_last=True, num_workers=0): | |
# tokenizer = tiktoken.get_encoding("gpt2") | |
# dataset = tokenize(txt, tokenizer, max_length, stride) | |
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers) | |
# return dataloader | |
# def custom_collate_fn(batch, tokenizer_pad_token_id): | |
# max_len_in_batch = max(len(seq) for seq in batch) | |
# padded_input_ids = [] | |
# attention_masks = [] | |
# for seq in batch: | |
# padded_seq = F.pad(seq, (0, max_len_in_batch - len(seq)), value=tokenizer_pad_token_id) | |
# attention_mask = torch.ones(max_len_in_batch) | |
# attention_mask[len(seq):] = 0 | |
# padded_input_ids.append(padded_seq) | |
# attention_masks.append(attention_mask) | |
# input_ids_tensor = torch.stack(padded_input_ids) | |
# attention_mask_tensor = torch.stack(attention_masks) | |
# labels_tensor = input_ids_tensor.clone() | |
# return { | |
# 'input_ids': input_ids_tensor, | |
# 'attention_mask': attention_mask_tensor, | |
# 'labels': labels_tensor | |
# } | |
# with open("the-verdict.txt", "r", encoding="utf-8") as f: | |
# raw_text = f.read() | |
# vocab_size = 50257 | |
# output_dim = 256 | |
# context_length = 1024 | |
# token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim) | |
# pos_embedding_layer = torch.nn.Embedding(context_length, output_dim) | |
# batch_size = 8 | |
# max_length = 4 | |
# dataloader = create_dataloader_v1( | |
# raw_text, | |
# batch_size=batch_size, | |
# max_length=max_length, | |
# stride=max_length | |
# ) | |
def levenshtein(reference_words, hypothesis_words): | |
m, n = len(reference_words), len(hypothesis_words) | |
dist_matrix = [[0 for _ in range(n+1)] for _ in range(m+1)] | |
for i in range(m+1): | |
dist_matrix[i][0] = i | |
for j in range(n+1): | |
dist_matrix[0][j] = j | |
for i in range(1, m+1): | |
for j in range(1, n+1): | |
if reference_words[i-1] == hypothesis_words[j-1]: | |
dist_matrix[i][j] = dist_matrix[i-1][j-1] | |
else: | |
substitution = dist_matrix[i-1][j-1] + 1 | |
insertion = dist_matrix[i][j-1] + 1 | |
deletion = dist_matrix[i-1][j] + 1 | |
dist_matrix[i][j] = min(substitution, insertion, deletion) | |
return dist_matrix[m][n] | |
def wer_batch(references, hypotheses): | |
total_errors = 0 | |
total_words = 0 | |
for ref, hyp in zip(references, hypotheses): | |
ref_words = ref.lower().split() | |
errors = levenshtein(ref_words, hyp.lower().split()) | |
total_errors += errors | |
total_words += len(ref_words) | |
return (total_errors / total_words) * 100 if total_words > 0 else 0.0 | |
def compute_metrics(pred, tokenizer=None, model=None, print_pred=False, num_samples=0, logits=None): | |
def clean(ids, pad_token_id=0, bos_token_id=1, eos_token_id=2): | |
if isinstance(ids, torch.Tensor): | |
ids = ids.tolist() | |
if isinstance(ids[0], (list, torch.Tensor, np.ndarray)): | |
return [[int(i) for i in seq if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] for seq in ids] | |
else: | |
return [int(i) for i in ids if i not in (-100, pad_token_id, bos_token_id, eos_token_id)] | |
pred_ids = pred.predictions | |
label_ids = pred.label_ids | |
if isinstance(pred_ids, tuple): | |
pred_ids = pred_ids[0] | |
if not isinstance(pred_ids, torch.Tensor): | |
pred_ids = torch.tensor(pred_ids) | |
label_ids = clean(label_ids) | |
pred_ids = clean(pred_ids) | |
pred_str = tokenizer.batch_decode(pred_ids) | |
label_str = tokenizer.batch_decode(label_ids) | |
if print_pred: | |
for i in range(min(num_samples, len(pred_ids))): | |
print(f"Pred tokens: {pred_ids[i]}") | |
print(f"Label tokens: {label_ids[i]}") | |
print(f"Pred: '{pred_str[i]}'") | |
print(f"Label: '{label_str[i]}'") | |
print("-" * 40) | |
wer = wer_batch(label_str, pred_str) | |
if model is not None: | |
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000 | |
efficiency_score = (100 - wer) / trainable_params if trainable_params > 0 else 0.0 | |
else: | |
trainable_params = 0.0 | |
efficiency_score = 0.0 | |
return { | |
"wer": float(wer), | |
"efficiency_score": float(efficiency_score), | |
} | |
def preprocess_logits_for_metrics(logits, labels): | |
pred_ids = torch.argmax(logits, dim=-1) | |
return pred_ids, labels | |
def hilbert_transform(x): | |
N = x.shape[-1] | |
xf = torch.fft.rfft(x) | |
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype) | |
if N % 2 == 0: | |
h[0] = h[N//2] = 1 | |
h[1:N//2] = 2 | |
else: | |
h[0] = 1 | |
h[1:(N+1)//2] = 2 | |
return torch.fft.irfft(xf * h, n=N) | |
def analytic_signal(x): | |
return x + 1j * hilbert_transform(x) | |
def hilbert_transform_2d(x, dim=-1): | |
N = x.shape[dim] | |
if dim == -1 or dim == len(x.shape) - 1: | |
xf = torch.fft.rfft(x) | |
else: | |
xf = torch.fft.rfft(x, dim=dim) | |
h_shape = [1] * len(x.shape) | |
h_shape[dim] = N // 2 + 1 | |
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype) | |
if dim == -1 or dim == len(x.shape) - 1: | |
if N % 2 == 0: | |
h[..., 0] = h[..., -1] = 1 | |
h[..., 1:-1] = 2 | |
else: | |
h[..., 0] = 1 | |
h[..., 1:] = 2 | |
else: | |
pass | |
return torch.fft.irfft(xf * h, n=N, dim=dim) | |
def hilbert_transform_true_2d(x): | |
xf = torch.fft.rfft2(x) | |
h1, h2 = torch.meshgrid( | |
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1, | |
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1, | |
indexing='ij') | |
h = -1j / (math.pi * (h1 + 1j*h2)) | |
h[0, 0] = 0 | |
return torch.fft.irfft2(xf * h.to(x.device)) | |
def process_spectrogram_with_hilbert(spec): | |
analytic = spec + 1j * hilbert_transform(spec) | |
envelope = torch.abs(analytic) | |
phase = torch.angle(analytic) | |
return envelope, phase | |
# import torch | |
# import torch.nn as nn | |
# import torch.nn.functional as F | |
# from torch import Tensor | |
# from typing import Optional, Tuple | |
# import numpy as np | |
# from torch.nn.functional import scaled_dot_product_attention | |
# from torch.cuda.amp import autocast | |
# from torch.nn import LayerNorm, Linear | |
# import logging | |
# logging.basicConfig(level=logging.WARNING) | |
# log = logging.getLogger(__name__) | |
# class ProjectionModule(nn.Module): | |
# """ | |
# Projects input embeddings into query, key, or value representations | |
# for multi-head attention, handling scaling for Q/K. | |
# """ | |
# def __init__(self, dims: int, head: int, proj_type: str = "query", use_bias: bool = True): | |
# """ | |
# Args: | |
# dims: Input and output dimension. | |
# head: Number of attention heads. | |
# proj_type: Type of projection ("query", "key", "value"). | |
# use_bias: Whether to use bias in the linear layer. | |
# """ | |
# super().__init__() | |
# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" | |
# assert proj_type in ["query", "key", "value"], \ | |
# f"proj_type must be 'query', 'key', or 'value', got {proj_type}" | |
# self.dims = dims | |
# self.head = head | |
# self.head_dim = dims // head | |
# self.proj_type = proj_type | |
# self.scale = self.head_dim ** -0.5 if proj_type != "value" else 1.0 | |
# self.proj = Linear(in_features=dims, out_features=dims, bias=use_bias) | |
# self.init_weights() | |
# def init_weights(self): | |
# """Initialize projection weights.""" | |
# nn.init.normal_(tensor=self.proj.weight, std=0.02) | |
# if self.proj.bias is not None: | |
# nn.init.zeros_(tensor=self.proj.bias) | |
# def forward(self, x: Tensor) -> Tensor: | |
# """ | |
# Applies projection, scaling (for Q/K), and reshapes for multi-head attention. | |
# Args: | |
# x: Input tensor of shape (batch, seq_len, dims). | |
# Returns: | |
# Projected tensor of shape (batch, head, seq_len, head_dim). | |
# """ | |
# batch, seq_len, _ = x.shape | |
# proj = self.proj(x) | |
# proj = proj.view(batch, seq_len, self.head, self.head_dim).permute(0, 2, 1, 3) | |
# if self.proj_type in ["query", "key"]: | |
# proj = proj * self.scale | |
# return proj | |
# def calculate_attention( | |
# q: Tensor, | |
# k: Tensor, | |
# v: Tensor, | |
# mask: Optional[Tensor] = None, | |
# temperature: float = 1.0, | |
# use_sdpa: bool = True, | |
# is_causal: bool = False, | |
# dropout_p: float = 0.0 | |
# ) -> Tuple[Tensor, Optional[Tensor]]: | |
# """ | |
# Calculates scaled dot-product attention. | |
# Uses torch.nn.functional.scaled_dot_product_attention if use_sdpa is True | |
# and inputs are compatible, otherwise falls back to manual implementation. | |
# Args: | |
# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Already scaled if needed. | |
# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Already scaled if needed. | |
# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). | |
# mask: Attention mask. Can be boolean (True means ignore) or float (-inf means ignore). | |
# Shape should be broadcastable to (Batch, Heads, SeqLen_Q, SeqLen_K). | |
# temperature: Softmax temperature scaling. Applied *before* softmax. | |
# use_sdpa: Flag to attempt using PyTorch's optimized SDPA implementation. | |
# is_causal: If True, applies a causal mask (for decoder self-attention). | |
# Used only if mask is None and use_sdpa is True. | |
# dropout_p: Dropout probability for attention weights. | |
# Returns: | |
# A tuple containing: | |
# - attn_output: Attention output tensor (Batch, Heads, SeqLen_Q, HeadDim). | |
# - attn_weights: Attention weights tensor (Batch, Heads, SeqLen_Q, SeqLen_K), | |
# or None if SDPA implementation doesn't return them or if fallback used. | |
# *Note: SDPA's default doesn't return weights, requires specific backend support.* | |
# *Manual path always returns weights.* | |
# """ | |
# batch_size, num_heads, q_len, head_dim = q.shape | |
# k_len = k.size(2) | |
# temp_scale = 1.0 / temperature if temperature > 0 else 1.0 | |
# attn_output, attn_weights = None, None | |
# if use_sdpa: | |
# try: | |
# if temperature != 1.0: | |
# raise NotImplementedError("SDPA does not directly support temperature scaling. Use manual path or scale Q.") | |
# attn_output = scaled_dot_product_attention( | |
# q, k, v, | |
# attn_mask=mask, | |
# dropout_p=dropout_p, | |
# is_causal=is_causal and mask is None | |
# ) | |
# attn_weights = None | |
# return attn_output, attn_weights | |
# except (RuntimeError, NotImplementedError) as e: | |
# log.warning(f"SDPA failed or not used ({e}), falling back to manual attention.") | |
# attn_scores = torch.matmul(q, k.transpose(-2, -1)) * temp_scale | |
# if mask is not None: | |
# if mask.dim() == 2: | |
# mask = mask.unsqueeze(0).unsqueeze(0) | |
# elif mask.dim() == 3: | |
# mask = mask.unsqueeze(1) | |
# expected_mask_shape = (batch_size, num_heads, q_len, k_len) | |
# if mask.shape != expected_mask_shape: | |
# try: | |
# mask = mask.expand(expected_mask_shape) | |
# except RuntimeError: | |
# raise ValueError(f"Mask shape {mask.shape} is not compatible with attention scores shape {expected_mask_shape}") | |
# if mask.dtype == torch.bool: | |
# attn_scores = attn_scores.masked_fill(mask, float("-inf")) | |
# else: | |
# attn_scores = attn_scores + mask | |
# attn_weights = F.softmax(attn_scores, dim=-1) | |
# if dropout_p > 0.0: | |
# attn_weights = F.dropout(attn_weights, p=dropout_p) | |
# attn_output = torch.matmul(attn_weights, v) | |
# return attn_output, attn_weights | |
# class BaseAttention(nn.Module): | |
# """Base class for attention mechanisms with common functionality.""" | |
# use_sdpa = True | |
# def __init__(self, dims: int, head: int, max_dist: int = 512, dropout: float = 0.0): | |
# """ | |
# Args: | |
# dims: Embedding dimension. | |
# head: Number of attention heads. | |
# max_dist: Maximum attention distance (used by some subclasses). | |
# dropout: Dropout probability for attention weights. | |
# """ | |
# super().__init__() | |
# assert dims % head == 0, f"dims ({dims}) must be divisible by head ({head})" | |
# self.dims = dims | |
# self.head = head | |
# self.head_dim = dims // head | |
# self.max_dist = max_dist | |
# self.dropout = dropout | |
# def _shape(self, tensor: torch.Tensor) -> torch.Tensor: | |
# """ | |
# Reshape tensor from (batch, seq_len, dims) to | |
# (batch, head, seq_len, head_dim) for multi-head attention. | |
# """ | |
# batch, seq_len, _ = tensor.shape | |
# return tensor.view(batch, seq_len, self.head, self.head_dim).transpose(1, 2).contiguous() | |
# def _reshape_to_output(self, attn_output: Tensor) -> Tensor: | |
# """ | |
# Reshape attention output from (batch, head, seq_len, head_dim) | |
# back to (batch, seq_len, dims). | |
# """ | |
# batch, _, seq_len, _ = attn_output.shape | |
# return attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.dims) | |
# class AttentionCombiner(BaseAttention): | |
# """ | |
# Computes attention given Q, K, V projections and applies an output projection. | |
# Assumes Q, K, V inputs are already projected and appropriately shaped/scaled. | |
# """ | |
# def __init__(self, dims: int, head: int, use_bias: bool = True, dropout: float = 0.0): | |
# """ | |
# Args: | |
# dims: Embedding dimension. | |
# head: Number of attention heads. | |
# use_bias: Whether to use bias in the output projection. | |
# dropout: Dropout probability for attention weights. | |
# """ | |
# super().__init__(dims, head, dropout=dropout) | |
# self.out = Linear(in_features=dims, out_features=dims, bias=use_bias) | |
# self._init_weights() | |
# def _init_weights(self): | |
# """Initialize output projection weights.""" | |
# nn.init.normal_(tensor=self.out.weight, std=0.02) | |
# if self.out.bias is not None: | |
# nn.init.zeros_(tensor=self.out.bias) | |
# # @autocast('cuda', enabled=torch.cuda.is_available()) | |
# def forward(self, q: Tensor, k: Tensor, v: Tensor, | |
# mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor: | |
# """ | |
# Processes Q, K, V through attention and output projection. | |
# Args: | |
# q: Query tensor (Batch, Heads, SeqLen_Q, HeadDim). Assumed scaled. | |
# k: Key tensor (Batch, Heads, SeqLen_K, HeadDim). Assumed scaled. | |
# v: Value tensor (Batch, Heads, SeqLen_K, HeadDim). | |
# mask: Attention mask. | |
# is_causal: Whether to apply causal masking (if mask is None). | |
# Returns: | |
# Output tensor (Batch, SeqLen_Q, Dims). | |
# """ | |
# attn_output, _ = calculate_attention( | |
# q, k, v, mask=mask, | |
# temperature=1.0, | |
# use_sdpa=BaseAttention.use_sdpa, | |
# is_causal=is_causal, | |
# dropout_p = self.dropout | |
# ) | |
# output = self._reshape_to_output(attn_output) | |
# return self.out(output) | |
# class AdaptiveUpdateAttention(BaseAttention): | |
# """ | |
# Attention implementation where Key and Value representations are cached | |
# and only updated based on content-dependent predictors. Suitable for | |
# encoder layers or cross-attention where K/V context changes less frequently. | |
# Note: Current implementation focuses on conditional update based on *current* | |
# input, not standard auto-regressive KV caching for generation. | |
# """ | |
# def __init__(self, dims: int, head: int, max_dist: int = 512, update_threshold: float = 0.5, dropout: float = 0.0): | |
# """ | |
# Args: | |
# dims: Embedding dimension. | |
# head: Number of attention heads. | |
# max_dist: Maximum attention distance (inherited, may not be directly used here). | |
# update_threshold: Threshold for sigmoid output of predictors to trigger update. | |
# dropout: Dropout probability for attention weights. | |
# """ | |
# super().__init__(dims, head, max_dist, dropout=dropout) | |
# self.query_module = ProjectionModule(dims, head, "query") | |
# self.key_module = ProjectionModule(dims, head, "key") | |
# self.value_module = ProjectionModule(dims, head, "value") | |
# self.combiner = AttentionCombiner(dims, head, dropout=dropout) | |
# self.key_update_predictor = nn.Sequential( | |
# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) | |
# self.value_update_predictor = nn.Sequential( | |
# Linear(dims, dims // 4), nn.ReLU(), Linear(dims // 4, 1), nn.Sigmoid()) | |
# self.update_threshold = update_threshold | |
# self.stored_key_cache: Optional[Tensor] = None | |
# self.stored_value_cache: Optional[Tensor] = None | |
# self.reset_cache_on_forward = True | |
# def _should_update(self, x: torch.Tensor, predictor: nn.Module) -> torch.Tensor: | |
# """Predict whether K or V should be updated based on content.""" | |
# avg_rep = x.mean(dim=1) | |
# update_prob = predictor(avg_rep) | |
# return update_prob > self.update_threshold | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, | |
# mask: Optional[Tensor] = None, | |
# is_causal: bool = False) -> Tensor: | |
# """ | |
# Process inputs with adaptive K/V update mechanism. | |
# Args: | |
# x: Input tensor for queries (Batch, SeqLen_Q, Dims). | |
# xa: Optional input tensor for keys/values (for cross-attention). | |
# If None, uses x for self-attention (Batch, SeqLen_KV, Dims). | |
# mask: Attention mask. | |
# is_causal: Whether attention should be causal. | |
# Returns: | |
# Output tensor (Batch, SeqLen_Q, Dims). | |
# """ | |
# if self.reset_cache_on_forward: | |
# self.stored_key_cache = None | |
# self.stored_value_cache = None | |
# batch, ctx_q, _ = x.shape | |
# q = self.query_module(x) | |
# kv_input = xa if xa is not None else x | |
# ctx_kv = kv_input.size(1) | |
# update_k_batch = self._should_update(kv_input, self.key_update_predictor) | |
# update_v_batch = self._should_update(kv_input, self.value_update_predictor) | |
# if self.stored_key_cache is None or self.stored_key_cache.shape[2] != ctx_kv or self.stored_key_cache.shape[0] != batch: | |
# k = self.key_module(kv_input) | |
# self.stored_key_cache = k | |
# elif update_k_batch.any(): | |
# new_k = self.key_module(kv_input) | |
# update_mask_k = update_k_batch.view(-1, 1, 1, 1).expand_as(self.stored_key_cache) | |
# k = torch.where(update_mask_k, new_k, self.stored_key_cache) | |
# self.stored_key_cache = k | |
# else: | |
# k = self.stored_key_cache | |
# if self.stored_value_cache is None or self.stored_value_cache.shape[2] != ctx_kv or self.stored_value_cache.shape[0] != batch: | |
# v = self.value_module(kv_input) | |
# self.stored_value_cache = v | |
# elif update_v_batch.any(): | |
# new_v = self.value_module(kv_input) | |
# update_mask_v = update_v_batch.view(-1, 1, 1, 1).expand_as(self.stored_value_cache) | |
# v = torch.where(update_mask_v, new_v, self.stored_value_cache) | |
# self.stored_value_cache = v | |
# else: | |
# v = self.stored_value_cache | |
# output = self.combiner(q, k, v, mask=mask, is_causal=is_causal) | |
# return output | |
# class Refiner: | |
# """ | |
# Q-learning based agent to refine parameters (e.g., attention span). | |
# Operates outside the standard backpropagation loop. | |
# """ | |
# def __init__(self, states: int, actions: int, alpha: float = 0.1, gamma: float = 0.9, epsilon: float = 0.1): | |
# self.states = states | |
# self.actions = actions | |
# self.R = {} | |
# self.alpha = alpha | |
# self.gamma = gamma | |
# self.epsilon = epsilon | |
# self.default_value = 0.0 | |
# def get_value(self, state: int, action: int) -> float: | |
# """Get Q-value for state-action pair.""" | |
# return self.R.get((state, action), self.default_value) | |
# def set_value(self, state: int, action: int, value: float): | |
# """Set Q-value for state-action pair.""" | |
# self.R[(state, action)] = value | |
# def choose_action(self, state: int) -> int: | |
# """Choose action using epsilon-greedy strategy.""" | |
# if np.random.random() < self.epsilon: | |
# return np.random.randint(self.actions) | |
# else: | |
# action_values = [self.get_value(state, a) for a in range(self.actions)] | |
# return np.argmax(action_values).item() | |
# def update(self, state: int, action: int, reward: float, next_state: int): | |
# """Update Q-value using the Q-learning rule.""" | |
# next_values = [self.get_value(next_state, a) for a in range(self.actions)] | |
# best_next_value = max(next_values) if next_values else self.default_value | |
# old_value = self.get_value(state, action) | |
# td_target = reward + self.gamma * best_next_value | |
# td_error = td_target - old_value | |
# new_value = old_value + self.alpha * td_error | |
# self.set_value(state, action, new_value) | |
# class Predictor(nn.Module): | |
# """Neural predictor for estimating a scale value (e.g., for adaptive span).""" | |
# def __init__(self, dims: int): | |
# super().__init__() | |
# self.linear = Linear(in_features=dims, out_features=1) | |
# self._init_weights() | |
# def _init_weights(self): | |
# """Initialize predictor weights.""" | |
# nn.init.xavier_normal_(self.linear.weight) | |
# if self.linear.bias is not None: | |
# nn.init.zeros_(self.linear.bias) | |
# def forward(self, x: Tensor) -> Tensor: | |
# """ | |
# Predicts a scale factor (0-1) from input features. | |
# Args: | |
# x: Input tensor (Batch, SeqLen, Dims) or (Batch, Dims). | |
# Returns: | |
# Scale tensor (Batch, 1). | |
# """ | |
# if x.dim() > 2: | |
# x = x.mean(dim=1) | |
# scale = torch.sigmoid(self.linear(x)) | |
# return scale | |
# class AdaptiveSpanAttention(BaseAttention): | |
# """ | |
# Attention mechanism where the span is dynamically adjusted based on a | |
# learnable parameter or predicted scale. This version focuses on slicing | |
# the input sequence to the effective span. | |
# Note: This implementation attends only to the *first* `eff_span` tokens. | |
# For attending to a *relative* window, different logic (e.g., sliding window | |
# or masking) would be needed in `calculate_attention`. | |
# """ | |
# def __init__(self, dims: int, head: int, max_dist: int = 512, | |
# initial_span_scale: float = 1.0, learnable_scale: bool = True, | |
# sharpen: bool = True, temp_scale: float = 0.01, dropout: float = 0.0): | |
# """ | |
# Args: | |
# dims, head, max_dist, dropout: Standard BaseAttention params. | |
# initial_span_scale: Initial value for the span scale. | |
# learnable_scale: If True, span_scale is an nn.Parameter. | |
# sharpen, temp_scale: Parameters for dynamic temperature adjustment. | |
# """ | |
# super().__init__(dims, head, max_dist, dropout=dropout) | |
# self.sharpen = sharpen | |
# self.temp_scale = temp_scale | |
# if learnable_scale: | |
# self.span_scale = nn.Parameter(torch.tensor(initial_span_scale)) | |
# else: | |
# self.register_buffer("span_scale", torch.tensor(initial_span_scale)) | |
# self.query_module = ProjectionModule(dims, head, "query") | |
# self.key_module = ProjectionModule(dims, head, "key") | |
# self.value_module = ProjectionModule(dims, head, "value") | |
# self.out_proj = Linear(dims, dims) | |
# @autocast('cuda', enabled=torch.cuda.is_available()) | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, | |
# mask: Optional[Tensor] = None, | |
# span_scale_override: Optional[Tensor] = None, | |
# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: | |
# """ | |
# Computes attention over an adaptively determined span. | |
# Args: | |
# x: Input tensor for Q (Batch, SeqLen_Q, Dims). | |
# xa: Optional input for K/V (Batch, SeqLen_KV, Dims). If None, use x. | |
# mask: External attention mask. | |
# span_scale_override: Optional tensor (Batch, 1) or scalar to override internal span_scale. | |
# is_causal: Whether to apply causal masking. | |
# Returns: | |
# Tuple of (output tensor (Batch, SeqLen_Q, Dims), attention weights (optional)). | |
# """ | |
# kv_input = xa if xa is not None else x | |
# batch, ctx_q, _ = x.shape | |
# ctx_kv = kv_input.size(1) | |
# current_span_scale = span_scale_override if span_scale_override is not None else self.span_scale | |
# if isinstance(current_span_scale, nn.Parameter): | |
# span_scale_val = current_span_scale.sigmoid() | |
# elif current_span_scale.numel() == 1: | |
# span_scale_val = current_span_scale.expand(batch, 1) | |
# else: | |
# span_scale_val = current_span_scale | |
# span_mean = span_scale_val.mean().item() | |
# max_span_len = ctx_kv | |
# target_span_len = max(1, int(max_span_len * span_mean)) | |
# eff_span = min(target_span_len, self.max_dist, ctx_q, ctx_kv) | |
# if eff_span == 0: | |
# return (torch.zeros_like(x), None) | |
# q_span = x[:, :eff_span, :] | |
# k_span = kv_input[:, :eff_span, :] | |
# v_span = kv_input[:, :eff_span, :] | |
# q_proj = self.query_module(q_span) | |
# k_proj = self.key_module(k_span) | |
# v_proj = self.value_module(v_span) | |
# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) | |
# if self.sharpen | |
# else 0.5 + self.temp_scale * span_mean) | |
# temperature = max(temperature, 1e-3) | |
# span_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: | |
# span_mask = mask[:, :, :eff_span, :eff_span] | |
# elif mask.dim() == 2: | |
# span_mask = mask[:eff_span, :eff_span] | |
# attn_output_span, attn_weights = calculate_attention( | |
# q_proj, k_proj, v_proj, | |
# mask=span_mask, | |
# temperature=temperature, | |
# use_sdpa=BaseAttention.use_sdpa, | |
# is_causal=is_causal, | |
# dropout_p=self.dropout | |
# ) | |
# output_span = self._reshape_to_output(attn_output_span) | |
# projected_output_span = self.out_proj(output_span) | |
# output = torch.zeros_like(x) | |
# output[:, :eff_span, :] = projected_output_span | |
# return output, attn_weights | |
# class MyelinatedLayer(BaseAttention): | |
# """ | |
# A complex Transformer layer featuring: | |
# - Integrated local/global attention (via IntegratedAttention). | |
# - Optional adapters within sub-layers. | |
# - Node importance prediction for sparsity. | |
# - MLP block. | |
# - Working memory component. | |
# - Potential layer skipping ("jumping") based on a learned policy. | |
# (This version assumes IntegratedAttention is the core attention mechanism). | |
# """ | |
# def __init__(self, dims: int, head: int, num_layers: int = 3, | |
# sparsity_threshold: float = 0.1, max_dist: int = 512, | |
# dropout: float = 0.1, mlp_ratio: int = 4): | |
# super().__init__(dims, head, max_dist, dropout) | |
# self.num_layers = num_layers | |
# self.sparsity_threshold = sparsity_threshold | |
# self.attention = IntegratedAttention(dims, head, max_dist=max_dist, dropout=dropout) | |
# self.sub_layers = nn.ModuleList() | |
# self.node_predictors = nn.ModuleList([ | |
# nn.Sequential(LayerNorm(dims), Linear(dims, 1), nn.Sigmoid()) | |
# for _ in range(num_layers)]) | |
# for i in range(num_layers): | |
# self.sub_layers.append(nn.ModuleDict({ | |
# 'ln': LayerNorm(dims), | |
# 'gate': nn.Sequential(Linear(dims, 1), nn.Sigmoid()), | |
# 'adapter': Linear(dims, dims) if i % 2 == 0 else None | |
# })) | |
# self.policy_net = nn.Sequential(Linear(dims, 128), nn.ReLU(), Linear(128, num_layers)) | |
# self.jump_weights = nn.Parameter(torch.tensor([0.1, 0.05, 0.01])) | |
# n_mlp = dims * mlp_ratio | |
# self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) | |
# self.mlp = nn.Sequential(Linear(dims, n_mlp), nn.GELU(), Linear(n_mlp, dims), nn.Dropout(dropout)) | |
# self.mlp_ln = LayerNorm(dims) | |
# self.working_memory = nn.Parameter(torch.zeros(1, 1, dims)) | |
# self.memory_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid()) | |
# self.last_memory_gate_values: Optional[Tensor] = None | |
# def predict_node_importance(self, x: Tensor, layer_idx: int) -> Tensor: | |
# """Predict token importance mask (0.0 or 1.0) for sparsity.""" | |
# importance = self.node_predictors[layer_idx](x) | |
# is_important = (importance > self.sparsity_threshold).float() | |
# return is_important | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, | |
# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, | |
# is_causal: bool = True) -> Tensor: | |
# batch, ctx, _ = x.shape | |
# working_memory = self.working_memory.expand(batch, 1, -1).to(x.device) | |
# original_x = x | |
# pooled_representation = x.mean(dim=1) | |
# policy_logits = self.policy_net(pooled_representation) | |
# policy = F.softmax(policy_logits, dim=-1) | |
# jump_history = [] | |
# i = 0 | |
# last_processed_output = x | |
# while i < self.num_layers: | |
# layer = self.sub_layers[i] | |
# node_importance_mask = self.predict_node_importance(x, i) | |
# if node_importance_mask.mean() < 0.2 and i > 0: | |
# i += 1 | |
# jump_history.append(f"skip_low_imp->{i}") | |
# continue | |
# norm_x = layer['ln'](x) | |
# current_attn_mask = node_importance_mask.permute(0, 2, 1) | |
# if mask is not None: | |
# pass | |
# attn_output = self.attention( | |
# norm_x * node_importance_mask, | |
# xa=xa, | |
# mask=mask, | |
# kv_cache=kv_cache, | |
# is_causal=is_causal | |
# ) | |
# if layer['adapter'] is not None: | |
# attn_output = layer['adapter'](attn_output) | |
# gate_value = layer['gate'](norm_x) | |
# x = x + gate_value * attn_output * node_importance_mask | |
# last_processed_output = x | |
# memory_gate = self.memory_gate(x.mean(dim=1, keepdim=True)) | |
# current_mean_x = x.mean(dim=1, keepdim=True) | |
# working_memory = memory_gate * working_memory + (1 - memory_gate) * current_mean_x | |
# if i < self.num_layers - 1: | |
# jump_prob_dist = policy[:, 1:] | |
# jump_prob = jump_prob_dist.sum(dim=-1) | |
# should_jump_batch = torch.rand_like(jump_prob) < jump_prob | |
# if should_jump_batch.any(): | |
# jump_len_probs = policy[should_jump_batch, :self.num_layers-i] | |
# sampled_jump_len = torch.multinomial(jump_len_probs, 1)[:, 0] + 1 | |
# jump_length = sampled_jump_len.max().item() | |
# i_next = min(i + jump_length, self.num_layers) | |
# skip_weight_idx = min(jump_length - 1, len(self.jump_weights) - 1) | |
# skip_weight = self.jump_weights[skip_weight_idx] | |
# x = skip_weight * original_x + (1 - skip_weight) * working_memory.expand_as(x) + x * (1-skip_weight) | |
# jump_history.append(f"jump_{jump_length} S:{skip_weight.item():.2f} ->{i_next}") | |
# i = i_next | |
# continue | |
# i += 1 | |
# mlp_input = last_processed_output | |
# norm_mlp_input = self.mlp_ln(mlp_input) | |
# mlp_output = self.mlp(norm_mlp_input) | |
# mlp_gate_value = self.mlp_gate(norm_mlp_input) | |
# final_output = mlp_input + mlp_gate_value * mlp_output | |
# if 'memory_gate' in locals(): | |
# self.last_memory_gate_values = memory_gate.detach().clone() | |
# return final_output | |
# class IntegratedAttention(BaseAttention): | |
# """ | |
# Integrates multiple attention strategies: | |
# - Local attention (sliding window or adaptive span via AdaptiveSpanAttention). | |
# - Global attention (potentially with adaptive updates via AdaptiveUpdateAttention). | |
# - Cross-attention capability. | |
# - RL-based refinement (`Refiner`) of the local attention span. | |
# - Iterative refinement (`_focus`) within local attention. | |
# """ | |
# def __init__(self, dims: int, head: int, max_dist: int = 512, | |
# win_size: int = 256, max_span: int = 384, temp_scale: float = 0.01, | |
# dropout: float = 0.1, | |
# rl_states: int = 10000, rl_actions: int = 10, rl_alpha: float = 0.1, | |
# rl_gamma: float = 0.9, rl_epsilon: float = 0.1): | |
# super().__init__(dims, head, max_dist, dropout=dropout) | |
# self.max_span = max_span | |
# self.sliding_window = win_size | |
# self.temp_scale = temp_scale | |
# self.sharpen = True | |
# self.refiner = Refiner( | |
# states=rl_states, actions=rl_actions, alpha=rl_alpha, | |
# gamma=rl_gamma, epsilon=rl_epsilon) | |
# self.span_pred = Predictor(dims=dims) | |
# self.attn_local = AdaptiveSpanAttention( | |
# dims=dims, head=head, max_dist=max_dist, sharpen=self.sharpen, | |
# temp_scale=temp_scale, learnable_scale=False, | |
# dropout=dropout) | |
# self.attn_global = AdaptiveUpdateAttention( | |
# dims=dims, head=head, max_dist=max_dist, dropout=dropout) | |
# self.cross_attn = AttentionCombiner(dims=dims, head=head, dropout=dropout) | |
# self.self_projection = Linear(in_features=2 * dims, out_features=dims) | |
# self.global_cross_projection = Linear(in_features=dims, out_features=dims) | |
# self.ln_local_in = LayerNorm(normalized_shape=dims) | |
# self.ln_global_in = LayerNorm(normalized_shape=dims) | |
# self.ln_cross_in = LayerNorm(normalized_shape=dims) | |
# self.register_buffer("threshold", torch.tensor(1e-4), persistent=False) | |
# self.register_buffer("s_factor", torch.tensor(0.1), persistent=False) | |
# def forward(self, x: Tensor, xa: Optional[Tensor] = None, | |
# mask: Optional[Tensor] = None, kv_cache: Optional[Tensor] = None, | |
# is_causal: bool = True) -> Tensor: | |
# """ | |
# Main forward pass distributing to cross or self-attention pathways. | |
# Args: | |
# x: Primary input tensor (Batch, SeqLen_Q, Dims). | |
# xa: Context tensor for cross-attention (Batch, SeqLen_KV, Dims). | |
# mask: Attention mask (padding or causal). | |
# kv_cache: Key/Value cache for generation (specific usage depends on sub-modules). | |
# is_causal: Flag for causal masking in self-attention. | |
# Returns: | |
# Output tensor (Batch, SeqLen_Q, Dims). | |
# """ | |
# batch, ctx_q, _ = x.shape | |
# if xa is not None: | |
# q_norm = self.ln_cross_in(x) | |
# k_cross = self.attn_global.key_module(xa) | |
# v_cross = self.attn_global.value_module(xa) | |
# q_cross = self.attn_global.query_module(q_norm) | |
# cross_out = self.cross_attn(q=q_cross, k=k_cross, v=v_cross, mask=mask, is_causal=False) | |
# return self.global_cross_projection(cross_out) | |
# local_input = self.ln_local_in(x) | |
# global_input = self.ln_global_in(x) | |
# globe_out_raw = self.attn_global( | |
# global_input, | |
# xa=None, | |
# mask=mask, | |
# is_causal=is_causal | |
# ) | |
# globe_out = self.global_cross_projection(globe_out_raw) | |
# base_freq_scale = self.span_pred(globe_out) | |
# state = self._extract_rl_state(local_input) | |
# action = self.refiner.choose_action(state=state) | |
# refinement_scale = self._action_to_scale(action=action) | |
# final_span_scale = torch.clamp(base_freq_scale * refinement_scale.expand_as(base_freq_scale), min=0.0, max=1.0) | |
# span_mean = final_span_scale.mean().item() | |
# with torch.no_grad(): | |
# current_win_size = max(1, int(self.sliding_window * span_mean)) | |
# current_span_len = max(1, int(self.max_span * span_mean)) | |
# local_out_raw = self._slide_win_local( | |
# x=local_input, | |
# win_size=current_win_size, | |
# span_len=current_span_len, | |
# span_scale=final_span_scale, | |
# mask=mask, | |
# is_causal=is_causal | |
# ) | |
# with torch.no_grad(): | |
# reward = self._calculate_rl_reward(output=local_out_raw) | |
# next_state = self._extract_rl_state(local_out_raw) | |
# self.refiner.update(state=state, action=action, reward=reward, next_state=next_state) | |
# combined = torch.cat([local_out_raw, globe_out], dim=-1) | |
# output = self.self_projection(combined) | |
# return output | |
# def _calculate_rl_reward(self, output: Tensor) -> float: | |
# """Calculate quality metric (reward) for reinforcement learning.""" | |
# with torch.no_grad(): | |
# output_probs = torch.softmax(output, dim=-1) | |
# safe_probs = torch.clamp(output_probs, min=1e-10) | |
# entropy = -(safe_probs * torch.log(safe_probs)).sum(-1).mean() | |
# coverage = (output.abs() > 0.01).float().mean() | |
# reward = float(coverage - 0.1 * entropy) | |
# return reward | |
# def _extract_rl_state(self, x: Tensor) -> int: | |
# """Extract discrete state features for RL agent from tensor.""" | |
# with torch.no_grad(): | |
# pooled = x.mean(dim=1) | |
# mean_state = pooled[0].mean() | |
# var_state = pooled[0].var(unbiased=False) | |
# state_features = torch.stack([mean_state, var_state]).cpu().numpy() | |
# state_id = self._discretize_state(state_features) | |
# return state_id | |
# def _discretize_state(self, state: np.ndarray) -> int: | |
# """Convert continuous state numpy array to a discrete state ID.""" | |
# bins = np.linspace(-1, 1, num=10) | |
# state_discrete = np.digitize(state, bins) | |
# state_hash = sum(val * (10**i) for i, val in enumerate(state_discrete)) | |
# state_id = int(state_hash % self.refiner.states) | |
# return state_id | |
# def _action_to_scale(self, action: int) -> Tensor: | |
# """Convert discrete RL action index to a continuous scale factor [0, 1].""" | |
# span_value = action / (self.refiner.actions - 1) | |
# scale_tensor = torch.tensor([span_value], device=self.span_pred.linear.weight.device, dtype=torch.float) | |
# return scale_tensor | |
# def _focus(self, query: Tensor, key: Tensor, value: Tensor, | |
# span_scale: Tensor, mask: Optional[Tensor] = None, | |
# is_causal: bool = False) -> Tuple[Tensor, Optional[Tensor]]: | |
# """ | |
# Iterative attention refinement. Applies attention multiple times, | |
# adding the output back to the query. Uses manual attention calculation. | |
# Args: | |
# query, key, value: Input tensors (B, SeqLen_Window, D). | |
# span_scale: Scale factor (scalar or B, 1) influencing effective span. | |
# mask: Attention mask for the window. | |
# is_causal: Apply causal masking within the window. | |
# Returns: | |
# Tuple (refined_output (B, SeqLen_Window, D), attention_weights (optional, None here)). | |
# """ | |
# max_iterations = 5 | |
# iteration = 0 | |
# prev_attn_out = torch.zeros_like(query) | |
# attn_out = torch.zeros_like(query) | |
# threshold = self.threshold.item() | |
# s_factor = self.s_factor.item() | |
# q_current = query | |
# while iteration < max_iterations: | |
# span_mean = span_scale.mean().item() | |
# target_span_len = max(1, int(self.max_span * span_mean)) | |
# eff_span = min(target_span_len, self.max_dist, q_current.size(1), key.size(1)) | |
# if eff_span == 0: break | |
# q_iter = q_current[:, :eff_span, :] | |
# k_iter = key[:, :eff_span, :] | |
# v_iter = value[:, :eff_span, :] | |
# q_proj = self.attn_local.query_module(q_iter) | |
# k_proj = self.attn_local.key_module(k_iter) | |
# v_proj = self.attn_local.value_module(v_iter) | |
# temperature = (1.0 + self.temp_scale * (1.0 - span_mean) | |
# if self.sharpen | |
# else 0.5 + self.temp_scale * span_mean) | |
# temperature = max(temperature, 1e-3) | |
# iter_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] | |
# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] | |
# attn_output_iter, _ = calculate_attention( | |
# q_proj, k_proj, v_proj, | |
# mask=iter_mask, | |
# temperature=temperature, | |
# use_sdpa=False, | |
# is_causal=is_causal, | |
# dropout_p=self.dropout | |
# ) | |
# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) | |
# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) | |
# current_iter_out = torch.zeros_like(q_current) | |
# current_iter_out[:, :eff_span, :] = projected_attn_out_span | |
# diff = torch.abs(current_iter_out - prev_attn_out).mean() | |
# dynamic_threshold = threshold + s_factor * diff | |
# if diff < dynamic_threshold and iteration > 0: | |
# attn_out = current_iter_out | |
# break | |
# prev_attn_out = current_iter_out.clone() | |
# q_current = q_current + current_iter_out | |
# attn_out = current_iter_out | |
# iteration += 1 | |
# return attn_out, None | |
# @autocast('cuda', enabled=torch.cuda.is_available()) | |
# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, | |
# span_scale: Tensor, mask: Optional[Tensor] = None, | |
# is_causal: bool = False) -> Tensor: | |
# """ | |
# Process input with sliding window attention, using `_focus` for each window. | |
# Args: | |
# x: Input tensor (Batch, SeqLen, Dims). | |
# win_size: Size of the attention window for queries. | |
# span_len: Max length of keys/values relative to query window start. | |
# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. | |
# mask: Full attention mask. | |
# is_causal: Apply causal masking within windows. | |
# Returns: | |
# Output tensor (Batch, SeqLen, Dims). | |
# """ | |
# batch, ctx, dims = x.size() | |
# output = torch.zeros_like(x) | |
# num_windows = (ctx + win_size - 1) // win_size | |
# for i in range(num_windows): | |
# q_start = i * win_size | |
# q_end = min(q_start + win_size, ctx) | |
# current_window_q_len = q_end - q_start | |
# if current_window_q_len == 0: continue | |
# kv_start = max(0, q_end - span_len) | |
# kv_end = q_end | |
# query_win = x[:, q_start:q_end, :] | |
# key_win = x[:, kv_start:kv_end, :] | |
# value_win = x[:, kv_start:kv_end, :] | |
# window_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: | |
# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] | |
# elif mask.dim() == 2: | |
# window_mask = mask[q_start:q_end, kv_start:kv_end] | |
# attn_out_win, _ = self._focus( | |
# query=query_win, | |
# key=key_win, | |
# value=value_win, | |
# span_scale=span_scale, | |
# mask=window_mask, | |
# is_causal=is_causal | |
# ) | |
# output[:, q_start:q_end, :] = attn_out_win | |
# return output | |
# class CTCDecoder(nn.Module): | |
# def __init__(self, input_dim: int, vocab_size: int, dims: int = 256, num_layers: int = 2, dropout: float = 0.1): | |
# super().__init__() | |
# self.input_dim = input_dim | |
# self.vocab_size = vocab_size | |
# self.dims = dims | |
# self.projection = nn.Linear(input_dim, dims) | |
# self.lstm = nn.LSTM(dims, dims, num_layers, dropout=dropout if num_layers > 1 else 0, batch_first=True, bidirectional=True) | |
# self.output = nn.Linear(dims * 2, vocab_size + 1) # +1 for CTC blank token | |
# self.dropout = nn.Dropout(dropout) | |
# def forward(self, x: Tensor) -> Tensor: | |
# x = self.projection(x) # (batch, seq_len, dims) | |
# x = self.dropout(x) | |
# x, _ = self.lstm(x) # (batch, seq_len, dims * 2) | |
# x = self.dropout(x) | |
# logits = self.output(x) # (batch, seq_len, vocab_size + 1) | |
# return logits | |
# class CTCWrapper(nn.Module): | |
# def __init__(self, model: Model, vocab_size: int, dims: int = 256, num_layers: int = 2): | |
# super().__init__() | |
# self.model = model | |
# self.ctc_decoder = CTCDecoder( | |
# input_dim=model.param.dims, | |
# vocab_size=vocab_size, | |
# dims=dims, | |
# num_layers=num_layers | |
# ) | |
# def forward(self, input_ids=None, pitch=None, labels=None, input_lengths=None, label_lengths=None): | |
# outputs = self.model(input_ids=input_ids, pitch=pitch) | |
# x = outputs["logits"] # (batch, seq_len, vocab_size) | |
# ctc_logits = self.ctc_decoder(x) # (batch, seq_len, vocab_size + 1) | |
# loss = None | |
# if labels is not None and input_lengths is not None and label_lengths is not None: | |
# log_probs = torch.log_softmax(ctc_logits, dim=-1) | |
# log_probs = log_probs.transpose(0, 1) | |
# loss = torch.nn.functional.ctc_loss( | |
# log_probs, | |
# labels, | |
# input_lengths, | |
# label_lengths, | |
# blank=0, | |
# reduction='mean' | |
# ) | |
# return { | |
# "logits": ctc_logits, | |
# "loss": loss, | |
# "out": x | |
# } | |
# def decode(self, logits: Tensor, input_lengths: Optional[Tensor] = None) -> List[List[int]]: | |
# probs = torch.softmax(logits, dim=-1) # (batch, seq_len, vocab_size + 1) | |
# predictions = torch.argmax(probs, dim=-1) # (batch, seq_len) | |
# decoded_sequences = [] | |
# for i, pred in enumerate(predictions): | |
# seq = [] | |
# prev_token = None | |
# for j, token in enumerate(pred): | |
# if input_lengths is not None and j >= input_lengths[i]: | |
# break | |
# if token != 0 and token != prev_token: | |
# seq.append(token.item()) | |
# prev_token = token | |
# decoded_sequences.append(seq) | |
# return decoded_sequences | |
# # ctc_model = CTCWrapper(model, vocab_size=40000, dims=256, num_layers=2) | |
# # outputs = ctc_model( | |
# # input_ids=input_ids, | |
# # pitch=pitch, | |
# # labels=labels, | |
# # input_lengths=input_lengths, # Length of each audio sequence | |
# # label_lengths=label_lengths # Length of each text sequence | |
# # ) | |
# # loss = outputs["loss"] | |
# # outputs = ctc_model(input_ids=input_ids, pitch=pitch) | |
# # logits = outputs["logits"] | |
# # # Decode to text | |
# # decoded_sequences = ctc_model.decode(logits, input_lengths=input_lengths) | |
# # ctc_model = CTCWrapper(model, vocab_size=param.vocab, dims=256, num_layers=2).to('cuda') | |
# # print(f"CTC model parameters: {sum(p.numel() for p in ctc_model.parameters() if p.requires_grad):,}") | |
# # from tensorboard import program | |
# # log_dir = "D:/newmodel/output/logs" | |
# # tb = program.TensorBoard() | |
# # tb.configure(argv=[None, '--logdir', log_dir]) | |
# # url = tb.launch() | |
# # print(f"TensorBoard started at {url}") | |
# def compute_metricsB(pred, tokenizer): | |
# pred_ids = pred["predictions"] | |
# label_ids = pred["label_ids"] | |
# if isinstance(pred_ids, tuple): | |
# pred_ids = pred_ids[0] | |
# else: | |
# pred_ids = pred_ids | |
# if pred_ids.ndim == 3: | |
# pred_ids = np.argmax(pred_ids, axis=-1) | |
# label_ids[label_ids == -100] = tokenizer.pad_token_id | |
# pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
# label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
# metrics = evaluate.load(path="wer") | |
# wer = metrics.compute(predictions=pred_str, references=label_str) | |
# return {"wer": wer} | |
# def train_and_evaluate( | |
# model, | |
# tokenizer, | |
# train_loader, | |
# eval_loader, | |
# optimizer, | |
# scheduler, | |
# loss_fn, | |
# max_steps=10000, | |
# device="cuda", | |
# accumulation_steps=1, | |
# clear_cache=True, | |
# log_interval=10, | |
# eval_interval=100, | |
# save_interval=1000, | |
# checkpoint_dir="checkpoint_dir", | |
# log_dir="log_dir", | |
# ): | |
# model.to(device) | |
# global_step = 0 | |
# scaler = torch.GradScaler() | |
# writer = SummaryWriter(log_dir=log_dir) | |
# train_iterator = iter(train_loader) | |
# total_loss = 0 | |
# step_in_report = 0 | |
# dataset_epochs = 0 | |
# progress_bar = tqdm( | |
# total=max_steps, desc="Training Progress", leave=True, colour="green" | |
# ) | |
# model.train() | |
# optimizer.zero_grad() | |
# while global_step < max_steps: | |
# try: | |
# batch = next(train_iterator) | |
# except StopIteration: | |
# train_iterator = iter(train_loader) | |
# batch = next(train_iterator) | |
# dataset_epochs += 1 | |
# print(f"Starting dataset epoch {dataset_epochs}") | |
# if step_in_report > 0: | |
# avg_loss = total_loss / step_in_report | |
# logging.info( | |
# f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}" | |
# ) | |
# total_loss = 0 | |
# step_in_report = 0 | |
# start_time = time.time() | |
# input_features = batch["input_features"].to(device) | |
# input_ids = batch["input_ids"].to(device) | |
# labels = batch["labels"].long().to(device) | |
# with torch.autocast(device_type="cuda"): | |
# input_features_encoded = model.encoder(input_features) | |
# decoder_output = model.decoder(input_ids, input_features_encoded) | |
# logits = decoder_output.view(-1, decoder_output.size(-1)) | |
# active_logits = logits.view(-1, decoder_output.size(-1)) | |
# active_labels = labels.view(-1) | |
# active_mask = active_labels != -100 | |
# active_logits = active_logits[active_mask] | |
# active_labels = active_labels[active_mask] | |
# loss = loss_fn(active_logits, active_labels) | |
# # model.adjust_freq(loss=loss.item()) | |
# total_loss += loss.item() | |
# loss = loss / accumulation_steps | |
# scaler.scale(loss).backward() | |
# if (global_step + 1) % accumulation_steps == 0: | |
# scaler.unscale_(optimizer) | |
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
# scaler.step(optimizer) | |
# scaler.update() | |
# optimizer.zero_grad() | |
# if clear_cache: | |
# torch.cuda.empty_cache() | |
# end_time = time.time() | |
# samples_per_sec = len(batch["input_features"]) / (end_time - start_time) | |
# if global_step % log_interval == 0: | |
# writer.add_scalar( | |
# tag="Loss/train", | |
# scalar_value=total_loss / (global_step + 1), | |
# global_step=global_step, | |
# ) | |
# lr = scheduler.get_last_lr()[0] | |
# writer.add_scalar( | |
# tag="LearningRate", scalar_value=lr, global_step=global_step | |
# ) | |
# writer.add_scalar( | |
# tag="SamplesPerSec", | |
# scalar_value=samples_per_sec, | |
# global_step=global_step, | |
# ) | |
# if global_step % eval_interval == 0: | |
# model.eval() | |
# eval_start_time = time.time() | |
# eval_loss = 0 | |
# all_predictions = [] | |
# all_labels = [] | |
# batch_count = 0 | |
# total_samples = 0 | |
# with torch.no_grad(): | |
# for eval_batch in eval_loader: | |
# # for eval_batch in tqdm(eval_loader, desc=f"Evaluating (Step {global_step})", leave=True, colour='green'): | |
# input_features = eval_batch["input_features"].to(device) | |
# input_ids = eval_batch["input_ids"].to(device) | |
# labels = eval_batch["labels"].long().to(device) | |
# batch = input_features.size(0) | |
# total_samples += batch | |
# input_features_encoded = model.encoder(input_features) | |
# decoder_output = model.decoder(input_ids, input_features_encoded) | |
# logits = decoder_output.view(-1, decoder_output.size(-1)) | |
# loss = loss_fn(logits, labels.view(-1)) | |
# eval_loss += loss.item() | |
# all_predictions.extend( | |
# torch.argmax(decoder_output, dim=-1).cpu().numpy().tolist() | |
# ) | |
# all_labels.extend(labels.cpu().numpy().tolist()) | |
# batch_count += 1 | |
# eval_time = time.time() - eval_start_time | |
# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 | |
# predictions = { | |
# "predictions": np.array(all_predictions, dtype=object), | |
# "label_ids": np.array(all_labels, dtype=object), | |
# } | |
# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) | |
# writer.add_scalar("Loss/eval", loss_avg, global_step) | |
# writer.add_scalar("WER", metrics["wer"], global_step) | |
# writer.add_scalar("EvalSamples", total_samples, global_step) | |
# writer.add_scalar("EvalTimeSeconds", eval_time, global_step) | |
# lr = scheduler.get_last_lr()[0] | |
# print( | |
# f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}" | |
# ) | |
# logging.info( | |
# f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}" | |
# ) | |
# # scheduler.step() | |
# model.train() | |
# if global_step % save_interval == 0: | |
# checkpoint_path = os.path.join( | |
# checkpoint_dir, f"checkpoint_step_{global_step}.pt" | |
# ) | |
# torch.save(model.state_dict(), checkpoint_path) | |
# # print(f"Model saved at step {global_step} to {checkpoint_path}") | |
# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") | |
# lr = scheduler.get_last_lr()[0] | |
# scheduler.step() | |
# global_step += 1 | |
# step_in_report += 1 | |
# avg_loss = total_loss / (global_step + 1) | |
# postfix_dict = { | |
# "loss": f"{avg_loss:.4f}", | |
# "lr": f"{lr:.6f}", | |
# "samp": f"{samples_per_sec:.1f}", | |
# } | |
# progress_bar.set_postfix(postfix_dict, refresh=True) | |
# progress_bar.update(1) | |
# final_model_path = os.path.join(checkpoint_dir, "final_model.pt") | |
# torch.save(model.state_dict(), final_model_path) | |
# print( | |
# f"Training completed after {global_step} steps. Final model saved to {final_model_path}" | |
# ) | |
# writer.close() | |
# progress_bar.close() | |
# def mainB(): | |
# checkpoint_dir = "./output/checkpoints" | |
# os.makedirs(checkpoint_dir, exist_ok=True) | |
# log_dir = os.path.join("./output/logs", datetime.now().strftime(format="%m-%d_%H")) | |
# os.makedirs(name=log_dir, exist_ok=True) | |
# logging.basicConfig( | |
# filename=os.path.join(log_dir, "training.log"), | |
# filemode="w", | |
# format="%(asctime)s - %(levelname)s - %(message)s", | |
# level=logging.INFO, | |
# ) | |
# token = "" | |
# dataset = IterableDatasetDict() | |
# dataset["train"] = load_dataset( | |
# path="google/fleurs", | |
# name="en_us", | |
# split="train", | |
# streaming=True, | |
# token=token, | |
# trust_remote_code=True, | |
# ).select_columns(column_names=["audio", "transcription"]) | |
# dataset["test"] = load_dataset( | |
# path="google/fleurs", | |
# name="en_us", | |
# split="test", | |
# streaming=True, | |
# token=token, | |
# trust_remote_code=True, | |
# ).select_columns(column_names=["audio", "transcription"]) | |
# debug = None | |
# param = Dimensions( | |
# mels=128, | |
# audio_ctx=1500, | |
# audio_head=4, | |
# encoder_idx=4, | |
# audio_dims=512, | |
# vocab=51865, | |
# text_ctx=512, | |
# text_head=4, | |
# decoder_idx=4, | |
# text_dims=512, | |
# decoder_start_token_id=0, | |
# pad_token_id=0, | |
# eos_token_id=0, | |
# act="gelu", | |
# ) | |
# model = model | |
# Collator = DataCollatorB( | |
# tokenizer=tokenizer, | |
# audio_ctx=param.audio_ctx, | |
# text_ctx=param.text_ctx, | |
# mels=param.mels, | |
# ) | |
# train_dataloader = DataLoader( | |
# dataset=dataset["train"], batch_size=1, collate_fn=Collator, num_workers=0 | |
# ) | |
# eval_dataloader = DataLoader( | |
# dataset=dataset["test"], batch_size=1, collate_fn=Collator, num_workers=0 | |
# ) | |
# optimizer = torch.optim.AdamW( | |
# model.parameters(), lr=5e-4, weight_decay=0.01, eps=1e-6, betas=(0.9, 0.98) | |
# ) | |
# scheduler = torch.optim.lr_scheduler.LinearLR( | |
# optimizer, start_factor=0.25, total_iters=10000, last_epoch=-1 | |
# ) | |
# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100) | |
# train_and_evaluate( | |
# model=model, | |
# tokenizer=tokenizer, | |
# train_loader=train_dataloader, | |
# eval_loader=eval_dataloader, | |
# optimizer=optimizer, | |
# scheduler=scheduler, | |
# loss_fn=loss_fn, | |
# max_steps=10000, | |
# device="cuda", | |
# accumulation_steps=1, | |
# clear_cache=False, | |
# log_interval=10, | |
# eval_interval=500, | |
# save_interval=10000, | |
# checkpoint_dir=checkpoint_dir, | |
# log_dir=log_dir, | |
# ) | |
# def train_and_evaluate( | |
# model, tokenizer, train_loader, eval_loader, optimizer, scheduler, loss_fn, | |
# max_steps=10000, device='cuda', accumulation_steps=1, clear_cache=True, | |
# log_interval=10, eval_interval=100, save_interval=1000, | |
# checkpoint_dir="checkpoint_dir", log_dir="log_dir" | |
# ): | |
# model.to(device) | |
# global_step = 0 | |
# scaler = torch.GradScaler() | |
# writer = SummaryWriter(log_dir=log_dir) | |
# train_iterator = iter(train_loader) | |
# total_loss = 0 | |
# step_in_report = 0 | |
# dataset_epochs = 0 | |
# progress_bar = tqdm(total=max_steps, desc="Training Progress", leave=True, colour='green') | |
# model.train() | |
# optimizer.zero_grad() | |
# while global_step < max_steps: | |
# try: | |
# batch = next(train_iterator) | |
# except StopIteration: | |
# train_iterator = iter(train_loader) | |
# batch = next(train_iterator) | |
# dataset_epochs += 1 | |
# print(f"Starting dataset epoch {dataset_epochs}") | |
# if step_in_report > 0: | |
# avg_loss = total_loss / step_in_report | |
# logging.info(f"Dataset iteration complete - Steps: {global_step}, Avg Loss: {avg_loss:.4f}") | |
# total_loss = 0 | |
# step_in_report = 0 | |
# start_time = time.time() | |
# batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} | |
# with torch.autocast(device_type="cuda"): | |
# output = model(**batch) if hasattr(model, '__call__') else model.forward(**batch) | |
# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output | |
# labels = batch["labels"] | |
# active_logits = logits.view(-1, logits.size(-1)) | |
# active_labels = labels.view(-1) | |
# active_mask = active_labels != 0 | |
# active_logits = active_logits[active_mask] | |
# active_labels = active_labels[active_mask] | |
# loss = loss_fn(active_logits, active_labels) | |
# total_loss += loss.item() | |
# loss = loss / accumulation_steps | |
# scaler.scale(loss).backward() | |
# if (global_step + 1) % accumulation_steps == 0: | |
# scaler.unscale_(optimizer) | |
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
# scaler.step(optimizer) | |
# scaler.update() | |
# optimizer.zero_grad() | |
# if clear_cache: | |
# torch.cuda.empty_cache() | |
# end_time = time.time() | |
# samples_per_sec = batch["spectrogram"].size(0) / (end_time - start_time) | |
# if global_step % log_interval == 0: | |
# writer.add_scalar(tag='Loss/train', scalar_value=total_loss / (global_step + 1), global_step=global_step) | |
# lr = scheduler.get_last_lr()[0] | |
# writer.add_scalar(tag='LearningRate', scalar_value=lr, global_step=global_step) | |
# writer.add_scalar(tag='SamplesPerSec', scalar_value=samples_per_sec, global_step=global_step) | |
# if global_step % eval_interval == 0: | |
# model.eval() | |
# eval_start_time = time.time() | |
# eval_loss = 0 | |
# all_predictions = [] | |
# all_labels = [] | |
# batch_count = 0 | |
# total_samples = 0 | |
# with torch.no_grad(): | |
# for eval_batch in eval_loader: | |
# eval_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in eval_batch.items()} | |
# output = model(**eval_batch) if hasattr(model, '__call__') else model.forward(**eval_batch) | |
# logits = output["logits"] if isinstance(output, dict) and "logits" in output else output | |
# labels = eval_batch["labels"] | |
# batch_size = logits.size(0) | |
# total_samples += batch_size | |
# loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1)) | |
# eval_loss += loss.item() | |
# all_predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy().tolist()) | |
# all_labels.extend(labels.cpu().numpy().tolist()) | |
# batch_count += 1 | |
# eval_time = time.time() - eval_start_time | |
# loss_avg = eval_loss / batch_count if batch_count > 0 else 0 | |
# predictions = {"predictions": np.array(all_predictions, dtype=object), "label_ids": np.array(all_labels, dtype=object)} | |
# metrics = compute_metrics(pred=predictions, tokenizer=tokenizer) | |
# writer.add_scalar('Loss/eval', loss_avg, global_step) | |
# writer.add_scalar('WER', metrics['wer'], global_step) | |
# writer.add_scalar('EvalSamples', total_samples, global_step) | |
# writer.add_scalar('EvalTimeSeconds', eval_time, global_step) | |
# lr = scheduler.get_last_lr()[0] | |
# print(f"• STEP:{global_step} • samp:{samples_per_sec:.1f} • WER:{metrics['wer']:.2f}% • Loss:{loss_avg:.4f} • LR:{lr:.8f}") | |
# logging.info(f"EVALUATION STEP {global_step} - WER: {metrics['wer']:.2f}%, Loss: {loss_avg:.4f}, LR: {lr:.8f}") | |
# model.train() | |
# if global_step % save_interval == 0: | |
# checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_step_{global_step}.pt') | |
# torch.save(model.state_dict(), checkpoint_path) | |
# logging.info(f"Model saved at step {global_step} to {checkpoint_path}") | |
# lr = scheduler.get_last_lr()[0] | |
# scheduler.step() | |
# global_step += 1 | |
# step_in_report += 1 | |
# avg_loss = total_loss / (global_step + 1) | |
# postfix_dict = { | |
# 'loss': f'{avg_loss:.4f}', | |
# 'lr': f'{lr:.6f}', | |
# 'samp': f'{samples_per_sec:.1f}' | |
# } | |
# progress_bar.set_postfix(postfix_dict, refresh=True) | |
# progress_bar.update(1) | |
# final_model_path = os.path.join(checkpoint_dir, 'final_model.pt') | |
# torch.save(model.state_dict(), final_model_path) | |
# print(f"Training completed after {global_step} steps. Final model saved to {final_model_path}") | |
# writer.close() | |
# progress_bar.close() | |
# def get_optimizer(model, lr=5e-4, weight_decay=0.01): | |
# return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-6, betas=(0.9, 0.98)) | |
# def get_scheduler(optimizer, total_steps=10000): | |
# return torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.25, total_iters=total_steps, last_epoch=-1) | |
# def get_loss_fn(): | |
# return torch.nn.CrossEntropyLoss(ignore_index=0) | |
# def mainc(): | |
# token = "" | |
# log_dir = os.path.join('./output/logs', datetime.now().strftime('%m-%d_%H_%M_%S')) | |
# os.makedirs(log_dir, exist_ok=True) | |
# tokenizer = setup_tokenizer(token) | |
# param = Dimensions( | |
# mels=128, aud_ctx=1500, aud_head=4, aud_dims=512, aud_idx=4, | |
# vocab=40000, text_ctx=512, text_head=4, text_dims=512, text_idx=4, | |
# act="swish", debug={}, cross_attn=True, features=["spectrogram"] | |
# ) | |
# dataset_config = { | |
# "spectrogram": True, "waveforms": False, "pitch": False, "downsamples": False, | |
# "frequency": True, "hilbert": False, "hop_length": 128, "fmin": 150, "fmax": 2000, | |
# "n_mels": 128, "n_fft": 1024, "sampling_rate": 16000, "pad_mode": "constant", | |
# "center": True, "power": 2.0, "window_fn": torch.hann_window, "mel_scale": "htk", | |
# "norm": None, "normalized": False | |
# } | |
# model = create_model(param) | |
# train_dataset, test_dataset = prepare_datasets( | |
# tokenizer=tokenizer, token=token, sanity_check=False, dataset_config=dataset_config | |
# ) | |
# collator = DataCollator(tokenizer=tokenizer) | |
# train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=collator, num_workers=0) | |
# eval_loader = DataLoader(test_dataset, batch_size=1, collate_fn=collator, num_workers=0) | |
# optimizer = get_optimizer(model) | |
# scheduler = get_scheduler(optimizer) | |
# loss_fn = get_loss_fn() | |
# train_and_evaluate( | |
# model=model, | |
# tokenizer=tokenizer, | |
# train_loader=train_loader, | |
# eval_loader=eval_loader, | |
# optimizer=optimizer, | |
# scheduler=scheduler, | |
# loss_fn=loss_fn, | |
# max_steps=10000, | |
# device='cuda', | |
# accumulation_steps=1, | |
# clear_cache=False, | |
# log_interval=10, | |
# eval_interval=500, | |
# save_interval=10000, | |
# checkpoint_dir="./checkpoints", | |
# log_dir=log_dir | |
# ) | |
# class attention(nn.Module): | |
# def __init__(self, dims: int, head: int): | |
# super(attention, self).__init__() | |
# self.dims = dims | |
# self.head = head | |
# self.head_dim = dims // head | |
# self.q = nn.Linear(dims, dims) | |
# self.k = nn.Linear(dims, dims, bias=False) | |
# self.v = nn.Linear(dims, dims) | |
# self.o = nn.Linear(dims, dims) | |
# self.lna = nn.LayerNorm(dims, bias = False) | |
# self.lnb = nn.LayerNorm(dims, bias = False) | |
# self.lnc = nn.LayerNorm(self.head_dim, bias = False) | |
# self.lnd = nn.LayerNorm(self.head_dim, bias = False) | |
# def _forward(self, x: Tensor, xa = None, mask = None): | |
# q = self.q(self.lna(x)) | |
# k = self.k(self.lnb(x if xa is None else xa)) | |
# v = self.v(self.lnb(x if xa is None else xa)) | |
# query = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3) | |
# key = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3) | |
# value = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3) | |
# max_iterations = 5 | |
# iteration = 0 | |
# prev_attn_out = torch.zeros_like(query) | |
# attn_out = torch.zeros_like(query) | |
# threshold = self.threshold.item() | |
# s_factor = self.s_factor.item() | |
# q_current = query | |
# while iteration < max_iterations: | |
# eff_span = min(x.shape[1], xa.shape[1], q_current.size(1), key.size(1)) | |
# if eff_span == 0: break | |
# q_iter = q_current[:, :eff_span, :] | |
# k_iter = key[:, :eff_span, :] | |
# v_iter = value[:, :eff_span, :] | |
# q_proj = self.attn_local.query_module(q_iter) | |
# k_proj = self.attn_local.key_module(k_iter) | |
# v_proj = self.attn_local.value_module(v_iter) | |
# temperature = (1.0 + self.temp_scale * (1.0 - xa.mean()) | |
# if self.sharpen | |
# else 0.5 + self.temp_scale * xa.mean()) | |
# temperature = max(temperature, 1e-3) | |
# iter_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: iter_mask = mask[:, :, :eff_span, :eff_span] | |
# elif mask.dim() == 2: iter_mask = mask[:eff_span, :eff_span] | |
# attn_output_iter, _ = calculate_attention( | |
# q_proj, k_proj, v_proj, | |
# mask=iter_mask, | |
# temperature=temperature, | |
# use_sdpa=False, | |
# dropout_p=self.dropout | |
# ) | |
# attn_out_span = self.attn_local._reshape_to_output(attn_output_iter) | |
# projected_attn_out_span = self.attn_local.out_proj(attn_out_span) | |
# current_iter_out = torch.zeros_like(q_current) | |
# current_iter_out[:, :eff_span, :] = projected_attn_out_span | |
# diff = torch.abs(current_iter_out - prev_attn_out).mean() | |
# dynamic_threshold = threshold + s_factor * diff | |
# if diff < dynamic_threshold and iteration > 0: | |
# attn_out = current_iter_out | |
# break | |
# prev_attn_out = current_iter_out.clone() | |
# q_current = q_current + current_iter_out | |
# attn_out = current_iter_out | |
# iteration += 1 | |
# return attn_out, None | |
# def _slide_win_local(self, x: Tensor, win_size: int, span_len: int, | |
# span_scale: Tensor, mask: Optional[Tensor] = None, | |
# is_causal: bool = False) -> Tensor: | |
# """ | |
# Process input with sliding window attention, using `_focus` for each window. | |
# Args: | |
# x: Input tensor (Batch, SeqLen, Dims). | |
# win_size: Size of the attention window for queries. | |
# span_len: Max length of keys/values relative to query window start. | |
# span_scale: Span scale tensor (Batch, 1 or scalar) passed to _focus. | |
# mask: Full attention mask. | |
# is_causal: Apply causal masking within windows. | |
# Returns: | |
# Output tensor (Batch, SeqLen, Dims). | |
# """ | |
# batch, ctx, dims = x.size() | |
# output = torch.zeros_like(x) | |
# num_windows = (ctx + win_size - 1) // win_size | |
# for i in range(num_windows): | |
# q_start = i * win_size | |
# q_end = min(q_start + win_size, ctx) | |
# current_window_q_len = q_end - q_start | |
# if current_window_q_len == 0: continue | |
# kv_start = max(0, q_end - span_len) | |
# kv_end = q_end | |
# query_win = x[:, q_start:q_end, :] | |
# key_win = x[:, kv_start:kv_end, :] | |
# value_win = x[:, kv_start:kv_end, :] | |
# window_mask = None | |
# if mask is not None: | |
# if mask.dim() == 4: | |
# window_mask = mask[:, :, q_start:q_end, kv_start:kv_end] | |
# elif mask.dim() == 2: | |
# window_mask = mask[q_start:q_end, kv_start:kv_end] | |
# attn_out_win, _ = self._focus( | |
# query=query_win, | |
# key=key_win, | |
# value=value_win, | |
# span_scale=span_scale, | |
# mask=window_mask, | |
# is_causal=is_causal | |
# ) | |
# output[:, q_start:q_end, :] = attn_out_win | |
# return output |