Spaces:
Sleeping
Sleeping
""" | |
TODO: train a linear probe | |
usage: | |
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output | |
""" | |
from pathlib import Path | |
from typing import List | |
import audiotools as at | |
from audiotools import AudioSignal | |
import argbind | |
import torch | |
import numpy as np | |
import zipfile | |
import json | |
from vampnet.interface import Interface | |
import tqdm | |
# bind the Interface to argbind | |
Interface = argbind.bind(Interface) | |
DEBUG = False | |
def smart_plotly_export(fig, save_path): | |
img_format = save_path.split('.')[-1] | |
if img_format == 'html': | |
fig.write_html(save_path) | |
elif img_format == 'bytes': | |
return fig.to_image(format='png') | |
#TODO: come back and make this prettier | |
elif img_format == 'numpy': | |
import io | |
from PIL import Image | |
def plotly_fig2array(fig): | |
#convert Plotly fig to an array | |
fig_bytes = fig.to_image(format="png", width=1200, height=700) | |
buf = io.BytesIO(fig_bytes) | |
img = Image.open(buf) | |
return np.asarray(img) | |
return plotly_fig2array(fig) | |
elif img_format == 'jpeg' or 'png' or 'webp': | |
fig.write_image(save_path) | |
else: | |
raise ValueError("invalid image format") | |
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''): | |
""" | |
dimensionality reduction for visualization! | |
saves an html plotly figure to save_path | |
parameters: | |
emb (np.ndarray): the samples to be reduces with shape (samples, features) | |
labels (list): list of labels for embedding | |
save_path (str): path where u wanna save ur figure | |
method (str): umap, tsne, or pca | |
title (str): title for ur figure | |
returns: | |
proj (np.ndarray): projection vector with shape (samples, dimensions) | |
""" | |
import pandas as pd | |
import plotly.express as px | |
if method == 'umap': | |
from umap import UMAP | |
reducer = umap.UMAP(n_components=n_components) | |
elif method == 'tsne': | |
from sklearn.manifold import TSNE | |
reducer = TSNE(n_components=n_components) | |
elif method == 'pca': | |
from sklearn.decomposition import PCA | |
reducer = PCA(n_components=n_components) | |
else: | |
raise ValueError | |
proj = reducer.fit_transform(emb) | |
if n_components == 2: | |
df = pd.DataFrame(dict( | |
x=proj[:, 0], | |
y=proj[:, 1], | |
instrument=labels | |
)) | |
fig = px.scatter(df, x='x', y='y', color='instrument', | |
title=title+f"_{method}") | |
elif n_components == 3: | |
df = pd.DataFrame(dict( | |
x=proj[:, 0], | |
y=proj[:, 1], | |
z=proj[:, 2], | |
instrument=labels | |
)) | |
fig = px.scatter_3d(df, x='x', y='y', z='z', | |
color='instrument', | |
title=title) | |
else: | |
raise ValueError("cant plot more than 3 components") | |
fig.update_traces(marker=dict(size=6, | |
line=dict(width=1, | |
color='DarkSlateGrey')), | |
selector=dict(mode='markers')) | |
return smart_plotly_export(fig, save_path) | |
# per JukeMIR, we want the emebddings from the middle layer? | |
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10): | |
with torch.inference_mode(): | |
# preprocess the signal | |
sig = interface.preprocess(sig) | |
# get the coarse vampnet model | |
vampnet = interface.coarse | |
# get the tokens | |
z = interface.encode(sig)[:, :vampnet.n_codebooks, :] | |
z_latents = vampnet.embedding.from_codes(z, interface.codec) | |
# do a forward pass through the model, get the embeddings | |
_z, embeddings = vampnet(z_latents, return_activations=True) | |
# print(f"got embeddings with shape {embeddings.shape}") | |
# [layer, batch, time, n_dims] | |
# [20, 1, 600ish, 768] | |
# squeeze batch dim (1 bc layer should be dim 0) | |
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}" | |
embeddings = embeddings.squeeze(1) | |
num_layers = embeddings.shape[0] | |
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers" | |
# do meanpooling over the time dimension | |
embeddings = embeddings.mean(dim=-2) | |
# [20, 768] | |
# return the embeddings | |
return embeddings | |
from dataclasses import dataclass, fields | |
class Embedding: | |
genre: str | |
filename: str | |
embedding: np.ndarray | |
def save(self, path): | |
"""Save the Embedding object to a given path as a zip file.""" | |
with zipfile.ZipFile(path, 'w') as archive: | |
# Save numpy array | |
with archive.open('embedding.npy', 'w') as f: | |
np.save(f, self.embedding) | |
# Save non-numpy data as json | |
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'} | |
with archive.open('data.json', 'w') as f: | |
f.write(json.dumps(non_numpy_data).encode('utf-8')) | |
def load(cls, path): | |
"""Load the Embedding object from a given zip path.""" | |
with zipfile.ZipFile(path, 'r') as archive: | |
# Load numpy array | |
with archive.open('embedding.npy') as f: | |
embedding = np.load(f) | |
# Load non-numpy data from json | |
with archive.open('data.json') as f: | |
data = json.loads(f.read().decode('utf-8')) | |
return cls(embedding=embedding, **data) | |
def main( | |
path_to_gtzan: str = None, | |
cache_dir: str = "./.gtzan_emb_cache", | |
output_dir: str = "./gtzan_vampnet_embeddings", | |
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] | |
): | |
path_to_gtzan = Path(path_to_gtzan) | |
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist" | |
cache_dir = Path(cache_dir) | |
output_dir = Path(output_dir) | |
output_dir.mkdir(exist_ok=True, parents=True) | |
# load our interface | |
# argbind will automatically load the default config, | |
interface = Interface() | |
# gtzan should have a folder for each genre, so let's get the list of genres | |
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()] | |
print(f"Found {len(genres)} genres") | |
print(f"genres: {genres}") | |
# collect audio files, genres, and embeddings | |
data = [] | |
for genre in genres: | |
audio_files = list(at.util.find_audio(path_to_gtzan / genre)) | |
print(f"Found {len(audio_files)} audio files for genre {genre}") | |
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"): | |
# check if we have a cached embedding for this file | |
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb") | |
if cached_path.exists(): | |
# if so, load it | |
if DEBUG: | |
print(f"loading cached embedding for {cached_path.stem}") | |
embedding = Embedding.load(cached_path) | |
else: | |
try: | |
sig = AudioSignal(audio_file) | |
except Exception as e: | |
print(f"failed to load {audio_file.name} with error {e}") | |
print(f"skipping {audio_file.name}") | |
continue | |
# gets the embedding | |
emb = vampnet_embed(sig, interface).cpu().numpy() | |
# create an embedding we can save/load | |
embedding = Embedding( | |
genre=genre, | |
filename=audio_file.name, | |
embedding=emb | |
) | |
# cache the embeddings | |
cached_path.parent.mkdir(exist_ok=True, parents=True) | |
embedding.save(cached_path) | |
data.append(embedding) | |
# now, let's do a dim reduction on the embeddings | |
# and visualize them. | |
# collect a list of embeddings and labels | |
embeddings = [d.embedding for d in data] | |
labels = [d.genre for d in data] | |
# convert the embeddings to a numpy array | |
embeddings = np.stack(embeddings) | |
# do dimensionality reduction for each layer we're given | |
for layer in tqdm.tqdm(layers, desc="dim reduction"): | |
dim_reduce( | |
embeddings[:, layer, :], labels, | |
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'), | |
n_components=2, method='tsne', | |
title=f'vampnet-gtzan-layer={layer}' | |
) | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
main() |