|
import math |
|
import time |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
def load_what_you_can(checkpoint: dict, model: torch.nn.Module): |
|
""" |
|
This method takes a checkpoint and loads as many weights from it as possible: |
|
|
|
If they are the same shape, there's nothing to do |
|
|
|
Will load the smallest shape otherwise. |
|
""" |
|
import torch |
|
|
|
model_state_dict = model.state_dict() |
|
checkpoint_state_dict = checkpoint |
|
|
|
for name, param in checkpoint_state_dict.items(): |
|
if name not in model_state_dict: |
|
print(f"Ignoring parameter '{name}' because it is not found in the model") |
|
continue |
|
|
|
model_state = model_state_dict[name] |
|
mshape = model_state.shape |
|
pshape = param.shape |
|
|
|
if pshape == mshape: |
|
model_state.copy_(param) |
|
continue |
|
|
|
if len(pshape) != len(mshape): |
|
|
|
continue |
|
|
|
min_shape = [ |
|
min(param.shape[i], model_state.shape[i]) for i in range(len(param.shape)) |
|
] |
|
print(name, "model:", mshape, "chkpt:", pshape, "loading:", min_shape) |
|
idxs = torch.meshgrid(*[torch.arange(s) for s in min_shape]) |
|
model_state[tuple(idxs)].copy_(param[tuple(idxs)]) |
|
|
|
return model.load_state_dict(model_state_dict) |
|
|
|
|
|
def multimap( |
|
items: list, func: callable, workers=4, desc=None, thread=False, chunk_size=128 |
|
) -> list: |
|
""" |
|
Quick and dirty multiprocessing that will return the result of func if it returns None |
|
""" |
|
from tqdm.contrib.concurrent import process_map, thread_map |
|
|
|
m = thread_map if thread else process_map |
|
length = None |
|
try: |
|
length = len(items) |
|
except Exception as e: |
|
print(e, "getting length") |
|
|
|
results = m( |
|
func, |
|
items, |
|
leave=False, |
|
desc=desc, |
|
max_workers=workers, |
|
total=length, |
|
chunksize=chunk_size, |
|
) |
|
return list(filter(lambda x: x is not None, results)) |
|
|
|
|
|
def round_up(num: float, factor: int): |
|
return factor * math.ceil(num / factor) |
|
|
|
|
|
def left_padding_mask(lengths, max_len, device=None, dtype=None): |
|
masks = [] |
|
if not max_len: |
|
max_len = max(lengths) |
|
for l in lengths: |
|
mask = torch.empty(l, l, device=device, dtype=dtype).fill_(-torch.inf).triu_(1) |
|
diff = max_len - l |
|
mask = F.pad(mask, (diff, 0, diff, 0), value=-torch.inf) |
|
masks.append(mask) |
|
|
|
masks = torch.stack(masks) |
|
return masks[:, None] |
|
|
|
|
|
def seed_all(seed: int): |
|
import random |
|
|
|
import numpy as np |
|
import torch |
|
|
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
|
|
def split_bucket_path(url: str) -> tuple[str, str]: |
|
url = url.replace("s3://", "") |
|
url = url.replace("sj://", "") |
|
url = url.replace("r2://", "") |
|
bucket = url.split("/")[0] |
|
path = "/".join(url.split("/")[1:]) |
|
return bucket, path |
|
|
|
|
|
def prob_mask_like(shape, prob: float, device): |
|
import torch |
|
|
|
if prob == 1: |
|
return torch.ones(shape, device=device, dtype=torch.bool) |
|
elif prob == 0: |
|
return torch.zeros(shape, device=device, dtype=torch.bool) |
|
else: |
|
return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob |
|
|
|
|
|
def round_up_to_multiple(n: int, multiple: int) -> int: |
|
if n % multiple != 0: |
|
n += multiple - (n % multiple) |
|
|
|
return n |
|
|
|
|
|
def warmup_then_cosine_decay( |
|
step: int, *, warmup_steps: int, steps: int, min_lr: float, max_lr: float |
|
): |
|
eps = 1e-9 |
|
cooldown_steps = warmup_steps |
|
if step < warmup_steps: |
|
return min_lr + step * (max_lr - min_lr) / (warmup_steps) |
|
elif step > steps: |
|
return min_lr |
|
elif step < steps - cooldown_steps: |
|
decay_ratio = (step - warmup_steps) / (steps - warmup_steps - cooldown_steps) |
|
|
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (max_lr - min_lr) |
|
else: |
|
|
|
return min_lr * (steps - step) / cooldown_steps + eps |
|
|
|
|
|
def decay_to_zero(step: int, *, decay_steps: int, steps: int, max_lr: float): |
|
if step > steps: |
|
return 0.0 |
|
else: |
|
gradient = -max_lr / decay_steps |
|
|
|
return max_lr + gradient * step |
|
|
|
|
|
def cross_entropy_loss(logits, mask, targets): |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
B, Q, T, _ = logits.size() |
|
assert logits.shape[:-1] == targets.shape |
|
assert mask.shape == targets.shape |
|
loss = torch.zeros([], device=targets.device) |
|
codebook_losses = [] |
|
for q in range(Q): |
|
logits_q = ( |
|
logits[:, q, ...].contiguous().view(-1, logits.size(-1)) |
|
) |
|
targets_q = targets[:, q, ...].contiguous().view(-1) |
|
mask_q = mask[:, q, ...].contiguous().view(-1) |
|
ce_targets = targets_q[mask_q] |
|
ce_logits = logits_q[mask_q] |
|
q_ce = F.cross_entropy(ce_logits, ce_targets) |
|
loss += q_ce |
|
codebook_losses.append(q_ce.detach()) |
|
|
|
loss = loss / Q |
|
return loss, codebook_losses |
|
|
|
|
|
def build_optimizer( |
|
module, *, weight_decay: float, lr: float, betas: tuple[float, float] |
|
): |
|
import torch |
|
|
|
param_dict = {pn: p for pn, p in module.named_parameters() if p.requires_grad} |
|
|
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
optim_groups = [ |
|
{"params": decay_params, "weight_decay": weight_decay}, |
|
{"params": nodecay_params, "weight_decay": 0.0}, |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=lr, betas=betas, fused=True) |
|
|
|
return optimizer |
|
|
|
|
|
def pad_or_cut_right(t: Tensor, padlen: int, value=0) -> Tensor: |
|
current_len = t.shape[-1] |
|
|
|
if current_len == padlen: |
|
return t |
|
|
|
if current_len < padlen: |
|
|
|
pad_size = (0, padlen - current_len) |
|
return F.pad(t, pad_size, value=value) |
|
|
|
return t[:padlen] |
|
|
|
|
|
def pad_or_cut_left(t: Tensor, value: int) -> Tensor: |
|
dims = t.ndim |
|
current_len = t.shape[0] |
|
|
|
if current_len == value: |
|
return t |
|
|
|
if current_len < value: |
|
|
|
pad_size = (0,) * (2 * (dims - 1)) + (value - current_len, 0) |
|
return F.pad(t, pad_size) |
|
|
|
return t[-value:] |
|
|
|
|
|
def dl_pt(orig: str): |
|
from os.path import exists |
|
|
|
import torch |
|
|
|
from vui.storage import s3, split_bucket_path |
|
|
|
if not orig.endswith(".pt"): |
|
orig = orig + ".pt" |
|
|
|
load = partial(torch.load, weights_only=True) |
|
if exists(orig): |
|
return load(orig) |
|
url = "/data/" + orig |
|
|
|
if exists(url): |
|
return load(url) |
|
url = "s3://fluxions/" + orig |
|
|
|
bucket, key = split_bucket_path(url) |
|
response = s3.get_object(Bucket=bucket, Key=key) |
|
return load(response["Body"]) |
|
|
|
|
|
def dl_ogg(url: str, start=0, end=-1, sr=None): |
|
import re |
|
from os.path import exists |
|
|
|
import soundfile as sf |
|
import torch |
|
|
|
search_sr = re.search(r"(\d+)/", url) |
|
if search_sr: |
|
sr = int(search_sr.group(1)) |
|
|
|
local_file = exists(url) |
|
|
|
if exists("/data/audio/" + url): |
|
local_file = True |
|
url = "/data/audio/" + url |
|
|
|
if not local_file: |
|
from vui.storage import s3 |
|
|
|
url = "s3://fluxions/" + url |
|
b, p = split_bucket_path(url) |
|
url = s3.get_object(Bucket=b, Key=p)["Body"] |
|
|
|
if sr is None: |
|
if local_file: |
|
sr = sf.info(url).samplerate |
|
else: |
|
sr = sf.info(url.read()).samplerate |
|
|
|
start_frame = int(start * sr) |
|
num_frames = int(end * sr) - start_frame |
|
wav, _ = sf.read(url, frames=num_frames, start=start_frame, always_2d=True) |
|
wav = torch.from_numpy(wav).float() |
|
wav = wav.T.mean(0, keepdim=True) |
|
return wav, sr |
|
|
|
|
|
class timer: |
|
def __init__(self, name=""): |
|
self.name = name |
|
|
|
def __enter__(self): |
|
self.t = time.perf_counter() |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
elapsed = time.perf_counter() - self.t |
|
print(f"{self.name} {elapsed:.4f}") |
|
|
|
|
|
@torch.inference_mode() |
|
def decode_audio_from_indices(model, indices, chunk_size=64): |
|
""" |
|
Decodes audio from indices in batches to avoid memory issues. |
|
|
|
Args: |
|
model: Codec |
|
indices: Tensor of shape (1, n_quantizers, sequence_length) |
|
chunk_size: Number of samples to process at once |
|
|
|
Returns: |
|
Tensor of reconstructed audio |
|
""" |
|
device = model.device |
|
indices = indices.to(device) |
|
_, _, seq_len = indices.shape |
|
chunks = seq_len // chunk_size + (1 if seq_len % chunk_size != 0 else 0) |
|
|
|
audio_chunks = [] |
|
for i in range(chunks): |
|
start_idx = i * chunk_size |
|
end_idx = min(start_idx + chunk_size, seq_len) |
|
chunk_indices = indices[:, :, start_idx:end_idx] |
|
chunk_audio = model.from_indices(chunk_indices) |
|
audio_chunks.append(chunk_audio.cpu()) |
|
|
|
full_audio = torch.cat(audio_chunks, dim=-1) |
|
return full_audio.flatten() |
|
|
|
|
|
def normalize_loudness(waveform, sample_rate: int, lufs: float = -12.0): |
|
""" |
|
Normalize the loudness of an audio tensor using torchaudio.transforms.Loudness. |
|
|
|
Args: |
|
audio_tensor (torch.Tensor): Input audio tensor of shape (channels, samples) |
|
sample_rate (int): Sampling rate of the audio |
|
target_loudness (float): Target loudness in LUFS (default: -16.0 LUFS) |
|
|
|
Returns: |
|
torch.Tensor: Loudness-normalized audio tensor |
|
""" |
|
import torchaudio |
|
|
|
|
|
if waveform.ndim == 1: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
loudness_transform = torchaudio.transforms.Loudness(sample_rate) |
|
|
|
|
|
current_loudness = loudness_transform(waveform) |
|
|
|
|
|
gain_db = lufs - current_loudness |
|
|
|
|
|
gain_linear = torch.pow(10, gain_db / 20) |
|
|
|
|
|
normalized_audio = waveform * gain_linear |
|
|
|
return normalized_audio |
|
|
|
|
|
def get_basename_without_extension(file_path): |
|
from pathlib import Path |
|
|
|
p = Path(file_path) |
|
return p.stem |
|
|
|
|
|
def ollama(prompt, MODEL=None): |
|
import os |
|
|
|
import requests |
|
|
|
OLLAMA_HOST = "http://localhost:11434" |
|
API = f"{OLLAMA_HOST}/api/generate" |
|
|
|
if MODEL is None: |
|
MODEL = os.environ.get("OLLAMA_MODEL", "gemma:1b") |
|
|
|
payload = { |
|
"model": MODEL, |
|
"prompt": prompt, |
|
"stream": False, |
|
"options": {"temperature": 0.9, "top_p": 0.9, "max_tokens": 1000}, |
|
} |
|
|
|
try: |
|
response = requests.post(API, json=payload) |
|
response.raise_for_status() |
|
result = response.json() |
|
return result.get("response", "") |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error calling Ollama API: {e}") |
|
return "" |
|
|
|
|
|
def decompile_state_dict(state_dict): |
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
|
return {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|