diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..533264f69d41465d4f322ed6e7fadac8ee60a16d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.venv +.env +.cache +__pycache__ +data/audio/*.wav \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b54f9f40e9105b3d97c2dce2dbb66759d293af12 --- /dev/null +++ b/app.py @@ -0,0 +1,52 @@ +import streamlit as st +from streamlit import session_state as session +from src.config.configs import ProjectPaths +import numpy as np +from src.laion_clap.inference import AudioEncoder + + +@st.cache(persist=True, show_spinner=False, suppress_st_warning=True) +def load_data(): + vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) + return vectors + + +recommender = AudioEncoder() +audio_vectors = load_data() + +dataframe = None + +st.title(""" +Curate me a Playlist. + """) + +st.text("") +st.text("") +st.text("") +st.text("") + +session.text_input = st.text(label="Describe a playlist") + +st.text("") +st.text("") + +session.slider_count = st.slider(label="movie_count", min_value=5, max_value=50) + +st.text("") +st.text("") + +buffer1, col1, buffer2 = st.columns([1.45, 1, 1]) + +is_clicked = col1.button(label="Curate") + +if is_clicked: + text_embed = recommender.get_text_embedding(session.text_input) + + +st.text("") +st.text("") +st.text("") +st.text("") + +if dataframe is not None: + st.table(dataframe) \ No newline at end of file diff --git a/data/.DS_Store b/data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..add7523c63fc1770d3b377991b1f9c2bb3e70725 Binary files /dev/null and b/data/.DS_Store differ diff --git a/data/audio/.gitkeep b/data/audio/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/json/saved_tracks.json b/data/json/saved_tracks.json new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/vectors/audio_representations.npy b/data/vectors/audio_representations.npy new file mode 100644 index 0000000000000000000000000000000000000000..84544edd82ab9a414b723384bc1e89ab995038e7 --- /dev/null +++ b/data/vectors/audio_representations.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe4a3ff8cfd2a6b13407352868f3f74fb290ebc11e8473e7132dd4bf947108da +size 1290368 diff --git a/model_checkpoints/.gitkeep b/model_checkpoints/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt b/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt new file mode 100644 index 0000000000000000000000000000000000000000..09274ba1b6f219de82e4265777848c7a41747e9e --- /dev/null +++ b/model_checkpoints/music_audioset_epoch_15_esc_90.14.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fae3e9c087f2909c28a09dc31c8dfcdacbc42ba44c70e972b58c1bd1caf6dedd +size 2352471003 diff --git a/notebooks/notebook.ipynb b/notebooks/notebook.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ae577ada192cbc7241641753552e272c2cf98f27 --- /dev/null +++ b/notebooks/notebook.ipynb @@ -0,0 +1,788 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import librosa\n", + "import torch\n", + "from src import laion_clap\n", + "from glob import glob\n", + "import pandas as pd\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']\n", + "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Load the specified checkpoint music_audioset_epoch_15_esc_90.14.pt from users.\n", + "Load Checkpoint...\n", + "logit_scale_a \t Loaded\n", + "logit_scale_t \t Loaded\n", + "audio_branch.spectrogram_extractor.stft.conv_real.weight \t Loaded\n", + "audio_branch.spectrogram_extractor.stft.conv_imag.weight \t Loaded\n", + "audio_branch.logmel_extractor.melW \t Loaded\n", + "audio_branch.bn0.weight \t Loaded\n", + "audio_branch.bn0.bias \t Loaded\n", + "audio_branch.patch_embed.proj.weight \t Loaded\n", + "audio_branch.patch_embed.proj.bias \t Loaded\n", + "audio_branch.patch_embed.norm.weight \t Loaded\n", + "audio_branch.patch_embed.norm.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.norm1.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.norm1.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.0.blocks.0.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.attn.proj.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.attn.proj.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.norm2.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.norm2.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.0.blocks.0.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.0.blocks.0.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.norm1.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.norm1.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.0.blocks.1.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.attn.proj.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.attn.proj.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.norm2.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.norm2.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.0.blocks.1.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.0.blocks.1.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.0.downsample.reduction.weight \t Loaded\n", + "audio_branch.layers.0.downsample.norm.weight \t Loaded\n", + "audio_branch.layers.0.downsample.norm.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.norm1.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.norm1.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.1.blocks.0.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.attn.proj.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.attn.proj.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.norm2.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.norm2.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.1.blocks.0.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.1.blocks.0.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.norm1.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.norm1.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.1.blocks.1.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.attn.proj.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.attn.proj.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.norm2.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.norm2.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.1.blocks.1.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.1.blocks.1.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.1.downsample.reduction.weight \t Loaded\n", + "audio_branch.layers.1.downsample.norm.weight \t Loaded\n", + "audio_branch.layers.1.downsample.norm.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.0.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.0.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.0.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.1.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.1.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.1.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.2.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.2.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.2.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.3.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.3.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.3.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.4.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.4.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.4.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.5.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.5.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.5.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.6.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.6.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.6.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.7.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.7.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.7.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.8.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.8.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.8.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.9.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.9.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.9.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.10.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.10.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.10.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.norm1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.norm1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.2.blocks.11.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.attn.proj.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.attn.proj.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.norm2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.norm2.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.2.blocks.11.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.2.blocks.11.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.2.downsample.reduction.weight \t Loaded\n", + "audio_branch.layers.2.downsample.norm.weight \t Loaded\n", + "audio_branch.layers.2.downsample.norm.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.norm1.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.norm1.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.3.blocks.0.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.attn.proj.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.attn.proj.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.norm2.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.norm2.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.3.blocks.0.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.3.blocks.0.mlp.fc2.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.norm1.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.norm1.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.attn.relative_position_bias_table \t Loaded\n", + "audio_branch.layers.3.blocks.1.attn.qkv.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.attn.qkv.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.attn.proj.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.attn.proj.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.norm2.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.norm2.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.mlp.fc1.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.mlp.fc1.bias \t Loaded\n", + "audio_branch.layers.3.blocks.1.mlp.fc2.weight \t Loaded\n", + "audio_branch.layers.3.blocks.1.mlp.fc2.bias \t Loaded\n", + "audio_branch.norm.weight \t Loaded\n", + "audio_branch.norm.bias \t Loaded\n", + "audio_branch.tscam_conv.weight \t Loaded\n", + "audio_branch.tscam_conv.bias \t Loaded\n", + "audio_branch.head.weight \t Loaded\n", + "audio_branch.head.bias \t Loaded\n", + "text_branch.embeddings.word_embeddings.weight \t Loaded\n", + "text_branch.embeddings.position_embeddings.weight \t Loaded\n", + "text_branch.embeddings.token_type_embeddings.weight \t Loaded\n", + "text_branch.embeddings.LayerNorm.weight \t Loaded\n", + "text_branch.embeddings.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.0.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.0.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.0.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.0.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.0.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.0.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.0.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.0.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.0.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.0.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.0.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.1.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.1.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.1.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.1.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.1.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.1.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.1.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.1.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.1.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.1.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.1.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.2.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.2.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.2.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.2.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.2.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.2.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.2.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.2.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.2.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.2.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.2.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.3.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.3.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.3.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.3.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.3.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.3.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.3.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.3.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.3.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.3.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.3.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.4.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.4.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.4.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.4.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.4.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.4.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.4.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.4.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.4.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.4.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.4.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.5.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.5.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.5.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.5.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.5.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.5.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.5.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.5.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.5.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.5.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.5.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.6.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.6.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.6.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.6.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.6.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.6.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.6.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.6.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.6.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.6.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.6.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.7.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.7.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.7.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.7.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.7.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.7.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.7.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.7.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.7.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.7.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.7.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.8.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.8.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.8.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.8.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.8.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.8.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.8.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.8.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.8.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.8.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.8.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.9.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.9.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.9.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.9.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.9.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.9.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.9.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.9.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.9.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.9.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.9.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.10.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.10.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.10.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.10.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.10.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.10.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.10.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.10.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.10.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.10.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.10.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.query.weight \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.query.bias \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.key.weight \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.key.bias \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.value.weight \t Loaded\n", + "text_branch.encoder.layer.11.attention.self.value.bias \t Loaded\n", + "text_branch.encoder.layer.11.attention.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.11.attention.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.11.attention.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.11.attention.output.LayerNorm.bias \t Loaded\n", + "text_branch.encoder.layer.11.intermediate.dense.weight \t Loaded\n", + "text_branch.encoder.layer.11.intermediate.dense.bias \t Loaded\n", + "text_branch.encoder.layer.11.output.dense.weight \t Loaded\n", + "text_branch.encoder.layer.11.output.dense.bias \t Loaded\n", + "text_branch.encoder.layer.11.output.LayerNorm.weight \t Loaded\n", + "text_branch.encoder.layer.11.output.LayerNorm.bias \t Loaded\n", + "text_branch.pooler.dense.weight \t Loaded\n", + "text_branch.pooler.dense.bias \t Loaded\n", + "text_transform.sequential.0.weight \t Loaded\n", + "text_transform.sequential.0.bias \t Loaded\n", + "text_transform.sequential.3.weight \t Loaded\n", + "text_transform.sequential.3.bias \t Loaded\n", + "text_projection.0.weight \t Loaded\n", + "text_projection.0.bias \t Loaded\n", + "text_projection.2.weight \t Loaded\n", + "text_projection.2.bias \t Loaded\n", + "audio_transform.sequential.0.weight \t Loaded\n", + "audio_transform.sequential.0.bias \t Loaded\n", + "audio_transform.sequential.3.weight \t Loaded\n", + "audio_transform.sequential.3.bias \t Loaded\n", + "audio_projection.0.weight \t Loaded\n", + "audio_projection.0.bias \t Loaded\n", + "audio_projection.2.weight \t Loaded\n", + "audio_projection.2.bias \t Loaded\n" + ] + } + ], + "source": [ + "model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')\n", + "model.load_ckpt(ckpt=\"music_audioset_epoch_15_esc_90.14.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def load_music_file(file_name):\n", + " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n", + " audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T)\n", + " # audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model\n", + " with torch.no_grad():\n", + " audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False)\n", + " return audio_embed\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "music_files = glob(\"/Users/berkayg/Codes/music-project/AudioCLIP/data/downloaded_tracks/*.wav\")[:100]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/sr/r72219hj06x_1xvw7hhd517h0000gn/T/ipykernel_18860/3009710654.py:2: UserWarning: PySoundFile failed. Trying audioread instead.\n", + " audio_data, _ = librosa.load(file_name, sr=48000) # sample rate should be 48000\n", + "/Users/berkayg/miniforge3/envs/playlist-curator/lib/python3.10/site-packages/librosa/core/audio.py:183: FutureWarning: librosa.core.audio.__audioread_load\n", + "\tDeprecated as of librosa version 0.10.0.\n", + "\tIt will be removed in librosa version 1.0.\n", + " y, sr_native = __audioread_load(path, offset, duration, dtype)\n" + ] + } + ], + "source": [ + "music_data = np.zeros((len(music_files), 512), dtype=np.float32)\n", + "for m in range(music_data.shape[0]):\n", + " music_data[m] = load_music_file(music_files[m])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 512)\n" + ] + } + ], + "source": [ + "text_data = [\"This audio is a romantic song\"] \n", + "text_embed = model.get_text_embedding(text_data)\n", + "print(text_embed.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "song_names = [k.split(\"/\")[-1] for k in music_files]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([100, 1])\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " ranking = torch.tensor(music_data) @ torch.tensor(text_embed).t()\n", + " ranking = ranking[:, 0].reshape(-1, 1)\n", + "print(ranking.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
This audio is a romantic song
Coldplay - Charlie Brown.wav0.400684
Sam Smith - I'm Not The Only One.wav0.373561
Pink Floyd - The Great Gig In The Sky - 2011 Remastered Version.wav0.371584
Christina Aguilera - You Lost Me.wav0.370390
Lana Del Rey - Yayo.wav0.370379
Queen - It's A Hard Life - Remastered 2011.wav0.348699
Teoman - Haziran.wav0.331220
John Lennon - Imagine - Remastered 2010.wav0.330397
Sleeping At Last - Mars.wav0.328770
Adele - Someone Like You.wav0.325650
Coldplay - What If.wav0.315717
Adamlar - Orda Ortada.wav0.306465
Eric Clapton - Autumn Leaves.wav0.305451
Premiata Forneria Marconi - Impressioni di settembre.wav0.295878
Guthrie Govan - Lost in Rio.wav0.284883
\n", + "
" + ], + "text/plain": [ + " This audio is a romantic song\n", + "Coldplay - Charlie Brown.wav 0.400684\n", + "Sam Smith - I'm Not The Only One.wav 0.373561\n", + "Pink Floyd - The Great Gig In The Sky - 2011 Re... 0.371584\n", + "Christina Aguilera - You Lost Me.wav 0.370390\n", + "Lana Del Rey - Yayo.wav 0.370379\n", + "Queen - It's A Hard Life - Remastered 2011.wav 0.348699\n", + "Teoman - Haziran.wav 0.331220\n", + "John Lennon - Imagine - Remastered 2010.wav 0.330397\n", + "Sleeping At Last - Mars.wav 0.328770\n", + "Adele - Someone Like You.wav 0.325650\n", + "Coldplay - What If.wav 0.315717\n", + "Adamlar - Orda Ortada.wav 0.306465\n", + "Eric Clapton - Autumn Leaves.wav 0.305451\n", + "Premiata Forneria Marconi - Impressioni di sett... 0.295878\n", + "Guthrie Govan - Lost in Rio.wav 0.284883" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(ranking, columns=[text_data[0]], index=song_names).nlargest(15, text_data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "playlist-curator", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/orchestrate_audio_data.py b/orchestrate_audio_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6299850291979f9347623673dc703d5966218090 --- /dev/null +++ b/orchestrate_audio_data.py @@ -0,0 +1,8 @@ +from src.data.spotify import list_personal_saved_tracks +from src.data.get_yt_links import collect_youtube_links +from src.data.pytuber import start_download_process + +if __name__ == "__main__": + list_personal_saved_tracks() + collect_youtube_links() + start_download_process() diff --git a/recommender.py b/recommender.py new file mode 100644 index 0000000000000000000000000000000000000000..357d8f96a63c2ffcda29249df6eefa1db4cb8911 --- /dev/null +++ b/recommender.py @@ -0,0 +1,11 @@ +from src.laion_clap.inference import AudioEncoder +from src.config.configs import ProjectPaths +from glob import glob + +recommender = AudioEncoder() +# audio = recommender.extract_bulk_audio_representaions(save=False) +result = recommender.get_text_embedding("This audio is a romantic song") +music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav"))) +song_names = [k.split("/")[-1] for k in music_files] +print(result) +pass \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..76d8dbd8c27b89b4466374154b4b88b2d7d7c114 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,89 @@ +altair==5.1.2 +anyio==4.0.0 +appdirs==1.4.4 +async-timeout==4.0.3 +attrs==23.1.0 +audioread==3.0.1 +blinker==1.7.0 +braceexpand==0.1.7 +cachetools==5.3.2 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +docker-pycreds==0.4.0 +filelock==3.13.1 +fsspec==2023.10.0 +ftfy==6.1.1 +gitdb==4.0.11 +GitPython==3.1.40 +google-api-python-client==2.105.0 +google-auth-httplib2==0.1.1 +h11==0.14.0 +h5py==3.10.0 +httpcore==1.0.2 +httplib2==0.22.0 +httpx==0.25.1 +huggingface-hub==0.19.4 +idna==3.4 +Jinja2==3.1.2 +joblib==1.3.2 +jsonschema==4.20.0 +jsonschema-specifications==2023.11.1 +lazy_loader==0.3 +librosa==0.10.1 +llvmlite==0.41.1 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +mdurl==0.1.2 +msgpack==1.0.7 +numba==0.58.1 +numpy==1.23.5 +pandas==2.1.3 +Pillow==10.1.0 +pooch==1.8.0 +progressbar==2.5 +protobuf==3.20.1 +pyarrow==14.0.1 +pycparser==2.21 +pydeck==0.8.1b0 +pytube==15.0.0 +pytz==2023.3.post1 +PyYAML==6.0.1 +redis==5.0.1 +referencing==0.31.0 +regex==2023.10.3 +requests==2.31.0 +rich==13.7.0 +rpds-py==0.13.0 +safetensors==0.4.0 +scikit-learn==1.3.2 +scipy==1.11.3 +sentry-sdk==1.35.0 +setproctitle==1.3.3 +smmap==5.0.1 +sniffio==1.3.0 +soundfile==0.12.1 +soxr==0.3.7 +spotipy==2.23.0 +streamlit==1.28.2 +tenacity==8.2.3 +threadpoolctl==3.2.0 +tokenizers==0.13.3 +toml==0.10.2 +toolz==0.12.0 +torch==1.11.0 +torchaudio==0.11.0 +torchlibrosa==0.1.0 +torchvision==0.12.0 +tqdm==4.66.1 +transformers==4.30.2 +tzdata==2023.3 +tzlocal==5.2 +uritemplate==4.1.1 +urllib3==2.1.0 +validators==0.22.0 +wandb==0.16.0 +webdataset==0.2.77 +wget==3.2 +youtube-search-python==1.6.6 \ No newline at end of file diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/config/configs.py b/src/config/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..1e43f2327e4f40b62381b8703c8332cfe33d40e4 --- /dev/null +++ b/src/config/configs.py @@ -0,0 +1,16 @@ +from pathlib import Path +from dataclasses import dataclass +from os import getenv + + +@dataclass +class ProjectPaths: + ROOT: Path = Path(__file__).parents[2] + DATA_DIR: Path = ROOT.joinpath("data") + MODEL_PATH: Path = ROOT.joinpath("model_checkpoints", "music_audioset_epoch_15_esc_90.14.pt") + + +@dataclass +class Credentials: + SPOTIFY_CLIENT_ID: str = getenv("SPOTIFY_CLIENT_ID") + SPOTIFY_SECRET_ID: str = getenv("SPOTIFY_SECRET_ID") diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/get_yt_links.py b/src/data/get_yt_links.py new file mode 100644 index 0000000000000000000000000000000000000000..c7658f4b6d53c97559d9942ae31dd34bbeadb713 --- /dev/null +++ b/src/data/get_yt_links.py @@ -0,0 +1,52 @@ +from youtubesearchpython import VideosSearch +import json +import time +from src.config.configs import ProjectPaths +from tqdm import tqdm + + +def read_json_data(): + with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "r") as rd: + data = json.load(rd) + return data + + +def get_track_link(artist_name, track_name): + search_result = VideosSearch(f'{artist_name} - {track_name}', limit=1) + result = search_result.result()["result"][0] + data = { + "artist_name": artist_name, + "track_name": track_name, + "duration": result.get("duration"), + "published_time": result.get("publishedTime"), + "title": result.get("title"), + "view_count": result.get("viewCount").get("text"), + "link": result.get("link") + } + return data + + +def save_youtube_data(data): + with open(ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json"), "w") as wr: + json.dump(data, wr, indent=4) + + +def collect_youtube_links(): + data = read_json_data() + youtube_data = [] + for track_data in tqdm(data): + yt_data = get_track_link(track_data["artist"], track_data["track"]) + youtube_data.append(yt_data) + time.sleep(0.2) + save_youtube_data(youtube_data) + + +if __name__ == "__main__": + data = read_json_data() + youtube_data = [] + for track_data in tqdm(data): + yt_data = get_track_link(track_data["artist"], track_data["track"]) + youtube_data.append(yt_data) + time.sleep(0.2) + pass + save_youtube_data(youtube_data) diff --git a/src/data/pytuber.py b/src/data/pytuber.py new file mode 100644 index 0000000000000000000000000000000000000000..46d6081b459c668d228d9a02b85269fa2e811b22 --- /dev/null +++ b/src/data/pytuber.py @@ -0,0 +1,35 @@ +import os +from src.config.configs import ProjectPaths +import json +import pytube +from tqdm import tqdm +from pytube.exceptions import AgeRestrictedError + + +def read_youtube_data(): + input_data = ProjectPaths.DATA_DIR.joinpath("json", "youtube_data.json") + with open(input_data, "r") as rd: + return json.load(rd) + + +def download_mp3(link, download_path, track_full_name): + data_dir = ProjectPaths.DATA_DIR.joinpath("audio") + try: + mp3 = pytube.YouTube(link, use_oauth=True, allow_oauth_cache=True).streams.filter(only_audio=True).first() + mp3.download(data_dir) + + new_file = track_full_name + '.wav' + os.rename(download_path.joinpath(mp3.default_filename), data_dir.joinpath(new_file)) + except AgeRestrictedError: + pass + + +def start_download_process(): + input_data = read_youtube_data() + done_pieces = os.listdir(ProjectPaths.DATA_DIR.joinpath("audio")) + for i in tqdm(input_data): + link = i["link"] + full_name = f'{i["artist_name"]} - {i["track_name"]}'.replace("/", "_") + if full_name + ".wav" in done_pieces: + continue + download_mp3(link, full_name) diff --git a/src/data/spotify.py b/src/data/spotify.py new file mode 100644 index 0000000000000000000000000000000000000000..eb98bc311ee79d18f0a69b5708589bde431a59c5 --- /dev/null +++ b/src/data/spotify.py @@ -0,0 +1,24 @@ +import spotipy +from spotipy.oauth2 import SpotifyOAuth +from ..config.configs import Credentials, ProjectPaths +import json + + +def list_personal_saved_tracks(): + scope = "user-library-read" + auth = SpotifyOAuth(client_id=Credentials.SPOTIFY_CLIENT_ID, client_secret=Credentials.SPOTIFY_SECRET_ID, scope=scope, redirect_uri="https://localhost:5000") + sp = spotipy.Spotify(auth_manager=auth) + + tracks = [] + offset_count = 0 + for _ in range(50): + results = sp.current_user_saved_tracks(limit=50, offset=offset_count) + for idx, item in enumerate(results['items']): + track = item['track'] + data = {"artist": track['artists'][0]['name'], "track": track['name']} + tracks.append(data) + print(idx, track['artists'][0]['name'], " - ", track['name']) + offset_count += 50 + + with open(ProjectPaths.DATA_DIR.joinpath("json", "saved_tracks.json"), "w", encoding="UTF-8") as wr: + json.dump(tracks, wr, indent=4) diff --git a/src/laion_clap/__init__.py b/src/laion_clap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96d4b618dcb091479e2c9092ea2b807527f239de --- /dev/null +++ b/src/laion_clap/__init__.py @@ -0,0 +1,5 @@ +import os +import sys +dir_path = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(dir_path) +from .hook import CLAP_Module \ No newline at end of file diff --git a/src/laion_clap/clap_module/__init__.py b/src/laion_clap/clap_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b585be6540fe21eef8bc6594375baee5017877ef --- /dev/null +++ b/src/laion_clap/clap_module/__init__.py @@ -0,0 +1,8 @@ +from .factory import list_models, create_model, create_model_and_transforms, add_model_config +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model +from .openai import load_openai_model, list_openai_models +from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ + get_pretrained_url, download_pretrained +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform \ No newline at end of file diff --git a/src/laion_clap/clap_module/bert.py b/src/laion_clap/clap_module/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..005e72dec67e4b1c05063dbd4d024166344fd2c4 --- /dev/null +++ b/src/laion_clap/clap_module/bert.py @@ -0,0 +1,32 @@ +from transformers import BertTokenizer, BertModel +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +model = BertModel.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." + +def bert_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='pt') + output = model(**encoded_input) + return output + +from transformers import RobertaTokenizer, RobertaModel + +tokenizer = RobertaTokenizer.from_pretrained('roberta-base') +model = RobertaModel.from_pretrained('roberta-base') +text = "Replace me by any text you'd like." +def Roberta_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='pt') + output = model(**encoded_input) + return output + +from transformers import BartTokenizer, BartModel + +tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') +model = BartModel.from_pretrained('facebook/bart-base') +text = "Replace me by any text you'd like." +def bart_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors='pt') + output = model(**encoded_input) + return output \ No newline at end of file diff --git a/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz b/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/src/laion_clap/clap_module/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/src/laion_clap/clap_module/factory.py b/src/laion_clap/clap_module/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..5e581ebcef11017eb5f48dd20682a066fd1a02fc --- /dev/null +++ b/src/laion_clap/clap_module/factory.py @@ -0,0 +1,263 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path +from packaging import version + +import torch +import transformers + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # removing position_ids to maintain compatibility with latest transformers update + if version.parse(transformers.__version__) >= version.parse("4.31.0"): + del state_dict["text_branch.embeddings.position_ids"] + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = 'None' + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info(f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained}).") + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith('PANN'): + if 'Cnn14_mAP' in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + audio_ckpt = audio_ckpt['model'] + keys = list(audio_ckpt.keys()) + for key in keys: + if 'spectrogram_extractor' not in key and 'logmel_extractor' not in key: + v = audio_ckpt.pop(key) + audio_ckpt['audio_branch.' + key] = v + elif os.path.basename(pretrained_audio).startswith('PANN'): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + audio_ckpt = audio_ckpt['state_dict'] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith('sed_model'): + v = audio_ckpt.pop(key) + audio_ckpt['audio_branch.' + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + else: + raise ValueError('Unknown audio checkpoint') + elif amodel_name.startswith('HTSAT'): + if 'HTSAT_AudioSet_Saved' in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + audio_ckpt = audio_ckpt['state_dict'] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith('sed_model') and ('spectrogram_extractor' not in key + and 'logmel_extractor' not in key): + v = audio_ckpt.pop(key) + audio_ckpt['audio_branch.' + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith('HTSAT'): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + audio_ckpt = audio_ckpt['state_dict'] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith('sed_model'): + v = audio_ckpt.pop(key) + audio_ckpt['audio_branch.' + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith('finetuned'): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location='cpu') + else: + raise ValueError('Unknown audio checkpoint') + else: + raise f'this audio encoder pretrained checkpoint is not support' + + model.load_state_dict(audio_ckpt, strict=False) + logging.info(f"Loading pretrained {amodel_name} weights ({pretrained_audio}).") + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() diff --git a/src/laion_clap/clap_module/feature_fusion.py b/src/laion_clap/clap_module/feature_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c2419516b76931f0aa801d78e1b5f04a92a909e6 --- /dev/null +++ b/src/laion_clap/clap_module/feature_fusion.py @@ -0,0 +1,193 @@ +''' +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +''' + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + ''' + 直接相加 DirectAddFuse + ''' + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + ''' + 多特征融合 iAFF + ''' + + def __init__(self, channels=64, r=4, type='2D'): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == '1D': + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == '2D': + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f'the type is not supported' + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa,xa],dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + ''' + 多特征融合 AFF + ''' + + def __init__(self, channels=64, r=4, type='2D'): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == '1D': + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == '2D': + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f'the type is not supported.' + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa,xa],dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo + diff --git a/src/laion_clap/clap_module/htsat.py b/src/laion_clap/clap_module/htsat.py new file mode 100644 index 0000000000000000000000000000000000000000..bb8e7cf5f2307c57e094a122121f3ca7f527436a --- /dev/null +++ b/src/laion_clap/clap_module/htsat.py @@ -0,0 +1,1031 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16, + enable_fusion=False, fusion_type='None'): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) + + if (self.enable_fusion) and (self.fusion_type == 'channel_map'): + self.proj = nn.Conv2d(in_chans*4, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): + self.mel_conv2d = nn.Conv2d(in_chans, embed_dim, kernel_size=(patch_size[0], patch_size[1]*3), stride=(patch_stride[0], patch_stride[1] * 3), padding=padding) + if self.fusion_type == 'daf_2d': + self.fusion_model = DAF() + elif self.fusion_type == 'aff_2d': + self.fusion_model = AFF(channels=embed_dim, type='2D') + elif self.fusion_type == 'iaff_2d': + self.fusion_model = iAFF(channels=embed_dim, type='2D') + def forward(self, x, longer_idx = None): + if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): + global_x = x[:,0:1,:,:] + + + # global processing + B, C, H, W = global_x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx,1:,:,:].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B*C,1,H,W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3)) + local_x = local_x.permute((0,2,3,1,4)).contiguous().flatten(3) + TB,TC,TH,_ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat([local_x, torch.zeros((TB,TC,TH,TW-local_x.size(-1)), device=global_x.device)], dim=-1) + else: + local_x = local_x[:,:,:,:TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx],local_x) + x = global_x + else: + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if self.norm_before_mlp == 'ln': + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == 'bn': + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + norm_before_mlp='ln'): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim = 0) + attn = torch.mean(attn, dim = 0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), + in_chans=1, num_classes=527, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], + window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, patch_norm=True, + use_checkpoint=False, norm_before_mlp='ln', config = None, + enable_fusion = False, fusion_type = 'None', **kwargs): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, + win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, + n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, + embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride, + enable_fusion=self.enable_fusion, fusion_type=self.fusion_type + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, + drop=self.drop_rate, attn_drop=self.attn_drop_rate, + drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio + self.tscam_conv = nn.Conv2d( + in_channels = self.num_features, + out_channels = self.num_classes, + kernel_size = (SF,3), + padding = (0,1) + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64) + ) + if self.fusion_type == 'daf_1d': + self.fusion_model = DAF() + elif self.fusion_type == 'aff_1d': + self.fusion_model = AFF(channels=64, type='1D') + elif self.fusion_type == 'iaff_1d': + self.fusion_model = iAFF(channels=64, type='1D') + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + + def forward_features(self, x, longer_idx = None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx = longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim = 2) + fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) + + latent_output = self.avgpool(torch.flatten(x,2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + 'framewise_output': fpx, # already sigmoided + 'clipwise_output': torch.sigmoid(x), + 'fine_grained_embedding': fine_grained_latent_output, + 'embedding': latent_output + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos = None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) + if F < target_F: + x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) + x = x.permute(0,1,3,2).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) + # print(x.shape) + x = x.permute(0,1,3,2,4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) + if F < target_F: + x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) + x = x.permute(0,1,3,2).contiguous() # B C F T + x = x[:,:,:,cur_pos:cur_pos + self.spec_size] + x = x.repeat(repeats = (1,1,4,1)) + return x + + def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False, device=None):# out_feat_keys: List[str] = None): + + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + if self.training: + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + else: + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=[]) + return output_dict + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']: + new_x = x[:,0:1,:,:].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous() + FB,FC,FT,FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1)) + fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1) + else: + fusion_x_local = fusion_x_local[:,:,:FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0,2,1)).contiguous() + new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local) + x = new_x.permute((0,2,1)).contiguous()[:,None,:,:] + else: + x = new_x + + elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx = longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + + + return output_dict + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type='None'): + try: + + assert audio_cfg.model_name in ["tiny", "base", "large"], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4,4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2,2,6,2], + num_heads=[4,8,16,32], + window_size=8, + config = audio_cfg, + enable_fusion = enable_fusion, + fusion_type = fusion_type + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4,4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2,2,12,2], + num_heads=[4,8,16,32], + window_size=8, + config = audio_cfg, + enable_fusion = enable_fusion, + fusion_type = fusion_type + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4,4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2,2,12,2], + num_heads=[4,8,16,32], + window_size=8, + config = audio_cfg, + enable_fusion = enable_fusion, + fusion_type = fusion_type + ) + + return model + except: + raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.') + \ No newline at end of file diff --git a/src/laion_clap/clap_module/linear_probe.py b/src/laion_clap/clap_module/linear_probe.py new file mode 100644 index 0000000000000000000000000000000000000000..bb2841dd4e28201db8b5bd4a215e1b8b9a60d25a --- /dev/null +++ b/src/laion_clap/clap_module/linear_probe.py @@ -0,0 +1,63 @@ +import numpy as np +import torch.nn.functional as F +from torch import nn +from .model import MLPLayers + + +class LinearProbe(nn.Module): + def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): + """ + Args: + model: nn.Module + mlp: bool, if True, then use the MLP layer as the linear probe module + freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe + in_ch: int, the output channel from CLAP model + out_ch: int, the output channel from linear probe (class_num) + act: torch.nn.functional, the activation function before the loss function + """ + super().__init__() + in_ch = 512 + self.clap_model = model + self.clap_model.text_branch = None # to save memory + self.freeze = freeze + if mlp: + self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) + else: + self.lp_layer = nn.Linear(in_ch, out_ch) + + if self.freeze: + for param in self.clap_model.parameters(): + param.requires_grad = False + + if act == 'None': + self.act = None + elif act == 'relu': + self.act = nn.ReLU() + elif act == 'elu': + self.act = nn.ELU() + elif act == 'prelu': + self.act = nn.PReLU(num_parameters=in_ch) + elif act == 'softmax': + self.act = nn.Softmax(dim=-1) + elif act == 'sigmoid': + self.act = nn.Sigmoid() + + def forward(self, x, mix_lambda=None, device=None): + """ + Args: + x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list + mix_lambda: torch.tensor [batch], the mixup lambda + Returns: + class_prob: torch.tensor [batch, class_num] + + """ + # batchnorm cancel grandient + if self.freeze: + self.clap_model.eval() + + x = self.clap_model.audio_projection( + self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)["embedding"]) + out = self.lp_layer(x) + if self.act is not None: + out = self.act(out) + return out diff --git a/src/laion_clap/clap_module/loss.py b/src/laion_clap/clap_module/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..53bbedd959813b072b146c16c14cd96df6cada14 --- /dev/null +++ b/src/laion_clap/clap_module/loss.py @@ -0,0 +1,307 @@ +from multiprocessing.sharedctypes import Value +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False +): + if use_horovod: + assert hvd is not None, 'Please install horovod' + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list(all_audio_features.chunk(world_size, dim=0)) + gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list(all_audio_features_mlp.chunk(world_size, dim=0)) + gathered_text_features_mlp = list(all_text_features_mlp.chunk(world_size, dim=0)) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat(torch.distributed.nn.all_gather(audio_features), dim=0) + all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(torch.distributed.nn.all_gather(audio_features_mlp), dim=0) + all_text_features_mlp = torch.cat(torch.distributed.nn.all_gather(text_features_mlp), dim=0) + else: + gathered_audio_features = [torch.zeros_like(audio_features) for _ in range(world_size)] + gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [torch.zeros_like(audio_features_mlp) for _ in range(world_size)] + gathered_text_features_mlp = [torch.zeros_like(text_features_mlp) for _ in range(world_size)] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp + else: + return all_audio_features, all_text_features + +class ClipLoss(nn.Module): + + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa!=0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, audio_features, text_features, logit_scale_a, logit_scale_t=None, audio_features_mlp=None, text_features_mlp=None): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = gather_features( + audio_features=audio_features,text_features=text_features, + audio_features_mlp=audio_features_mlp,text_features_mlp=text_features_mlp, + local_loss=self.local_loss,gather_with_grad=self.gather_with_grad, + rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss + ) + if self.local_loss: + a_logits_per_audio = logit_scale_a * audio_features @ all_text_features_mlp.T + a_logits_per_text = logit_scale_a * text_features_mlp @ all_audio_features.T + t_logits_per_audio = logit_scale_t * audio_features_mlp @ all_text_features.T + t_logits_per_text = logit_scale_t * text_features @ all_audio_features_mlp.T + else: + a_logits_per_audio = logit_scale_a * all_audio_features @ all_text_features_mlp.T + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = logit_scale_t * all_audio_features_mlp @ all_text_features.T + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = logit_scale_a * audio_features @ text_features_mlp.T + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = logit_scale_t * audio_features_mlp @ text_features.T + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features@audio_features.T).detach() + audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(audio_weight)))).detach() + text_weight = (text_features@text_features.T).detach() + text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(text_features)))).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features,text_features=text_features, + local_loss=self.local_loss,gather_with_grad=self.gather_with_grad, + rank=self.rank,world_size=self.world_size,use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss + ) + + if self.local_loss: + logits_per_audio = logit_scale_a * audio_features @ all_text_features.T + logits_per_text = logit_scale_a * text_features @ all_audio_features.T + else: + logits_per_audio = logit_scale_a * all_audio_features @ all_text_features.T + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features@all_audio_features.T).detach() + audio_weight = (torch.exp(torch.sum(audio_weight, axis=1)/(self.weight_loss_kappa*len(all_audio_features)))).detach() + text_weight = (all_text_features@all_text_features.T).detach() + text_weight = (torch.exp(torch.sum(text_weight, axis=1)/(self.weight_loss_kappa*len(all_text_features)))).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + +def lp_gather_features( + pred, + target, + world_size=1, + use_horovod=False +): + if use_horovod: + assert hvd is not None, 'Please install horovod' + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + +def get_acc(pred, target): + pred = torch.argmax(pred,1).numpy() + target = torch.argmax(target,1).numpy() + return accuracy_score(target, pred) + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names = ['map','acc','mauc']): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self,name): + if name == 'map': + return get_map + elif name == 'acc': + return get_acc + elif name == 'mauc': + return get_mauc + else: + raise ValueError(f'the metric should be at least one of [map, acc, mauc]') + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + + def __init__(self, loss_name): + super().__init__() + if loss_name == 'bce': + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == 'ce': + self.loss_func = calc_celoss + elif loss_name == 'mse': + self.loss_func = nn.MSELoss() + else: + raise ValueError(f'the loss func should be at least one of [bce, ce, mse]') + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss + \ No newline at end of file diff --git a/src/laion_clap/clap_module/model.py b/src/laion_clap/clap_module/model.py new file mode 100644 index 0000000000000000000000000000000000000000..60663ec2658cb302f093625c5ce02fc843e6a5bc --- /dev/null +++ b/src/laion_clap/clap_module/model.py @@ -0,0 +1,892 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from email.mime import audio +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .timm_model import TimmModel +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel +from transformers.tokenization_utils_base import BatchEncoding + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = 'None', + joint_embed_shape: int = 512, + mlp_act: str = 'relu', + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == 'relu': + mlp_act_layer = nn.ReLU() + elif mlp_act == 'gelu': + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers(units=[self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape], dropout=0.1) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape) + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers(units=[self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape], dropout=0.1) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape) + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained('roberta-base') + self.text_transform = MLPLayers(units=[self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape], dropout=0.1) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape) + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained('facebook/bart-base') + self.text_transform = MLPLayers(units=[self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape], dropout=0.1) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape) + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers(units=[self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape], dropout=0.1) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape) + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + width = self.text_branch.shared.weight.shape[-1] + else: + width = self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean(self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"],axis=1) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection(self.encode_audio(audio, device=device)["embedding"]) + audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"]) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text( + text, device=device + ) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + input_dict = {} + keys = data[0].keys() + for k in keys: + input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device) + audio_embeds = self.encode_audio(input_dict, device=device)["embedding"] + audio_embeds = self.audio_projection(audio_embeds) + audio_embeds = F.normalize(audio_embeds, dim=-1) + return audio_embeds + + + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'): + + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-base.json b/src/laion_clap/clap_module/model_configs/HTSAT-base.json new file mode 100644 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-large.json b/src/laion_clap/clap_module/model_configs/HTSAT-large.json new file mode 100644 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json b/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json b/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-10.json b/src/laion_clap/clap_module/model_configs/PANN-10.json new file mode 100644 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json new file mode 100644 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json new file mode 100644 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json b/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json new file mode 100644 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json b/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-14.json b/src/laion_clap/clap_module/model_configs/PANN-14.json new file mode 100644 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/PANN-6.json b/src/laion_clap/clap_module/model_configs/PANN-6.json new file mode 100644 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json b/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/RN101.json b/src/laion_clap/clap_module/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json b/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/src/laion_clap/clap_module/model_configs/RN50.json b/src/laion_clap/clap_module/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/RN50x16.json b/src/laion_clap/clap_module/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/RN50x4.json b/src/laion_clap/clap_module/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-16.json b/src/laion_clap/clap_module/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json b/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/ViT-B-32.json b/src/laion_clap/clap_module/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/model_configs/ViT-L-14.json b/src/laion_clap/clap_module/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/src/laion_clap/clap_module/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/laion_clap/clap_module/openai.py b/src/laion_clap/clap_module/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..9911b6e135e51970177fcac067c12192b0b57c1c --- /dev/null +++ b/src/laion_clap/clap_module/openai.py @@ -0,0 +1,129 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained + +__all__ = ["list_openai_models", "load_openai_model"] + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models('openai') + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser("~/.cache/clip"), + enable_fusion: bool = False, + fusion_type: str = 'None' +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, 'openai'): + model_path = download_pretrained(get_pretrained_url(name, 'openai'), root=cache_dir) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict(state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd, model_cfg, enable_fusion, fusion_type).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/src/laion_clap/clap_module/pann_model.py b/src/laion_clap/clap_module/pann_model.py new file mode 100644 index 0000000000000000000000000000000000000000..109db5f418a0bad32cae2452742589ff52a19b85 --- /dev/null +++ b/src/laion_clap/clap_module/pann_model.py @@ -0,0 +1,543 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os +os.environ['NUMBA_CACHE_DIR'] = '/tmp/' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate, pad_framewise_output +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer. """ + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, 'bias'): + if layer.bias is not None: + layer.bias.data.fill_(0.) + + +def init_bn(bn): + """Initialize a Batchnorm layer. """ + bn.bias.data.fill_(0.) + bn.weight.data.fill_(1.) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.conv2 = nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), stride=(1, 1), + padding=(1, 1), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), stride=(1, 1), + padding=(2, 2), bias=False) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + + def forward(self, input, pool_size=(2, 2), pool_type='avg'): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == 'max': + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg': + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == 'avg+max': + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception('Incorrect argument!') + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation='linear', temperature=1.): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == 'linear': + return x + elif self.activation == 'sigmoid': + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + fmax, classes_num, enable_fusion=False, fusion_type='None'): + + super(Cnn14, self).__init__() + + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + win_length=window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == 'channel_map'): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and (self.fusion_type in ['daf_1d','aff_1d','iaff_1d']): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64) # No Relu + ) + if self.fusion_type == 'daf_1d': + self.fusion_model = DAF() + elif self.fusion_type == 'aff_1d': + self.fusion_model = AFF(channels=64, type='1D') + elif self.fusion_type == 'iaff_1d': + self.fusion_model = iAFF(channels=64, type='1D') + + if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5,5), stride=(6, 2), padding=(2,2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True) + ) + + if self.fusion_type == 'daf_2d': + self.fusion_model = DAF() + elif self.fusion_type == 'aff_2d': + self.fusion_model = AFF(channels=64, type='2D') + elif self.fusion_type == 'iaff_2d': + self.fusion_model = iAFF(channels=64, type='2D') + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor(input['waveform'].to(device=device, non_blocking=True)) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ['daf_1d','aff_1d','iaff_1d']: + new_x = x[:,0:1,:,:].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx,1:,:,:].clone().contiguous() + FB,FC,FT,FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute(fusion_x_local, (0,2,1)).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view(FB,FC,FF,fusion_x_local.size(-1)) + fusion_x_local = torch.permute(fusion_x_local, (0,2,1,3)).contiguous().flatten(2) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat([fusion_x_local, torch.zeros((FB,FF,FT- fusion_x_local.size(-1)), device=device)], dim=-1) + else: + fusion_x_local = fusion_x_local[:,:,:FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0,2,1)).contiguous() + new_x[longer_list_idx] = self.fusion_model(new_x[longer_list_idx], fusion_x_local) + x = new_x.permute((0,2,1)).contiguous()[:,None,:,:] + else: + x = new_x + elif self.fusion_type in ['daf_2d','aff_2d','iaff_2d','channel_map']: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and (self.fusion_type in ['daf_2d','aff_2d','iaff_2d']): + global_x = x[:,0:1,:,:] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type='avg') + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx,1:,:,:].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B*C,1,H,W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view(B,C,local_x.size(1),local_x.size(2),local_x.size(3)) + local_x = local_x.permute((0,2,1,3,4)).contiguous().flatten(2,3) + TB,TC,_,TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat([local_x, torch.zeros((TB,TC,TH-local_x.size(-2),TW), device=global_x.device)], dim=-2) + else: + local_x = local_x[:,:,:TH,:] + + global_x[longer_list_idx] = self.fusion_model(global_x[longer_list_idx],local_x) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} + return output_dict + + +class Cnn6(nn.Module): + def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + fmax, classes_num, enable_fusion=False, fusion_type='None'): + + super(Cnn6, self).__init__() + + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + win_length=window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} + + return output_dict + + +class Cnn10(nn.Module): + def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + fmax, classes_num, enable_fusion=False, fusion_type='None'): + + super(Cnn10, self).__init__() + + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, + win_length=window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, + n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = {'clipwise_output': clipwise_output, 'embedding': embedding, 'fine_grained_embedding': latent_output} + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type='None'): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate = audio_cfg.sample_rate, + window_size = audio_cfg.window_size, + hop_size =audio_cfg.hop_size, + mel_bins = audio_cfg.mel_bins, + fmin = audio_cfg.fmin, + fmax = audio_cfg.fmax, + classes_num = audio_cfg.class_num, + enable_fusion = enable_fusion, + fusion_type = fusion_type + ) + return model + except: + raise RuntimeError(f'Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough.') + diff --git a/src/laion_clap/clap_module/pretrained.py b/src/laion_clap/clap_module/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..723619a9fd511cf8619def49c4631ec701891b93 --- /dev/null +++ b/src/laion_clap/clap_module/pretrained.py @@ -0,0 +1,147 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_tag_models(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return '' + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return '' + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = '' + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target diff --git a/src/laion_clap/clap_module/timm_model.py b/src/laion_clap/clap_module/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..071dd148c772f398e87ecbfc836dcfa4a3ae01af --- /dev/null +++ b/src/laion_clap/clap_module/timm_model.py @@ -0,0 +1,106 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d +except ImportError as e: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """ timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool='avg', + proj='linear', + drop=0., + pretrained=False): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get('pool_size', None) + feature_ndim = 1 if not feat_size else 2 + if pool in ('abs_attn', 'rot_attn'): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool='') + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == 'abs_attn': + head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) + prev_chs = embed_dim + elif pool == 'rot_attn': + head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, 'projection layer needed if non-attention pooling is used.' + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == 'linear': + head_layers['drop'] = nn.Dropout(drop) + head_layers['proj'] = nn.Linear(prev_chs, embed_dim) + elif proj == 'mlp': + head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """ lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/src/laion_clap/clap_module/tokenizer.py b/src/laion_clap/clap_module/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4a238b987ce66f2932b11451d916e40816b8a3 --- /dev/null +++ b/src/laion_clap/clap_module/tokenizer.py @@ -0,0 +1,180 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + if not special_tokens: + special_tokens = ['', ''] + else: + special_tokens = ['', ''] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t:t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, :len(tokens)] = torch.tensor(tokens) + + return result diff --git a/src/laion_clap/clap_module/transform.py b/src/laion_clap/clap_module/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..7014c926f153a351d2256c869c67c02d57b30913 --- /dev/null +++ b/src/laion_clap/clap_module/transform.py @@ -0,0 +1,30 @@ +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711) +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose([ + RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), + _convert_to_rgb, + ToTensor(), + normalize, + ]) + else: + return Compose([ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ]) diff --git a/src/laion_clap/clap_module/utils.py b/src/laion_clap/clap_module/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee10b9310944e9de9f8e1db1a4104defd0423744 --- /dev/null +++ b/src/laion_clap/clap_module/utils.py @@ -0,0 +1,389 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging +import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "ESC50_1": ["train", "test"], + "ESC50_2": ["train", "test"], + "ESC50_3": ["train", "test"], + "ESC50_4": ["train", "test"], + "ESC50_5": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "audioset_t5_debiased": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "epidemic_sound_effects_t5_debiased": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], + "fma_full": ["train", "test"], + "Genius": ["train", "test"], + "Jamendo": ["train", "test"], + "juno": ["train", "test"], + "CMU_Arctic": ["train", "test"], + "ravdess": ["train", "test"], + "Europarl-st": ["train", "test"], + "common_voice": ["train", "test"], + "Jamendo_16bit": ["train", "test"], + "genius_16bit_128": ["train", "test"], + "juno_16bit": ["train", "test"], + "fma_full_16bit_128": ["train", "test"], + "GTZAN": ["train", "test"], + } + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, + dataset_types, + islocal, + dataset_path, + proportion=1, + full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +def process_ipc(index_path, classes_num, filename): + # load data + logging.info("Load Data...............") + ipc = [[] for _ in range(classes_num)] + with h5py.File(index_path, "r") as f: + for i in tqdm(range(len(f["target"]))): + t_class = np.where(f["target"][i])[0] + for t in t_class: + ipc[t].append(i) + print(ipc) + np.save(filename, ipc) + logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + with open(name, 'w') as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + with open(name, 'r') as fp: + data = json.load(fp) + return data + + +from multiprocessing import Process, Manager +from multiprocessing import Process, Value, Array +from ctypes import c_wchar + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW( + params, lr=lr, betas=betas, eps=eps + ) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD( + params, lr=lr, momentum=momentum + ) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam( + params, lr=lr, betas=betas, eps=eps + ) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/src/laion_clap/clap_module/version.py b/src/laion_clap/clap_module/version.py new file mode 100644 index 0000000000000000000000000000000000000000..fc79d63d5430b972ac6ec1c4bfea9af80922da4d --- /dev/null +++ b/src/laion_clap/clap_module/version.py @@ -0,0 +1 @@ +__version__ = '0.2.1' diff --git a/src/laion_clap/evaluate/__init__.py b/src/laion_clap/evaluate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/laion_clap/evaluate/eval_dcase.py b/src/laion_clap/evaluate/eval_dcase.py new file mode 100644 index 0000000000000000000000000000000000000000..c615651f2d96f7e34d109e9c3dbb8abc7275065f --- /dev/null +++ b/src/laion_clap/evaluate/eval_dcase.py @@ -0,0 +1,150 @@ +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from open_clip import create_model +from open_clip import tokenize +import glob +import json +import librosa +from tqdm import tqdm +import numpy as np +import os +from laion_clap.training.params import parse_args + + +def get_output_from_single_audio(audio, text, model, device): + + # audio_embedding = model.audio_infer(audio, hopsize=5 * 48000, key="embedding", device=device)['embedding'] + # if audio_embedding.ndim > 1: + # audio_embedding = audio_embedding.mean(dim=0, keepdim=True) + # else: + # audio_embedding = audio_embedding.unsqueeze(0) + audio_features = model(audio, None, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = model(None, text, device=device) + text_features = F.normalize(text_features, dim=-1) + + # CHANGE: before normalize or after + audio_features_mlp = model.audio_transform(audio_features) + text_features_mlp = model.text_transform(text_features) + return audio_features, text_features, audio_features_mlp, text_features_mlp, model.logit_scale_a.exp(), model.logit_scale_t.exp() + + +def get_metrics(text_to_audio_logits): + metrics = {} + + # repeat ground truth 5 times because Clotho has 5 text for 1 audio + ground_truth = torch.repeat_interleave(torch.arange(len(text_features) // 5), 5).view(-1, 1) + + ranking = torch.argsort(text_to_audio_logits, descending=True) + preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"mean_rank"] = preds.mean() + 1 + metrics[f"median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + return metrics + + +if __name__ == '__main__': + args = parse_args() + + model_path = args.pretrained + + clotho_test_preprocessed_dir = "/fsx/yusong/clotho_test_set/test" + + cudnn.benchmark = True + cudnn.deterministic = False + + audio_features_ensemble_all = [] + text_features_ensemble_all = [] + audio_features_mlp_ensemble_all = [] + text_features_mlp_ensemble_all = [] + logit_scale_a_ensemble_all = [] + logit_scale_t_ensemble_all = [] + + + device = torch.device('cuda') + model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False + ) + + # load model + checkpoint = torch.load(model_path, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module."):]: v for k, v in sd.items()} + model.load_state_dict(sd) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + + model.to(device) + model.eval() + for param in model.parameters(): + param.requires_grad = False + + # take every 5th file because clotho has 5 texts for 1 audio + test_file_list = sorted(glob.glob(f"{clotho_test_preprocessed_dir}/*.flac")) + + audio_features_all = [] + text_features_all = [] + audio_features_mlp_all = [] + text_features_mlp_all = [] + logit_scale_a_all = [] + logit_scale_t_all = [] + + with torch.no_grad(): + for file_path in tqdm(test_file_list): + json_path = file_path.replace(".flac", ".json") + with open(json_path, "r") as f: + json_data = json.load(f) + audio, sr = librosa.load(file_path, sr=48000, mono=True) + audio = torch.from_numpy(audio).to(device) + audio = {'waveform': audio.unsqueeze(0), 'sample_rate': sr} + text = json_data["text"] + + if args.tmodel == "transformer": + from open_clip import tokenize + text = tokenize(text) + else: + from laion_clap.training.data import tokenizer + text = tokenizer(text, tmodel=args.tmodel) # 5 texts for each audio + + audio_features, text_features, audio_features_mlp, text_features_mlp, logit_scale_a, logit_scale_t = \ + get_output_from_single_audio(audio, text, model, device) + + audio_features_all.append(audio_features.detach().cpu()) + text_features_all.append(text_features.detach().cpu()) + audio_features_mlp_all.append(audio_features_mlp.detach().cpu()) + text_features_mlp_all.append(text_features_mlp.detach().cpu()) + logit_scale_a_all.append(logit_scale_a.detach().cpu()) + logit_scale_t_all.append(logit_scale_t.detach().cpu()) + + audio_features = torch.cat(audio_features_all) + text_features = torch.cat(text_features_all) + logit_scale_a = logit_scale_a_all[0] + + logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_audio.t().detach().cpu() + + metrics = get_metrics( + logits_per_text + ) + + print(metrics) diff --git a/src/laion_clap/evaluate/eval_linear_probe.py b/src/laion_clap/evaluate/eval_linear_probe.py new file mode 100644 index 0000000000000000000000000000000000000000..e5123451b503de7c54fac58e99b3507439992171 --- /dev/null +++ b/src/laion_clap/evaluate/eval_linear_probe.py @@ -0,0 +1,515 @@ +''' +Evalute the linear probe performance on different checkpoints +''' +import logging +import os +import random +from datetime import datetime +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.cuda.amp import GradScaler +import glob + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from clap_module import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_main import config_lp_optimizer +from training.lp_train import train_one_epoch, evaluate +from clap_module.utils import get_tar_path_from_dataset_name, dataset_split +from clap_module.utils import load_p, load_class_label +from clap_module.linear_probe import LinearProbe + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True, pretrain_epoch=0 +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + pretrain_epoch=pretrain_epoch + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + pretrain_epoch=pretrain_epoch + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + + pretrained_ckpts = sorted(glob.glob(os.path.join(args.pretrained, "*.pt")), key=os.path.getmtime) + + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path+'-'+str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f'openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}') + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + logging.debug("Finished loading wandb.") + + for idx, f in enumerate(pretrained_ckpts): + logging.info(f"pretrained on {f}") + args.pretrained = f + ckpt = torch.load(f, map_location='cpu') + pretrain_epoch = 0 + if 'epoch' in ckpt: + pretrain_epoch = ckpt['epoch'] + # train + best_metrics = lp_main(args, device, writer, pretrain_epoch, idx) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + for name, val in best_metrics.items(): + wandb.log({f"val/summary/{name}": val, "epoch": pretrain_epoch}) + + if args.wandb and is_master(args): + wandb.finish() + +def update_metric(best_metric, new_metric): + for key in new_metric: + if key not in best_metric: + best_metric[key] = new_metric[key] + else: + best_metric[key] = max(best_metric[key], new_metric[key]) + return best_metric + +def lp_main(args, device, writer, pretrain_epoch, idx): + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + if idx == 0: + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, freeze=args.lp_freeze, + in_ch=512, out_ch=args.lp_out_ch, + act=args.lp_act + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args) and idx == 0: + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(model, data, args) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + if args.wandb and is_master(args): + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + if args.debug: + wandb.watch(model, log="all") + if idx == 0: + wandb.save(params_file) + + best_metrics = {} + + if "train" not in data: + metric = evaluate(model, data, start_epoch, args, writer, extra_suffix="_pe@" + str(pretrain_epoch)) + if is_master(args): + best_metrics = update_metric(best_metrics, metric) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + metric = evaluate(model, data, 0, args, writer, extra_suffix="_pe@" + str(pretrain_epoch)) + if is_master(args): + best_metrics = update_metric(best_metrics, metric) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer, extra_suffix="_pe@" + str(pretrain_epoch)) + completed_epoch = epoch + 1 + + if any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) and not args.no_eval: + metric = evaluate(model, data, completed_epoch, args, writer, extra_suffix="_pe@" + str(pretrain_epoch)) + if is_master(args): + best_metrics = update_metric(best_metrics, metric) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metric.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "pretrain_epoch": pretrain_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"pretrain_epoch_{pretrain_epoch}_lp_epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + pretrain_epoch=pretrain_epoch + ) + del clap_model + return best_metrics + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() + + diff --git a/src/laion_clap/evaluate/eval_retrieval.py b/src/laion_clap/evaluate/eval_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..739734cee63588ec647dcb9189520af6294764f2 --- /dev/null +++ b/src/laion_clap/evaluate/eval_retrieval.py @@ -0,0 +1,192 @@ +import os.path +import glob +import random +import numpy as np +import logging +import wandb +import torch +import torch.backends.cudnn as cudnn +from laion_clap import create_model +from laion_clap.training.logger import setup_logging +from laion_clap.training.data import get_data +from laion_clap.training.train import evaluate +from laion_clap.utils import get_tar_path_from_dataset_name, dataset_split +from laion_clap.training.params import parse_args + + +def find_params_value(file, key): + # find value of params in params_file + with open(file, 'r') as f: + for line in f: + if key + ': ' in line: + return line.split(': ')[1].strip() + return None + + +if __name__ == '__main__': + # (yusong) repeated run might have different metric results. + # This is because we randomly select crop 10s for each audio. + args = parse_args() + + if os.path.isdir(args.pretrained): + log_dir = os.path.dirname(args.pretrained) + else: + log_dir = os.path.dirname(os.path.dirname(args.pretrained)) + + args.log_level = logging.DEBUG if args.debug else logging.INFO + log_path = os.path.join(log_dir, 'out.log') + setup_logging(log_path, args.log_level) + params_file = os.path.join(log_dir, 'params.txt') + + seed = 3407 + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + + cudnn.benchmark = True + cudnn.deterministic = False + pretrained = 'openai' + amodel = find_params_value(params_file, 'amodel') + tmodel = find_params_value(params_file, 'tmodel') + + if amodel is None or tmodel is None: + raise ValueError('model type not found in params file') + + # set up dummy values for args + args.parallel_eval = False + args.rank = 0 + args.local_rank = 0 + args.world_size = 1 + args.val_frequency = 1 + args.epochs = 1 + args.precision = 'fp32' + args.save_logs = True + args.wandb = True + args.class_index_dict = None + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + args.device = device + + if args.remotedata: + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + ) + args.val_data = get_tar_path_from_dataset_name( + args.datasetnames, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + ) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) # a hack to get model_cfg + + data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data + + writer = None # if use tensorboard, initalize writer here + + if args.wandb: + assert wandb is not None, "Please install wandb." + + # # find the line with "wandb_notes" and get the value + # wandb_notes = find_params_value(params_file, 'wandb_notes') + # if wandb_notes is None: + # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') + # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' + # wandb_notes = wandb_notes + '-eval-retrieval' + wandb_notes = args.wandb_notes + + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + if args.wandb_id is not None: + wandb.init( + project="clap", + id=args.wandb_id, + resume=True + ) + else: + wandb.init( + project="clap", + notes=wandb_notes, + name=wandb_notes, + tags=[], + config=vars(args), + ) + logging.debug("Finished loading wandb.") + + if os.path.isdir(args.pretrained): + all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) + else: + all_model_checkpoints = [args.pretrained] + for model_path in all_model_checkpoints: + args.checkpoint_path = os.path.dirname(model_path) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + # load model + checkpoint = torch.load(model_path, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module."):]: v for k, v in sd.items()} + model.load_state_dict(sd) + logging.info( + f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + start_epoch = 0 + + model.to(device) + model.eval() + for param in model.parameters(): + param.requires_grad = False + + evaluate(model, data, start_epoch, args, writer) diff --git a/src/laion_clap/evaluate/eval_retrieval_main.py b/src/laion_clap/evaluate/eval_retrieval_main.py new file mode 100644 index 0000000000000000000000000000000000000000..edfa65fdbf19377c1a0ba5c2e8c4fdc6f0d64e96 --- /dev/null +++ b/src/laion_clap/evaluate/eval_retrieval_main.py @@ -0,0 +1,257 @@ +import os.path +import glob +import random +import numpy as np +import logging +import wandb +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from clap_module import create_model +from clap_module import tokenize +from training.logger import setup_logging +from training.data import get_data +from training.train import evaluate +from clap_module.utils import get_tar_path_from_dataset_name, dataset_split +from training.params import parse_args + + +def find_params_value(file, key): + # find value of params in params_file + with open(file, 'r') as f: + for line in f: + if key + ': ' in line: + return line.split(': ')[1].strip() + return None + + +def evaluate_zeroshot(model, data, start_epoch, args, writer): + dataloader = data["val"].dataloader + metrics = {} + device = torch.device(args.device) + model.eval() + metrics.update({"epoch": start_epoch}) + + all_audio_features = [] + all_class_labels = [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + audio_features = model(audios, None, device) + audio_features = F.normalize(audio_features, dim=-1) + all_audio_features.append(audio_features.detach().cpu()) + all_class_labels.append(torch.argmax(batch["class_label"], 1).long()) + all_audio_features = torch.cat(all_audio_features, dim=0) + all_class_labels = torch.cat(all_class_labels, dim=0) + metrics["num_samples"] = all_audio_features.shape[0] + + # get text features + all_texts = ["This is a sound of " + t for t in args.class_index_dict.keys()] + # (yusong): a hack, can make it better + if args.tmodel == "transformer": + from clap_module.tokenizer import tokenize + all_texts = tokenize(all_texts) + else: + from training.data import tokenizer + all_texts = tokenizer(all_texts) + all_text_features = model(None, all_texts, device) + all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu() + + # compute similarity + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu() + logits_per_text = logits_per_audio.t().detach().cpu() + + ground_truth = all_class_labels.view(-1, 1) + logit = logits_per_audio + + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1 + metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + logging.info( + f"Eval Epoch: {start_epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": start_epoch}) + + +if __name__ == '__main__': + # (yusong) repeated run might have different metric results. + # This is because we randomly select crop 10s for each audio. + args = parse_args() + + if os.path.isdir(args.pretrained): + log_dir = os.path.dirname(args.pretrained) + else: + log_dir = os.path.dirname(os.path.dirname(args.pretrained)) + + args.log_level = logging.DEBUG if args.debug else logging.INFO + log_path = os.path.join(log_dir, 'out.log') + setup_logging(log_path, args.log_level) + params_file = os.path.join(log_dir, 'params.txt') + + seed = 3407 + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + + cudnn.benchmark = True + cudnn.deterministic = False + pretrained = 'openai' + amodel = find_params_value(params_file, 'amodel') + tmodel = find_params_value(params_file, 'tmodel') + + if amodel is None or tmodel is None: + raise ValueError('model type not found in params file') + + # set up dummy values for args + args.parallel_eval = False + args.rank = 0 + args.local_rank = 0 + args.world_size = 1 + args.val_frequency = 1 + args.epochs = 1 + args.precision = 'fp32' + args.save_logs = True + args.wandb = args.report_to == 'wandb' + args.class_index_dict = None + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + args.device = device + + if args.remotedata: + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + ) + args.val_data = get_tar_path_from_dataset_name( + args.datasetnames, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + ) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) # a hack to get model_cfg + + data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data + + writer = None # if use tensorboard, initalize writer here + + if args.wandb: + assert wandb is not None, "Please install wandb." + + # # find the line with "wandb_notes" and get the value + # wandb_notes = find_params_value(params_file, 'wandb_notes') + # if wandb_notes is None: + # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') + # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' + # wandb_notes = wandb_notes + '-eval-retrieval' + wandb_notes = args.wandb_notes + + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + if args.wandb_id is not None: + wandb.init( + project="clap", + id=args.wandb_id, + resume=True + ) + else: + wandb.init( + project="clap", + notes=wandb_notes, + name=wandb_notes, + tags=[], + config=vars(args), + ) + logging.debug("Finished loading wandb.") + + if os.path.isdir(args.pretrained): + all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) + else: + all_model_checkpoints = [args.pretrained] + for model_path in all_model_checkpoints: + args.checkpoint_path = os.path.dirname(model_path) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + # load model + checkpoint = torch.load(model_path, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module."):]: v for k, v in sd.items()} + model.load_state_dict(sd) + logging.info( + f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + start_epoch = 0 + + model.to(device) + model.eval() + for param in model.parameters(): + param.requires_grad = False + + evaluate_zeroshot(model, data, start_epoch, args, writer) diff --git a/src/laion_clap/evaluate/eval_zeroshot_classification.py b/src/laion_clap/evaluate/eval_zeroshot_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..577cb91125bacb7f7dc7fb9841c2cd819478a736 --- /dev/null +++ b/src/laion_clap/evaluate/eval_zeroshot_classification.py @@ -0,0 +1,261 @@ +import os.path +import glob +import random +import numpy as np +import logging +import wandb +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from clap_module import create_model +from clap_module import tokenize +from training.logger import setup_logging +from training.data import get_data +from training.train import evaluate +from clap_module.utils import get_tar_path_from_dataset_name, dataset_split +from training.params import parse_args + + +def find_params_value(file, key): + # find value of params in params_file + with open(file, 'r') as f: + for line in f: + if key + ': ' in line: + return line.split(': ')[1].strip() + return None + + +def evaluate_zeroshot(model, data, start_epoch, args, writer): + dataloader = data["val"].dataloader + metrics = {} + device = torch.device(args.device) + model.eval() + metrics.update({"epoch": start_epoch}) + + all_audio_features = [] + all_class_labels = [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + audio_features = model(audios, None, device) + audio_features = F.normalize(audio_features, dim=-1) + all_audio_features.append(audio_features.detach().cpu()) + all_class_labels.append(torch.argmax(batch["class_label"], 1).long()) + all_audio_features = torch.cat(all_audio_features, dim=0) + all_class_labels = torch.cat(all_class_labels, dim=0) + metrics["num_samples"] = all_audio_features.shape[0] + + # get text features + if args.val_dataset_names == ['GTZAN']: + all_texts = [f"This is a {t} song." for t in args.class_index_dict.keys()] + else: + all_texts = [f"This is a sound of {t}." for t in args.class_index_dict.keys()] + logging.info(f'class label prompts: {all_texts}') + # (yusong): a hack, can make it better + if args.tmodel == "transformer": + from clap_module.tokenizer import tokenize + all_texts = tokenize(all_texts) + else: + from training.data import tokenizer + all_texts = tokenizer(all_texts) + all_text_features = model(None, all_texts, device) + all_text_features = F.normalize(all_text_features, dim=-1).detach().cpu() + + # compute similarity + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + logits_per_audio = (logit_scale_a * all_audio_features @ all_text_features.t()).detach().cpu() + logits_per_text = logits_per_audio.t().detach().cpu() + + ground_truth = all_class_labels.view(-1, 1) + logit = logits_per_audio + + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{args.datasetnames[0]}_mean_rank"] = preds.mean() + 1 + metrics[f"{args.datasetnames[0]}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{args.datasetnames[0]}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{args.datasetnames[0]}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + logging.info( + f"Eval Epoch: {start_epoch} " + + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) + ) + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": start_epoch}) + + +if __name__ == '__main__': + # (yusong) repeated run might have different metric results. + # This is because we randomly select crop 10s for each audio. + args = parse_args() + + if os.path.isdir(args.pretrained): + log_dir = os.path.dirname(args.pretrained) + else: + log_dir = os.path.dirname(os.path.dirname(args.pretrained)) + + args.log_level = logging.DEBUG if args.debug else logging.INFO + log_path = os.path.join(log_dir, 'out.log') + setup_logging(log_path, args.log_level) + params_file = os.path.join(log_dir, 'params.txt') + + seed = 3407 + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + + cudnn.benchmark = True + cudnn.deterministic = False + pretrained = 'openai' + amodel = find_params_value(params_file, 'amodel') + tmodel = find_params_value(params_file, 'tmodel') + + if amodel is None or tmodel is None: + raise ValueError('model type not found in params file') + + # set up dummy values for args + args.parallel_eval = False + args.rank = 0 + args.local_rank = 0 + args.world_size = 1 + args.val_frequency = 1 + args.epochs = 1 + args.precision = 'fp32' + args.save_logs = True + args.wandb = args.report_to == 'wandb' + args.class_index_dict = None + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + args.device = device + + if args.remotedata: + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + ) + args.val_data = get_tar_path_from_dataset_name( + args.datasetnames, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + ) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) # a hack to get model_cfg + + data = get_data(args, model_cfg=model_cfg) # (yusong): hack: no model_cfg needed to get data + + writer = None # if use tensorboard, initalize writer here + + if args.wandb: + assert wandb is not None, "Please install wandb." + + # # find the line with "wandb_notes" and get the value + # wandb_notes = find_params_value(params_file, 'wandb_notes') + # if wandb_notes is None: + # print(f'wandb_notes not found in params file: {params_file}, set to timestamp.') + # wandb_notes = f'experiment_{time.strftime("%Y%m%d-%H%M%S")}' + # wandb_notes = wandb_notes + '-eval-retrieval' + wandb_notes = args.wandb_notes + + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + if args.wandb_id is not None: + wandb.init( + project="clap", + id=args.wandb_id, + resume=True + ) + else: + wandb.init( + project="clap", + notes=wandb_notes, + name=wandb_notes, + tags=[], + config=vars(args), + ) + logging.debug("Finished loading wandb.") + + if os.path.isdir(args.pretrained): + all_model_checkpoints = sorted(glob.glob(os.path.join(log_dir, 'checkpoints', '*.pt')), key=os.path.getmtime) + else: + all_model_checkpoints = [args.pretrained] + for model_path in all_model_checkpoints: + args.checkpoint_path = os.path.dirname(model_path) + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision='fp32', + device=device, + jit=False, + force_quick_gelu=False, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + # load model + checkpoint = torch.load(model_path, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module."):]: v for k, v in sd.items()} + model.load_state_dict(sd) + logging.info( + f"=> resuming checkpoint '{model_path}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + start_epoch = 0 + + model.to(device) + model.eval() + for param in model.parameters(): + param.requires_grad = False + + evaluate_zeroshot(model, data, start_epoch, args, writer) diff --git a/src/laion_clap/hook.py b/src/laion_clap/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..d86e942eaf44dfbe0d598bf789e21950e9974778 --- /dev/null +++ b/src/laion_clap/hook.py @@ -0,0 +1,219 @@ +""" +Contrastive Language-Audio Pretraining Model from LAION +-------------------------------------------------------- +Paper: https://arxiv.org/abs/2211.06687 +Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui +Support: LAION +""" +import os +import torch +import librosa +from clap_module import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 + +from transformers import RobertaTokenizer +import wget +from clap_module.factory import load_state_dict + + +class CLAP_Module(torch.nn.Module): + def __init__(self, enable_fusion=False, device=None, amodel='HTSAT-tiny', tmodel='roberta') -> None: + """Initialize CLAP Model + + Parameters + ---------- + enable_fusion: bool + if true, it will create the fusion clap model, otherwise non-fusion clap model (default: false) + device: str + if None, it will automatically detect the device (gpu or cpu) + amodel: str + audio encoder architecture, default: HTSAT-tiny + tmodel: str + text encoder architecture, default: roberta + """ + super(CLAP_Module, self).__init__() + if device is None: + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + precision = 'fp32' + + if enable_fusion: + fusion_type = 'aff_2d' + model, model_cfg = create_model( + amodel, + tmodel, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type + ) + else: + model, model_cfg = create_model( + amodel, + tmodel, + precision=precision, + device=device, + enable_fusion=enable_fusion + ) + self.enable_fusion = enable_fusion + self.model = model + self.model_cfg = model_cfg + self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return result + + def load_ckpt(self, ckpt = None, model_id = -1, verbose = True): + """Load the pretrained checkpoint of CLAP model + + Parameters + ---------- + ckpt: str + if ckpt is specified, the model will load this ckpt, otherwise the model will download the ckpt from zenodo. \n + For fusion model, it will download the 630k+audioset fusion model (id=3). For non-fusion model, it will download the 630k+audioset model (id=1). + model_id: + if model_id is specified, you can download our best ckpt, as: + id = 0 --> 630k non-fusion ckpt \n + id = 1 --> 630k+audioset non-fusion ckpt \n + id = 2 --> 630k fusion ckpt \n + id = 3 --> 630k+audioset fusion ckpt \n + Note that if your model is specied as non-fusion model but you download a fusion model ckpt, you will face an error. + """ + download_link = 'https://huggingface.co/lukewys/laion_clap/resolve/main/' + download_names = [ + '630k-best.pt', + '630k-audioset-best.pt', + '630k-fusion-best.pt', + '630k-audioset-fusion-best.pt' + ] + if ckpt is not None: + print(f'Load the specified checkpoint {ckpt} from users.') + else: + print(f'Load our best checkpoint in the paper.') + if model_id == -1: + model_id = 3 if self.enable_fusion else 1 + package_dir = os.path.dirname(os.path.realpath(__file__)) + weight_file_name = download_names[model_id] + ckpt = os.path.join(package_dir, weight_file_name) + if os.path.exists(ckpt): + print(f'The checkpoint is already downloaded') + else: + print('Downloading laion_clap weight files...') + ckpt = wget.download(download_link + weight_file_name, os.path.dirname(ckpt)) + print('Download completed!') + print('Load Checkpoint...') + ckpt = load_state_dict(ckpt, skip_params=True) + self.model.load_state_dict(ckpt) + if verbose: + param_names = [n for n, p in self.model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + + def get_audio_embedding_from_filelist(self, x, use_tensor=False): + """get audio embeddings from the audio file list + + Parameters + ---------- + x: List[str] (N,): + an audio file list to extract features, audio files can have different lengths (as we have the feature fusion machanism) + use_tensor: boolean: + if True, it will return the torch tensor, preserving the gradient (default: False). + Returns + ---------- + audio_embed : numpy.darray | torch.Tensor (N,D): + audio embeddings that extracted from audio files + """ + self.model.eval() + audio_input = [] + for f in x: + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, _ = librosa.load(f, sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + temp_dict = {} + temp_dict = get_audio_features( + temp_dict, audio_waveform, 480000, + data_truncating='fusion' if self.enable_fusion else 'rand_trunc', + data_filling='repeatpad', + audio_cfg=self.model_cfg['audio_cfg'], + require_grad=audio_waveform.requires_grad + ) + audio_input.append(temp_dict) + audio_embed = self.model.get_audio_embedding(audio_input) + if not use_tensor: + audio_embed = audio_embed.detach().cpu().numpy() + return audio_embed + + + def get_audio_embedding_from_data(self, x, use_tensor=False): + """get audio embeddings from the audio data + + Parameters + ---------- + x: np.darray | torch.Tensor (N,T): + audio data, must be mono audio tracks. + use_tensor: boolean: + if True, x should be the tensor input and the output will be the tesnor, preserving the gradient (default: False). + Note that if 'use tensor' is set to True, it will not do the quantize of the audio waveform (otherwise the gradient will not be preserved). + Returns + ---------- + audio embed: numpy.darray | torch.Tensor (N,D): + audio embeddings that extracted from audio files + """ + self.model.eval() + audio_input = [] + for audio_waveform in x: + # quantize + if not use_tensor: + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + temp_dict = {} + temp_dict = get_audio_features( + temp_dict, audio_waveform, 480000, + data_truncating='fusion' if self.enable_fusion else 'rand_trunc', + data_filling='repeatpad', + audio_cfg=self.model_cfg['audio_cfg'], + require_grad=audio_waveform.requires_grad + ) + audio_input.append(temp_dict) + audio_embed = self.model.get_audio_embedding(audio_input) + if not use_tensor: + audio_embed = audio_embed.detach().cpu().numpy() + return audio_embed + + def get_text_embedding(self, x, tokenizer = None, use_tensor = False): + """get text embeddings from texts + + Parameters + ---------- + x: List[str] (N,): + text list + tokenizer: func: + the tokenizer function, if not provided (None), will use the default Roberta tokenizer. + use_tensor: boolean: + if True, the output will be the tesnor, preserving the gradient (default: False). + Returns + ---------- + text_embed : numpy.darray | torch.Tensor (N,D): + text embeddings that extracted from texts + """ + self.model.eval() + if tokenizer is not None: + text_input = tokenizer(x) + else: + text_input = self.tokenizer(x) + text_embed = self.model.get_text_embedding(text_input) + if not use_tensor: + text_embed = text_embed.detach().cpu().numpy() + return text_embed + + diff --git a/src/laion_clap/inference.py b/src/laion_clap/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..759c8e32ab9cbd4925de0b0aaad10766f010da56 --- /dev/null +++ b/src/laion_clap/inference.py @@ -0,0 +1,41 @@ +import numpy as np +import librosa +import torch +from src import laion_clap +from glob import glob +import pandas as pd +from ..config.configs import ProjectPaths +import pickle + + +class AudioEncoder(laion_clap.CLAP_Module): + def __init__(self) -> None: + super().__init__(enable_fusion=False, amodel='HTSAT-base') + self.load_ckpt(ckpt=ProjectPaths.MODEL_PATH) + + def extract_audio_representaion(self, file_name): + audio_data, _ = librosa.load(file_name, sr=48000) + audio_data = audio_data.reshape(1, -1) + with torch.no_grad(): + audio_embed = self.get_audio_embedding_from_data(x=audio_data, use_tensor=False) + return audio_embed + + def extract_bulk_audio_representaions(self, save=False): + music_files = glob(str(ProjectPaths.DATA_DIR.joinpath("audio", "*.wav"))) + song_names = [k.split("/")[-1] for k in music_files] + music_data = np.zeros((len(music_files), 512), dtype=np.float32) + for m in range(music_data.shape[0]): + music_data[m] = self.extract_audio_representaion(music_files[m]) + + if not save: + return music_data, song_names + + else: + np.save(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) + with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl", "rb")) as writer: + pickle.dump(song_names, writer) + + def extract_text_representation(self, text): + text_data = [text] + text_embed = self.get_text_embedding(text_data) + return text_embed \ No newline at end of file diff --git a/src/laion_clap/training/__init__.py b/src/laion_clap/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/laion_clap/training/audioset_textmap.npy b/src/laion_clap/training/audioset_textmap.npy new file mode 100644 index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b --- /dev/null +++ b/src/laion_clap/training/audioset_textmap.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b +size 84448 diff --git a/src/laion_clap/training/data.py b/src/laion_clap/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..fad90621575f45388feb5c015cfff5d1f5fe2146 --- /dev/null +++ b/src/laion_clap/training/data.py @@ -0,0 +1,895 @@ +import ast +import json +import logging +import math +import os +import random +import h5py +from dataclasses import dataclass +import braceexpand +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms +import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +from pathlib import Path +import wget +import tempfile +import copy +from contextlib import suppress + +from clap_module.utils import get_tar_path_from_dataset_name, dataset_split +from clap_module.utils import load_p, load_class_label +from clap_module import tokenize as clip_tokenizer +from transformers import BertTokenizer +from transformers import RobertaTokenizer +from transformers import BartTokenizer + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") +bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") + +def tokenizer(text, tmodel="roberta", max_length=77): + """tokenizer for different models + tmodel is default to roberta as it is the best model for our task + max_length is default to 77 from the OpenAI CLIP parameters + We assume text to be a single string, but it can also be a list of strings + """ + if tmodel == "transformer": + return clip_tokenizer(text).squeeze(0) + + elif tmodel == "bert": + result = bert_tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + elif tmodel == "roberta": + result = roberta_tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + elif tmodel == "bart": + result = bart_tokenizer( + text, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + + +def int16_to_float32_torch(x): + return (x / 32767.0).type(torch.float32) + + +def float32_to_int16_torch(x): + x = torch.clamp(x, min=-1., max=1.) + return (x * 32767.).type(torch.int16) + + +# For Toy Dataset +class ToyDataset(Dataset): + def __init__(self, index_path, ipc, config, eval_mode=False): + """Toy Dataset for testing the audioset input with text labels + Parameters + ---------- + index_path: str + the link to the h5 file of each audio + idc: str + the link to the npy file, the number of samples in each class + config: dict + the audio cfg file + eval_model (bool): to indicate if the dataset is a testing dataset + """ + self.audio_cfg = config["audio_cfg"] + self.text_cfg = config["text_cfg"] + self.fp = h5py.File(index_path, "r") + self.ipc = np.load(ipc, allow_pickle=True) + self.total_size = len(self.fp["audio_name"]) + self.classes_num = self.audio_cfg["class_num"] + self.eval_mode = eval_mode + + if not eval_mode: + self.generate_queue() + else: + self.queue = [] + for i in range(self.total_size): + target = self.fp["target"][i] + if np.sum(target) > 0: + self.queue.append(i) + self.total_size = len(self.queue) + logging.info("total dataset size: %d" % (self.total_size)) + logging.info("class num: %d" % (self.classes_num)) + + def time_shifting(self, x): + frame_num = len(x) + shift_len = random.randint(0, frame_num - 1) + new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) + return new_sample + + def generate_queue(self): + self.queue = [] + while len(self.queue) < self.total_size: + class_set = [*range(self.classes_num)] + random.shuffle(class_set) + self.queue += [ + self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set + ] + self.queue = self.queue[: self.total_size] + + logging.info("queue regenerated:%s" % (self.queue[-5:])) + + def crop_wav(self, x): + crop_size = self.audio_cfg["crop_size"] + crop_pos = random.randint(0, len(x) - crop_size - 1) + return x[crop_pos: crop_pos + crop_size] + + def prompt_text(self, target): + events = _AUDIOSET_MAP[np.where(target > 0)] + event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] + text = tokenizer(event_text)[0] + return text + + def __getitem__(self, index): + """Load waveform, text, and target of an audio clip + + Parameters + ---------- + index: int + the index number + Return + ------ + output: dict { + "hdf5_path": str, + "index_in_hdf5": int, + "audio_name": str, + "waveform": list (audio_length,), + "target": list (class_num, ), + "text": torch.tensor (context_length,) + } + the output dictionary + """ + s_index = self.queue[index] + + audio_name = self.fp["audio_name"][s_index].decode() + # Hardcode here CHANGE + hdf5_path = ( + self.fp["hdf5_path"][s_index] + .decode() + .replace( + "../workspace", + "/home/la/kechen/Research/ke_zsasp/workspace", + ) + ) + r_idx = self.fp["index_in_hdf5"][s_index] + target = self.fp["target"][s_index].astype(np.float32) + text = self.prompt_text(target) + with h5py.File(hdf5_path, "r") as f: + waveform = int16_to_float32(f["waveform"][r_idx])[ + : self.audio_cfg["clip_samples"] + ] + assert ( + len(waveform) == self.audio_cfg["clip_samples"] + ), "The sample length is not match" + # Time shift + # if (self.config.enable_time_shift) and (not self.eval_mode): + # waveform = self.time_shifting(waveform) + # # Label Enhance + # if (self.config.crop_size is not None) and (not self.eval_mode): + # waveform = self.crop_wav(waveform) + # # the label enhance rate is fixed 0.5 + # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: + # kidx = np.where(target)[0] + # for k in kidx: + # for add_key in self.class_map[k][1]: + # target[add_key] = 1.0 + # if len(self.class_map[k][2]) > 0: + # add_key = random.choice(self.class_map[k][2]) + # target[add_key] = 1.0 + + # missing the text input + mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] + mel_spec = torch.cat([mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0).cpu().numpy() + longer = random.choice([True, False]) + if longer == False: + mel_spec[1:, :, :] = 0.0 + data_dict = { + "hdf5_path": hdf5_path, + "index_in_hdf5": r_idx, + "audio_name": audio_name, + "waveform": waveform, + "class_label": target, + "text": text, + "longer": longer, + "mel_fusion": mel_spec + } + return data_dict + + def __len__(self): + return self.total_size + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def get_dataset_size(shards, sizefilepath_=None, is_local=True): + if isinstance(shards, list): + size_list = [] + for s in shards: + size_list.append( + get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] + ) + else: + if not is_local: + for n in dataset_split.keys(): + if n in shards.split("/"): + break + for s in dataset_split[n]: + if s in shards.split("/"): + break + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + if sizefilepath_ is not None: + sizes = json.load(open(sizefilepath_, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) + for shard in shards_list + ] + ) + else: + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [int(sizes[os.path.basename(shard)]) for shard in shards_list] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + raise Exception( + f"Cannot find sizes file for dataset {shards}. Please specify the path to the file." + ) + # total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # cc3m-train: 2905954 + # cc12m: 10968539 + # LAION-400m: 407332084 + num_shards = len(shards_list) + if isinstance(shards, list): + return sum(size_list), len(shards) + else: + return total_size, num_shards + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def sample_prop(sizefile, inputs, proportion, is_local=True): + """ + Sample a proportion of the data. + """ + file_path_dict = { + os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] + for i in range(len(inputs)) + } + sampled_filepath_dict = {} + sampled_size_dict = {} + if not is_local: + if os.path.exists("sizes.json"): + os.remove("sizes.json") + wget.download(sizefile, "sizes.json") + sizefile = "sizes.json" + with open(sizefile, "r", encoding="UTF-8") as f: + load_dict = json.load(f) + L = int(len(file_path_dict) * proportion) + subkeys = random.sample(file_path_dict.keys(), L) + for k in subkeys: + sampled_size_dict[k] = load_dict[k] + sampled_filepath_dict[k] = file_path_dict[k] + return ( + sum(sampled_size_dict.values()), + L, + [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], + sampled_size_dict, + ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel_tf = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg['sample_rate'], + n_fft=audio_cfg['window_size'], + win_length=audio_cfg['window_size'], + hop_length=audio_cfg['hop_size'], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=audio_cfg['mel_bins'], + f_min=audio_cfg['fmin'], + f_max=audio_cfg['fmax'] + ).to(audio_data.device) + + mel = mel_tf(audio_data) + # Align to librosa: + # librosa_melspec = librosa.feature.melspectrogram( + # waveform, + # sr=audio_cfg['sample_rate'], + # n_fft=audio_cfg['window_size'], + # hop_length=audio_cfg['hop_size'], + # win_length=audio_cfg['window_size'], + # center=True, + # pad_mode="reflect", + # power=2.0, + # n_mels=audio_cfg['mel_bins'], + # norm=None, + # htk=True, + # f_min=audio_cfg['fmin'], + # f_max=audio_cfg['fmax'] + # ) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg, require_grad=False): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + require_grad: whether to require gradient for audio data. + This is useful when we want to apply gradient-based classifier-guidance. + """ + grad_fn = suppress if require_grad else torch.no_grad + with grad_fn(): + if len(audio_data) > max_len: + if data_truncating == "rand_trunc": + longer = torch.tensor([True]) + elif data_truncating == "fusion": + # fusion + mel = get_mel(audio_data, audio_cfg) + # split to three parts + chunk_frames = max_len // audio_cfg['hop_size'] + 1 # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is + # larger than max_len but smaller than max_len+hop_size. + # In this case, we just use the whole audio. + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + else: + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # select mel + mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] + + # shrink the mel + mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, audio_cfg['mel_bins']])(mel[None])[0] + # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") + + # stack + mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([True]) + else: + raise NotImplementedError( + f"data_truncating {data_truncating} not implemented" + ) + # random crop to max_len (for compatibility) + overflow = len(audio_data) - max_len + idx = np.random.randint(0, overflow + 1) + audio_data = audio_data[idx: idx + max_len] + + else: # padding if too short + if len(audio_data) < max_len: # do nothing if equal + if data_filling == "repeatpad": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat) + # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "pad": + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "repeat": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat + 1)[:max_len] + else: + raise NotImplementedError( + f"data_filling {data_filling} not implemented" + ) + if data_truncating == 'fusion': + mel = get_mel(audio_data, audio_cfg) + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def select_text(json_dict_raw, text_augment_selection): + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + return texts + + +def preprocess_single( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + tmodel, + class_index_dict, + data_filling, + data_truncating, + text_augment_selection, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sample[audio_ext] + audio_data = int16_to_float32_torch(float32_to_int16_torch(audio_data[0])) + + sample = get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, audio_cfg) + del sample[audio_ext] + + json_dict_raw = sample[text_ext] + + texts = select_text(json_dict_raw, text_augment_selection) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts, tmodel=tmodel) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + + # in case the re-written version is wrong, here is the old version: + # sample["class_label"] = np.zeros(len(class_index_dict.keys())) + # for x in json_dict_raw["tag"]: + # sample["class_label"][class_index_dict[x]] = 1 + # sample["class_label"] = torch.tensor(sample["class_label"]).float() + + class_labels = np.zeros(len(class_index_dict)) + class_labels[np.in1d(list(class_index_dict.keys()), json_dict_raw["tag"])] = 1 + sample["class_label"] = torch.tensor(class_labels).float() + + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn_with_preprocess(batch, + audio_ext, + text_ext, + max_len, + audio_cfg, + args, + ): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + + class_index_dict = copy.deepcopy(args.class_index_dict) # To avoid deadlock in multiprocessing + data_filling = args.data_filling + data_truncating = args.data_truncating + text_augment_selection = args.text_augment_selection + tmodel = args.tmodel + + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + data_preprocessed = [] + + for sample in batch: + data_preprocessed.append( + preprocess_single(sample, audio_ext, text_ext, max_len, audio_cfg, tmodel, class_index_dict, data_filling, + data_truncating, text_augment_selection)) + + batch_dict = {} + for k in data_preprocessed[0].keys(): + if isinstance(data_preprocessed[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in data_preprocessed[0][k].keys(): + tmp = [] + for i in range(len(data_preprocessed)): + tmp.append(data_preprocessed[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(data_preprocessed[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in data_preprocessed]) + elif isinstance(data_preprocessed[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in data_preprocessed])) + else: + batch_dict[k] = [sample[k] for sample in data_preprocessed] + del data_preprocessed + return batch_dict + + +def get_wds_dataset( + args, + model_cfg, + is_train, + audio_ext="flac", + text_ext="json", + max_len=480000, + proportion=1.0, + sizefilepath_=None, + is_local=None, +): + """ + Get a dataset for wdsdataloader. + """ + if is_local is None and (not args.remotedata is None): + is_local = not args.remotedata + + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + + if not sizefilepath_ is None: + sizefilepath = sizefilepath_ + else: + sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + + if proportion != 1.0: + num_samples, num_shards, input_shards, _ = sample_prop( + sizefilepath, input_shards, proportion, is_local=is_local + ) + else: + num_samples, num_shards = get_dataset_size( + input_shards, sizefilepath_=sizefilepath_, is_local=is_local + ) + + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + else: + num_samples = ( + args.val_num_samples or 0 + ) # eval will just exhaust the iterator if not specified + + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards + # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node + if is_train or args.parallel_eval: + pipeline.extend( + [ + wds.detshuffle( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + ), + wds.split_by_node, + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker at each node + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + rng=random.Random(args.seed), + ), + # wds.repeatedly, # FIXME determine if this is beneficial + ] + ) + else: + pipeline.extend( + [ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ] + ) + + pipeline.append( + wds.decode(wds.torch_audio), + ) + + pipeline.append( + wds.batched( + args.batch_size, + partial=not (is_train or args.parallel_eval), + collation_fn=partial(collate_fn_with_preprocess, + audio_ext=audio_ext, + text_ext=text_ext, + max_len=max_len, + audio_cfg=model_cfg['audio_cfg'], + args=args, + ), + + ) + ) + + dataset = wds.DataPipeline(*pipeline) + if is_train or args.parallel_eval: + # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. + # (yusong): See comments below. + # roll over and repeat a few samples to get same number of full batches on each node + global_batch_size = args.batch_size * args.world_size + num_batches = math.ceil(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = math.ceil( + num_batches / num_workers + ) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch( + num_worker_batches + ) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + kwargs = {} + if args.horovod: # multi-node training on summit + kwargs["multiprocessing_context"] = "forkserver" + + if is_train: + if args.prefetch_factor: + prefetch_factor = args.prefetch_factor + else: + prefetch_factor = max(2, args.batch_size // args.workers) + else: + prefetch_factor = 2 + + dataloader = wds.WebLoader( + dataset, + batch_size=None, + shuffle=False, + num_workers=args.workers, + pin_memory=True, + prefetch_factor=prefetch_factor, + **kwargs + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = [n for n in args.datasetnames if n not in excluded_eval_datasets] \ + if excluded_eval_datasets else args.datasetnames + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/src/laion_clap/training/distributed.py b/src/laion_clap/training/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..adb0e927a64dbe7fc83fecf65be054ac6bd28a94 --- /dev/null +++ b/src/laion_clap/training/distributed.py @@ -0,0 +1,139 @@ +import os + +import torch +import socket + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): + return True + else: + return False + + +def is_using_distributed(): + if 'WORLD_SIZE' in os.environ: + return int(os.environ['WORLD_SIZE']) > 1 + if 'SLURM_NTASKS' in os.environ: + return int(os.environ['SLURM_NTASKS']) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ('SLURM_LOCALID', 'MPI_LOCALRANKID', 'OMPI_COMM_WORLD_LOCAL_RANK', 'LOCAL_RANK'): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ('SLURM_PROCID', 'PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'RANK'): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ('SLURM_NTASKS', 'PMI_SIZE', 'OMPI_COMM_WORLD_SIZE', 'WORLD_SIZE'): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + args.distributed = True + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + print(f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}") + elif is_using_distributed(): + if 'SLURM_PROCID' in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + elif 'OMPI_COMM_WORLD_SIZE' in os.environ: # using Summit cluster + world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + print(f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}") + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = 'cuda:%d' % args.local_rank + else: + device = 'cuda:0' + torch.cuda.set_device(device) + else: + device = 'cpu' + args.device = device + device = torch.device(device) + return device diff --git a/src/laion_clap/training/imagenet_zeroshot_data.py b/src/laion_clap/training/imagenet_zeroshot_data.py new file mode 100644 index 0000000000000000000000000000000000000000..a78987448805afc228b2941302a2894818cac497 --- /dev/null +++ b/src/laion_clap/training/imagenet_zeroshot_data.py @@ -0,0 +1,254 @@ +# NOTE: This script is currently not supported for CLAP. + +imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", + "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", + "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", + "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", + "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", + "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", + "box turtle", "banded gecko", "green iguana", "Carolina anole", + "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", + "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", + "American alligator", "triceratops", "worm snake", "ring-necked snake", + "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", + "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", + "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", + "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", + "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", + "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", + "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", + "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", + "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", + "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", + "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", + "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", + "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", + "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", + "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", + "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", + "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", + "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", + "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", + "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", + "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", + "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", + "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", + "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", + "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", + "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", + "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", + "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", + "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", + "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", + "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", + "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", + "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", + "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", + "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", + "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", + "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", + "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", + "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", + "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", + "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", + "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", + "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", + "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", + "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", + "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", + "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", + "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", + "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", + "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", + "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", + "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", + "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", + "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", + "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", + "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", + "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", + "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", + "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", + "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", + "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", + "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", + "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", + "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", + "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", + "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", + "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", + "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", + "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", + "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", + "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", + "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", + "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", + "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", + "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", + "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", + "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", + "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", + "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", + "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", + "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", + "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", + "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", + "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", + "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", + "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", + "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", + "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", + "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", + "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", + "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", + "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", + "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", + "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", + "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", + "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", + "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", + "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", + "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", + "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", + "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", + "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", + "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", + "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", + "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", + "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", + "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", + "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", + "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", + "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", + "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", + "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", + "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", + "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", + "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", + "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", + "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", + "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", + "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", + "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", + "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", + "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", + "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", + "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", + "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", + "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", + "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", + "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", + "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", + "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", + "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", + "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", + "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", + "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", + "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", + "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", + "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", + "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", + "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", + "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", + "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", + "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] + + + + + +openai_imagenet_template = [ + lambda c: f'a bad photo of a {c}.', + lambda c: f'a photo of many {c}.', + lambda c: f'a sculpture of a {c}.', + lambda c: f'a photo of the hard to see {c}.', + lambda c: f'a low resolution photo of the {c}.', + lambda c: f'a rendering of a {c}.', + lambda c: f'graffiti of a {c}.', + lambda c: f'a bad photo of the {c}.', + lambda c: f'a cropped photo of the {c}.', + lambda c: f'a tattoo of a {c}.', + lambda c: f'the embroidered {c}.', + lambda c: f'a photo of a hard to see {c}.', + lambda c: f'a bright photo of a {c}.', + lambda c: f'a photo of a clean {c}.', + lambda c: f'a photo of a dirty {c}.', + lambda c: f'a dark photo of the {c}.', + lambda c: f'a drawing of a {c}.', + lambda c: f'a photo of my {c}.', + lambda c: f'the plastic {c}.', + lambda c: f'a photo of the cool {c}.', + lambda c: f'a close-up photo of a {c}.', + lambda c: f'a black and white photo of the {c}.', + lambda c: f'a painting of the {c}.', + lambda c: f'a painting of a {c}.', + lambda c: f'a pixelated photo of the {c}.', + lambda c: f'a sculpture of the {c}.', + lambda c: f'a bright photo of the {c}.', + lambda c: f'a cropped photo of a {c}.', + lambda c: f'a plastic {c}.', + lambda c: f'a photo of the dirty {c}.', + lambda c: f'a jpeg corrupted photo of a {c}.', + lambda c: f'a blurry photo of the {c}.', + lambda c: f'a photo of the {c}.', + lambda c: f'a good photo of the {c}.', + lambda c: f'a rendering of the {c}.', + lambda c: f'a {c} in a video game.', + lambda c: f'a photo of one {c}.', + lambda c: f'a doodle of a {c}.', + lambda c: f'a close-up photo of the {c}.', + lambda c: f'a photo of a {c}.', + lambda c: f'the origami {c}.', + lambda c: f'the {c} in a video game.', + lambda c: f'a sketch of a {c}.', + lambda c: f'a doodle of the {c}.', + lambda c: f'a origami {c}.', + lambda c: f'a low resolution photo of a {c}.', + lambda c: f'the toy {c}.', + lambda c: f'a rendition of the {c}.', + lambda c: f'a photo of the clean {c}.', + lambda c: f'a photo of a large {c}.', + lambda c: f'a rendition of a {c}.', + lambda c: f'a photo of a nice {c}.', + lambda c: f'a photo of a weird {c}.', + lambda c: f'a blurry photo of a {c}.', + lambda c: f'a cartoon {c}.', + lambda c: f'art of a {c}.', + lambda c: f'a sketch of the {c}.', + lambda c: f'a embroidered {c}.', + lambda c: f'a pixelated photo of a {c}.', + lambda c: f'itap of the {c}.', + lambda c: f'a jpeg corrupted photo of the {c}.', + lambda c: f'a good photo of a {c}.', + lambda c: f'a plushie {c}.', + lambda c: f'a photo of the nice {c}.', + lambda c: f'a photo of the small {c}.', + lambda c: f'a photo of the weird {c}.', + lambda c: f'the cartoon {c}.', + lambda c: f'art of the {c}.', + lambda c: f'a drawing of the {c}.', + lambda c: f'a photo of the large {c}.', + lambda c: f'a black and white photo of a {c}.', + lambda c: f'the plushie {c}.', + lambda c: f'a dark photo of a {c}.', + lambda c: f'itap of a {c}.', + lambda c: f'graffiti of the {c}.', + lambda c: f'a toy {c}.', + lambda c: f'itap of my {c}.', + lambda c: f'a photo of a cool {c}.', + lambda c: f'a photo of a small {c}.', + lambda c: f'a tattoo of the {c}.', +] diff --git a/src/laion_clap/training/infer_demo.py b/src/laion_clap/training/infer_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..b39f3f43fc9f295b244bc1ca2b444f4969488cba --- /dev/null +++ b/src/laion_clap/training/infer_demo.py @@ -0,0 +1,92 @@ +import torch +import librosa +from clap_module import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained('roberta-base') +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + +def infer_text(): + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + precision = 'fp32' + amodel = 'HTSAT-tiny' # or 'PANN-14' + tmodel = 'roberta' # the best text encoder in our training + enable_fusion = True # False if you do not want to use the fusion model + fusion_type = 'aff_2d' + pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded. + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type + ) + # load the text, can be a list (i.e. batch size) + text_data = ["I love the contrastive learning", "I love the pretrain model"] + # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 + text_data = tokenizer(text_data) + model.eval() + text_embed = model.get_text_embedding(text_data) + text_embed = text_embed.detach().cpu().numpy() + print(text_embed) + print(text_embed.shape) + +def infer_audio(): + + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + precision = 'fp32' + amodel = 'HTSAT-tiny' # or 'PANN-14' + tmodel = 'roberta' # the best text encoder in our training + enable_fusion = True # False if you do not want to use the fusion model + fusion_type = 'aff_2d' + pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded. + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type + ) + + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, sr = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + audio_dict = {} + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + audio_dict = get_audio_features( + audio_dict, audio_waveform, 480000, + data_truncating='fusion', + data_filling='repeatpad', + audio_cfg=model_cfg['audio_cfg'] + ) + model.eval() + # can send a list to the model, to process many audio tracks in one time (i.e. batch size) + audio_embed = model.get_audio_embedding([audio_dict]) + audio_embed = audio_embed.detach().cpu().numpy() + print(audio_embed) + print(audio_embed.shape) + + + +if __name__ == "__main__": + infer_text() + # infer_audio() diff --git a/src/laion_clap/training/logger.py b/src/laion_clap/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9abed92568d459cbc8d6094ae3901935d89621 --- /dev/null +++ b/src/laion_clap/training/logger.py @@ -0,0 +1,26 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + hostname = socket.gethostname() + formatter = logging.Formatter( + f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + else: + formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) + diff --git a/src/laion_clap/training/lp_main.py b/src/laion_clap/training/lp_main.py new file mode 100644 index 0000000000000000000000000000000000000000..1e88f7356950ef03c78ca4d88681eb78ff1b4f6a --- /dev/null +++ b/src/laion_clap/training/lp_main.py @@ -0,0 +1,643 @@ +import logging +import os +import random +from datetime import datetime +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.cuda.amp import GradScaler +import time + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from clap_module import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_train import train_one_epoch, evaluate +from clap_module.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer +from clap_module.utils import load_p, load_class_label +from clap_module.linear_probe import LinearProbe + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + +def config_lp_optimizer(model, data, args): + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + in_clap = ( + lambda n, p: n.startswith("clap_model") + ) + + named_parameters = list(model.named_parameters()) + + optimizer = {} + scheduler = {} + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + ] + + if args.freeze_text: + logging.info("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + if not args.lp_freeze: + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + # (yusong): we do not split the learning rate anymore + # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer["text"] = pretrained_params_optimizer + optimizer["audio"] = new_params_optimizer + scheduler["text"] = pretrained_params_scheduler + scheduler["audio"] = new_params_scheduler + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + + optimizer["clap"] = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + scheduler["clap"] = cosine_lr(optimizer["clap"], args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer["clap"] = hvd.DistributedOptimizer( + optimizer["clap"], named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) + + # linear probe optimizer + else: + lp_params = [p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad] + lp_optim = get_optimizer(lp_params, lr=args.lp_lr, betas=(args.beta1, args.beta2), eps=args.eps, momentum=0.9, + optimizer_name=args.optimizer) + optimizer["lp"] = lp_optim + + return optimizer, scheduler, text_freeze_parameters + + +def main(): + args = parse_args() + + time.sleep(args.sleep) + + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path+'-'+str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f'openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}') + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, freeze=args.lp_freeze, + in_ch=512, out_ch=args.lp_out_ch, + act=args.lp_act + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(model, data, args) + + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) and not args.no_eval: + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/src/laion_clap/training/lp_train.py b/src/laion_clap/training/lp_train.py new file mode 100644 index 0000000000000000000000000000000000000000..d686336c824dd34b45f056c414d761540141a46f --- /dev/null +++ b/src/laion_clap/training/lp_train.py @@ -0,0 +1,292 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from clap_module import LPLoss, LPMetrics, lp_gather_features +from clap_module.utils import do_mixup, get_mix_lambda +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None, extra_suffix="" +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = LPLoss(args.lp_loss) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch['class_label'] + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + if args.mixup: + # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 + mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(audio["waveform"]))).to(device) + class_label = do_mixup(class_label, mix_lambda) + else: + mix_lambda = None + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + pred = model(audio, mix_lambda=mix_lambda, device=device) + total_loss = loss(pred, class_label) + + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) + unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audio, dict): + batch_size = len(audio["waveform"]) + else: + batch_size = len(audio) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + if isinstance(optimizer, dict): + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = f"train{extra_suffix}/{name}" + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + +def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print('Evaluating...') + metric_names = args.lp_metrics.split(',') + eval_tool = LPMetrics(metric_names=metric_names) + + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + if args.parallel_eval: + dataloader, sampler = data["val"].dataloader, data["val"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + samples_per_val = dataloader.num_samples + else: + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + eval_info = { + 'pred': [], + 'target': [] + } + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch['class_label'] + + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + with autocast(): + pred = model(audio, device=device) + if args.parallel_eval: + pred, class_label = lp_gather_features(pred, class_label, args.world_size, args.horovod) + eval_info['pred'].append(pred) + eval_info['target'].append(class_label) + + num_samples += class_label.shape[0] + + if (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + + if is_master(args): + eval_info['pred'] = torch.cat(eval_info['pred'], 0).cpu() + eval_info['target'] = torch.cat(eval_info['target'], 0).cpu() + metric_dict = eval_tool.evaluate_mertics(eval_info['pred'], eval_info['target']) + metrics.update(metric_dict) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{m}: {round(metrics[m], 4):.4f}" ]) + for m in metrics + ] + ) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics diff --git a/src/laion_clap/training/main.py b/src/laion_clap/training/main.py new file mode 100644 index 0000000000000000000000000000000000000000..7c48b66e2287785961501dfd75f2b6f5d331c245 --- /dev/null +++ b/src/laion_clap/training/main.py @@ -0,0 +1,597 @@ +import logging +import os +import random +from datetime import datetime +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch.cuda.amp import GradScaler + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from clap_module import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr +from training.train import train_one_epoch, evaluate +from clap_module.utils import dataset_split, get_optimizer + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("transformer") + or n in ["positional_embedding", "text_projection"] + or n.startswith("token_embedding") + or n.startswith("ln_final") + or n.startswith("logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": + assert ( + args.pretrained == "" or args.pretrained is None + ), "bert/roberta/bart text encoder does not support pretrained models." + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path): + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use fp32 mixed-precision instead of FP16 and AMP in this model. " + "They will cause NaN loss and NaN gradients. " + "FP16 and AMP support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if 'text_branch' in n + ] + + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer = { + "pretrained": pretrained_params_optimizer, + "new": new_params_optimizer, + } + scheduler = { + "pretrained": pretrained_params_scheduler, + "new": new_params_scheduler, + } + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + optimizer = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + entity="clap", + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + # print(f'rank {args.rank}, Start Training') # (yusong): for debug + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + if args.split_opt: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + else: + opt_dict = {"optimizer": optimizer.state_dict()} + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/src/laion_clap/training/params.py b/src/laion_clap/training/params.py new file mode 100644 index 0000000000000000000000000000000000000000..84cd3b43104007a14835e4de9cd4521899ba6345 --- /dev/null +++ b/src/laion_clap/training/params.py @@ -0,0 +1,567 @@ +import argparse + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default="~/.cache/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", type=float, default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss" + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default='None', + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + parser.add_argument( + "--prefetch-factor", + type=int, + default=None, + help="The prefetch factor for dataloader. Larger value will use more memory and CPU but faster.", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/src/laion_clap/training/scheduler.py b/src/laion_clap/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..e0bfdf796c95df003582ede43f0511bd9181c1e4 --- /dev/null +++ b/src/laion_clap/training/scheduler.py @@ -0,0 +1,23 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + return _lr_adjuster \ No newline at end of file diff --git a/src/laion_clap/training/train.py b/src/laion_clap/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..06db94158fe15c0a17c95babf60b75d31e8893c7 --- /dev/null +++ b/src/laion_clap/training/train.py @@ -0,0 +1,781 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from clap_module import ClipLoss, gather_features +from .distributed import is_master + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + weight_loss_kappa=args.kappa, + ) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + # logging.info(f"batch {i} of {num_batches_per_epoch}") + step = num_batches_per_epoch * epoch + i + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + audios = batch # contains mel_spec, wavform, and longer list + texts = batch['text'] + # audios = audios.to(device=device, non_blocking=True) + # texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.clap_mlploss: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + logit_scale_t=logit_scale_t, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp + ) + else: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a + ) + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) + if args.clap_mlploss: + unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audios, dict): + batch_size = len(audios["waveform"]) + else: + batch_size = len(audios) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar_a = logit_scale_a.item() + logit_scale_scalar_t = logit_scale_t.item() + if isinstance(optimizer, dict): + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + + else: + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": optimizer.param_groups[0]["lr"], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print('Evaluating...') + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if args.val_dataset_names == ['Clotho', 'audiocaps']: + # if only clotho and audiocaps are used, then we will use a different evaluation function. + # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. + if args.parallel_eval: + # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. + raise NotImplementedError("Parallel evaluation not supported for eval only Clotho and audiocaps.") + val_metrics_per_dataset = evaluate_clotho_audiocaps(model, data, epoch, args, autocast, device, tb_writer) + for m in val_metrics_per_dataset.values(): + metrics.update(m) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + metrics = select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args) + elif "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = {} + if args.clap_mlploss: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [] + } # cumulative_loss = 0.0 + else: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [] + } # cumu + # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + texts = batch['text'] + # audios = audios.to(device=device, non_blocking=True) + + all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']])) + for name in all_names: + if name not in eval_info.keys(): + if args.clap_mlploss: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } + else: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [] + } + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.parallel_eval: + # multi-GPU eval + if args.clap_mlploss: + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss + ) + else: + ( + audio_features, + text_features, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss + ) + + if is_master(args): + num_samples += audio_features.shape[0] + for n in [*all_names, "all"]: + if n == "all": + eval_info[n]["all_audio_features"].append( + audio_features.cpu() + ) + eval_info[n]["all_text_features"].append( + text_features.cpu() + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu() + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu() + ) + else: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch['__url__']] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features"].append( + text_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # print(f'eval step {i}') # (yusong): for debug + + # cumulative_loss += total_loss * batch_size + # num_samples += batch_size + if is_master(args) and (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + if is_master(args): + val_metrics_per_dataset = {} + for n in eval_info.keys(): + if args.clap_mlploss: + metrics_single_dataset = get_metrics( + audio_features=torch.cat(eval_info[n]["all_audio_features"]), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + audio_features_mlp=torch.cat( + eval_info[n]["all_audio_features_mlp"] + ), + text_features_mlp=torch.cat(eval_info[n]["all_text_features_mlp"]), + logit_scale_t=logit_scale_t.cpu(), + mlp_loss=args.clap_mlploss + ) + else: + metrics_single_dataset = get_metrics( + audio_features=torch.cat(eval_info[n]["all_audio_features"]), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + mlp_loss=args.clap_mlploss + ) + val_metrics_per_dataset[n] = { + n + "/" + k: v for k, v in metrics_single_dataset.items() + } + metrics.update(val_metrics_per_dataset[n]) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) + for m in val_metrics_per_dataset.values() + ] + ) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics + + +def get_metrics( + audio_features, + text_features, + logit_scale_a, + audio_features_mlp=None, + text_features_mlp=None, + logit_scale_t=None, + mlp_loss=False +): + metrics = {} + if mlp_loss: + # Set up audio to text & text to audio similary matrice + a_logits_per_audio = ( + (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() + ) + a_logits_per_text = a_logits_per_audio.t().detach().cpu() + t_logits_per_audio = ( + (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() + ) + t_logits_per_text = t_logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = { + "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, + "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, + } + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + else: + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} + + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[1] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + return metrics + + +def evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer=None +): + """ + Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. + 1. for text-to-audio retrieval, do 5 times and average the results + 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text + 3. for map@10 in audio-to-text retrieval: + 3.1: sort the rank of 5 text + 3.2: exclude the rank >=10 (0-index) + 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). + (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. + (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. + """ + # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. + dataloader = data["val"].dataloader + with torch.no_grad(): + eval_info = {} + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + + # each item in the list has 5 texts + if args.tmodel == "transformer": + from clap_module import tokenize + texts = [tokenize(t) for t in batch['full_text']] + texts = torch.cat(texts) + else: + from .data import tokenizer + texts = [tokenizer(t, tmodel=args.tmodel) for t in batch['full_text']] # 5 texts for each audio + texts = {k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()} # 5 x batch + + # audios = audios.to(device=device, non_blocking=True) + + # batch['__url__'] contains the path to the data tar this sample is from + # So, b.split("/")[-3:-1] will get you '-' + all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']])) + for name in all_names: + if name not in eval_info.keys(): + # we will not use mlp outputs even if args.clap_mlploss=True + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [] + } + with autocast(): + audio_features = model(audios, None, device) + text_features = model(None, texts, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + all_names = list(set(["-".join(b.split("/")[-3:-1]) for b in batch['__url__']])) + for n in all_names: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch['__url__']] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # (yusong) please double-check. This is for selecting 5 text features at once. + # because idx is a list of indices in size of num_samples, + # and text_features is a tensor of size (5*num_samples, dim) + # so we need to select 5 consecutive indices at once for a single index in idx. + eval_info[n]["all_text_features"].append( + text_features.cpu().reshape([-1, 5, text_features.shape[1]]).index_select( + 0, torch.tensor(idx).long() + ).reshape([-1, text_features.shape[1]]) + ) + + val_metrics_all = {} + + for n in eval_info.keys(): + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) + text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) + + logits_per_audio = (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + logits_per_text = logits_per_audio.t().detach().cpu() + + # logits_per_audio shape: [num_samples, num_samples*5] + # logits_per_text shape: [num_samples*5, num_samples] + + logging.info(f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " + f"logits_per_text shape: {logits_per_text.shape}") + + metrics = {} + num_samples = audio_features.shape[0] + metrics[f"num_samples"] = num_samples + + # (yusong) the following code is very important, please double-check: + # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] + # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + # Those two are retrieving one of the 5 text for each audio. + labels = torch.arange(audio_features.shape[0]).long() + audio_to_text_loss = [ + F.cross_entropy( + logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], labels) for d in range(5) + ] + text_to_audio_loss = [ + F.cross_entropy( + logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], labels) for d in range(5) + ] + total_loss = ( + np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + + # text to audio: do 5 times + pred_text = [] + for d in range(5): + logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + ground_truth = torch.arange(len(logit)).view(-1, 1) + ranking = torch.argsort(logit, descending=True) # [num_samples, num_samples] + preds = torch.where(ranking == ground_truth)[1] + pred_text.append(preds.detach().cpu().numpy()) + pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] + metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 + metrics[f"text_to_audio_median_rank"] = np.floor(np.median(pred_text_concat)) + 1 + for k in [1, 5, 10]: + metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) + # map@10 + metrics[f"text_to_audio_mAP@10"] = np.mean(np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)) + + # audio to text: take the best result + # for audio to text map 10, sort and assign descending ground truth. + # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 + # map@10 + map_all = [] + pred_audio_all = [] + for d in range(num_samples): + # logits_per_audio: [num_samples, num_samples*5] + logit_single = logits_per_audio[d, :] # [5*num_samples] + # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] + ranking = torch.argsort(logit_single, descending=True) # [5*num_samples] + # ranking: the index of first match, second match, ... + ground_truth = torch.arange(d * 5, d * 5 + 5)[None] + all_pred = torch.where(torch.stack([ranking] * 5) == ground_truth.view(-1, 1))[1] + min_pred = torch.min(all_pred) + pred_audio_all.append(min_pred.detach().cpu().numpy()) + all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() + # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. + map_single = np.sum((np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))) / 5 + map_all.append(map_single) + metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) + for k in [1, 5, 10]: + metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) + + val_metrics_all[n] = { + n + "/" + k: v for k, v in metrics.items() + } + return val_metrics_all + + +def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): + """ + Calculate performance for Clotho+AudioCaps for model selection. + """ + selection_performance_all = [] + for n in val_metrics_per_dataset.keys(): + selection_performance = (val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] + + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]) / 2 + selection_performance_all.append(selection_performance) + return np.mean(selection_performance_all) + + +def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): + # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value + # metrics: dict, key: metric name, value: metric value + # Hack: use args to save the top performance + if not hasattr(args, "top_selection_performance"): + selection_performance = calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset) + # TODO: write the if and else together + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[k.split('/')[0] + '-top' + '/' + k.split('/')[1]] = val_metrics_per_dataset[n][k] + metric_update['top_selection_performance'] = selection_performance + metric_update['top-selection-epoch'] = metrics['epoch'] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance + else: + selection_performance_new = calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset) + selection_performance_old = args.top_selection_performance + if selection_performance_new > selection_performance_old: + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[k.split('/')[0] + '-top' + '/' + k.split('/')[1]] = val_metrics_per_dataset[n][k] + metric_update['top_selection_performance'] = selection_performance_new + metric_update['top-selection-epoch'] = metrics['epoch'] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance_new + else: + metrics.update(args.top_metric) + return metrics diff --git a/src/laion_clap/training/zero_shot.py b/src/laion_clap/training/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..04472c16e36041f90c8f229c5e026dcc394fb977 --- /dev/null +++ b/src/laion_clap/training/zero_shot.py @@ -0,0 +1,90 @@ +# NOTE: This script is currently not supported for CLAP. +import logging +from contextlib import suppress + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from clap_module import tokenize +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenize(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] + + +def run(model, classifier, dataloader, args): + autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress + with torch.no_grad(): + top1, top5, n = 0., 0., 0. + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100. * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = (top1 / n) + top5 = (top5 / n) + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if 'imagenet-val' not in data and 'imagenet-v2' not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info('Starting zero-shot imagenet.') + + logging.info('Building zero-shot classifier') + classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) + + logging.info('Using classifier') + results = {} + if 'imagenet-val' in data: + top1, top5 = run(model, classifier, data['imagenet-val'].dataloader, args) + results['imagenet-zeroshot-val-top1'] = top1 + results['imagenet-zeroshot-val-top5'] = top5 + if 'imagenet-v2' in data: + top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) + results['imagenetv2-zeroshot-val-top1'] = top1 + results['imagenetv2-zeroshot-val-top5'] = top5 + + logging.info('Finished zero-shot imagenet.') + + return results diff --git a/src/laion_clap/unit_test.py b/src/laion_clap/unit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f138c627b43bb70b45bdd9364841ac0f113e32dc --- /dev/null +++ b/src/laion_clap/unit_test.py @@ -0,0 +1,75 @@ +""" +Contrastive Language-Audio Pretraining Model from LAION +-------------------------------------------------------- +Paper: https://arxiv.org/abs/2211.06687 +Authors (equal contributions): Ke Chen, Yusong Wu, Tianyu Zhang, Yuchen Hui +Support: LAION +""" + +import numpy as np +import librosa +import torch +import laion_clap + +# quantization +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + +model = laion_clap.CLAP_Module(enable_fusion=False) +model.load_ckpt() + +# Directly get audio embeddings from audio files +audio_file = [ + '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', + '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav' +] +audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=False) +print(audio_embed[:,-20:]) +print(audio_embed.shape) + +# Get audio embeddings from audio data +audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000 +audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T) +audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False) +print(audio_embed[:,-20:]) +print(audio_embed.shape) + +# Directly get audio embeddings from audio files, but return torch tensor +audio_file = [ + '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', + '/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav' +] +audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=True) +print(audio_embed[:,-20:]) +print(audio_embed.shape) + +# Get audio embeddings from audio data +audio_data, _ = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_short.wav', sr=48000) # sample rate should be 48000 +audio_data = audio_data.reshape(1, -1) # Make it (1,T) or (N,T) +audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # quantize before send it in to the model +audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=True) +print(audio_embed[:,-20:]) +print(audio_embed.shape) + +# Get text embedings from texts: +text_data = ["I love the contrastive learning", "I love the pretrain model"] +text_embed = model.get_text_embedding(text_data) +print(text_embed) +print(text_embed.shape) + +# Get text embedings from texts, but return torch tensor: +text_data = ["I love the contrastive learning", "I love the pretrain model"] +text_embed = model.get_text_embedding(text_data, use_tensor=True) +print(text_embed) +print(text_embed.shape) + + + + + + diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/tests/check_ckpt.py b/src/tests/check_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..d034dd44ac67643f5f2b6ed8d229ee20d1ccbde4 --- /dev/null +++ b/src/tests/check_ckpt.py @@ -0,0 +1,802 @@ +import torch + +def keys_in_state_dict(ckpt, device='cpu'): + if device=="cpu": + a = torch.load(ckpt, map_location=torch.device('cpu'))["state_dict"] + else: + a = torch.load(ckpt)["state_dict"] + print("keys_in_state_dict", a.keys()) + + +def check_ckpt_diff(ckpt_a, ckpt_b, key_include=None, key_exclude=None, device='cpu', verbose=True): + if device=="cpu": + a = torch.load(ckpt_a, map_location=torch.device('cpu'))["state_dict"] + b = torch.load(ckpt_b, map_location=torch.device('cpu'))["state_dict"] + else: + a = torch.load(ckpt_a)["state_dict"] + b = torch.load(ckpt_b)["state_dict"] + a_sum = 0 + b_sum = 0 + difference_count = 0 + for k in a.keys(): + if key_include is not None and key_include not in k: + continue + if key_exclude is not None and key_exclude in k: + continue + if k in b.keys(): + a_sum += torch.sum(a[k]) + b_sum += torch.sum(b[k]) + if verbose: + if torch.sum(a[k]) != torch.sum(b[k]): + print(f"key {k} is different") + difference_count += 1 + print("a_sum: ", a_sum) + print("b_sum: ", b_sum) + print("diff: ", a_sum - b_sum) + if verbose: + print("difference_count: ", difference_count) + return bool(a_sum - b_sum) + +# Transformer no freeze: +# check_ckpt_diff("/fsx/clap_logs/2022_09_11-19_37_08-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_11-19_37_08-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.resblocks") + +check_ckpt_diff("/fsx/clap_logs/2022_09_29-23_42_40-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_1.pt", + "/fsx/clap_logs/2022_09_29-23_42_40-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_2.pt", + "text_branch.resblocks") + +# key module.text_branch.resblocks.0.attn.in_proj_weight is different +# key module.text_branch.resblocks.0.attn.in_proj_bias is different +# key module.text_branch.resblocks.0.attn.out_proj.weight is different +# key module.text_branch.resblocks.0.attn.out_proj.bias is different +# key module.text_branch.resblocks.0.ln_1.weight is different +# key module.text_branch.resblocks.0.ln_1.bias is different +# key module.text_branch.resblocks.0.mlp.c_fc.weight is different +# key module.text_branch.resblocks.0.mlp.c_fc.bias is different +# key module.text_branch.resblocks.0.mlp.c_proj.weight is different +# key module.text_branch.resblocks.0.mlp.c_proj.bias is different +# key module.text_branch.resblocks.0.ln_2.weight is different +# key module.text_branch.resblocks.0.ln_2.bias is different +# key module.text_branch.resblocks.1.attn.in_proj_weight is different +# key module.text_branch.resblocks.1.attn.in_proj_bias is different +# key module.text_branch.resblocks.1.attn.out_proj.weight is different +# key module.text_branch.resblocks.1.attn.out_proj.bias is different +# key module.text_branch.resblocks.1.ln_1.weight is different +# key module.text_branch.resblocks.1.ln_1.bias is different +# key module.text_branch.resblocks.1.mlp.c_fc.weight is different +# key module.text_branch.resblocks.1.mlp.c_fc.bias is different +# key module.text_branch.resblocks.1.mlp.c_proj.weight is different +# key module.text_branch.resblocks.1.mlp.c_proj.bias is different +# key module.text_branch.resblocks.1.ln_2.weight is different +# key module.text_branch.resblocks.1.ln_2.bias is different +# key module.text_branch.resblocks.2.attn.in_proj_weight is different +# key module.text_branch.resblocks.2.attn.in_proj_bias is different +# key module.text_branch.resblocks.2.attn.out_proj.weight is different +# key module.text_branch.resblocks.2.attn.out_proj.bias is different +# key module.text_branch.resblocks.2.ln_1.weight is different +# key module.text_branch.resblocks.2.ln_1.bias is different +# key module.text_branch.resblocks.2.mlp.c_fc.weight is different +# key module.text_branch.resblocks.2.mlp.c_fc.bias is different +# key module.text_branch.resblocks.2.mlp.c_proj.weight is different +# key module.text_branch.resblocks.2.mlp.c_proj.bias is different +# key module.text_branch.resblocks.2.ln_2.weight is different +# key module.text_branch.resblocks.2.ln_2.bias is different +# key module.text_branch.resblocks.3.attn.in_proj_weight is different +# key module.text_branch.resblocks.3.attn.in_proj_bias is different +# key module.text_branch.resblocks.3.attn.out_proj.weight is different +# key module.text_branch.resblocks.3.attn.out_proj.bias is different +# key module.text_branch.resblocks.3.ln_1.weight is different +# key module.text_branch.resblocks.3.ln_1.bias is different +# key module.text_branch.resblocks.3.mlp.c_fc.weight is different +# key module.text_branch.resblocks.3.mlp.c_fc.bias is different +# key module.text_branch.resblocks.3.mlp.c_proj.weight is different +# key module.text_branch.resblocks.3.mlp.c_proj.bias is different +# key module.text_branch.resblocks.3.ln_2.weight is different +# key module.text_branch.resblocks.3.ln_2.bias is different +# key module.text_branch.resblocks.4.attn.in_proj_weight is different +# key module.text_branch.resblocks.4.attn.in_proj_bias is different +# key module.text_branch.resblocks.4.attn.out_proj.weight is different +# key module.text_branch.resblocks.4.attn.out_proj.bias is different +# key module.text_branch.resblocks.4.ln_1.weight is different +# key module.text_branch.resblocks.4.ln_1.bias is different +# key module.text_branch.resblocks.4.mlp.c_fc.weight is different +# key module.text_branch.resblocks.4.mlp.c_fc.bias is different +# key module.text_branch.resblocks.4.mlp.c_proj.weight is different +# key module.text_branch.resblocks.4.mlp.c_proj.bias is different +# key module.text_branch.resblocks.4.ln_2.weight is different +# key module.text_branch.resblocks.4.ln_2.bias is different +# key module.text_branch.resblocks.5.attn.in_proj_weight is different +# key module.text_branch.resblocks.5.attn.in_proj_bias is different +# key module.text_branch.resblocks.5.attn.out_proj.weight is different +# key module.text_branch.resblocks.5.attn.out_proj.bias is different +# key module.text_branch.resblocks.5.ln_1.weight is different +# key module.text_branch.resblocks.5.ln_1.bias is different +# key module.text_branch.resblocks.5.mlp.c_fc.weight is different +# key module.text_branch.resblocks.5.mlp.c_fc.bias is different +# key module.text_branch.resblocks.5.mlp.c_proj.weight is different +# key module.text_branch.resblocks.5.mlp.c_proj.bias is different +# key module.text_branch.resblocks.5.ln_2.weight is different +# key module.text_branch.resblocks.5.ln_2.bias is different +# key module.text_branch.resblocks.6.attn.in_proj_weight is different +# key module.text_branch.resblocks.6.attn.in_proj_bias is different +# key module.text_branch.resblocks.6.attn.out_proj.weight is different +# key module.text_branch.resblocks.6.attn.out_proj.bias is different +# key module.text_branch.resblocks.6.ln_1.weight is different +# key module.text_branch.resblocks.6.ln_1.bias is different +# key module.text_branch.resblocks.6.mlp.c_fc.weight is different +# key module.text_branch.resblocks.6.mlp.c_fc.bias is different +# key module.text_branch.resblocks.6.mlp.c_proj.weight is different +# key module.text_branch.resblocks.6.mlp.c_proj.bias is different +# key module.text_branch.resblocks.6.ln_2.weight is different +# key module.text_branch.resblocks.6.ln_2.bias is different +# key module.text_branch.resblocks.7.attn.in_proj_weight is different +# key module.text_branch.resblocks.7.attn.in_proj_bias is different +# key module.text_branch.resblocks.7.attn.out_proj.weight is different +# key module.text_branch.resblocks.7.attn.out_proj.bias is different +# key module.text_branch.resblocks.7.ln_1.weight is different +# key module.text_branch.resblocks.7.ln_1.bias is different +# key module.text_branch.resblocks.7.mlp.c_fc.weight is different +# key module.text_branch.resblocks.7.mlp.c_fc.bias is different +# key module.text_branch.resblocks.7.mlp.c_proj.weight is different +# key module.text_branch.resblocks.7.mlp.c_proj.bias is different +# key module.text_branch.resblocks.7.ln_2.weight is different +# key module.text_branch.resblocks.7.ln_2.bias is different +# key module.text_branch.resblocks.8.attn.in_proj_weight is different +# key module.text_branch.resblocks.8.attn.in_proj_bias is different +# key module.text_branch.resblocks.8.attn.out_proj.weight is different +# key module.text_branch.resblocks.8.attn.out_proj.bias is different +# key module.text_branch.resblocks.8.ln_1.weight is different +# key module.text_branch.resblocks.8.ln_1.bias is different +# key module.text_branch.resblocks.8.mlp.c_fc.weight is different +# key module.text_branch.resblocks.8.mlp.c_fc.bias is different +# key module.text_branch.resblocks.8.mlp.c_proj.weight is different +# key module.text_branch.resblocks.8.mlp.c_proj.bias is different +# key module.text_branch.resblocks.8.ln_2.weight is different +# key module.text_branch.resblocks.8.ln_2.bias is different +# key module.text_branch.resblocks.9.attn.in_proj_weight is different +# key module.text_branch.resblocks.9.attn.in_proj_bias is different +# key module.text_branch.resblocks.9.attn.out_proj.weight is different +# key module.text_branch.resblocks.9.attn.out_proj.bias is different +# key module.text_branch.resblocks.9.ln_1.weight is different +# key module.text_branch.resblocks.9.ln_1.bias is different +# key module.text_branch.resblocks.9.mlp.c_fc.weight is different +# key module.text_branch.resblocks.9.mlp.c_fc.bias is different +# key module.text_branch.resblocks.9.mlp.c_proj.weight is different +# key module.text_branch.resblocks.9.mlp.c_proj.bias is different +# key module.text_branch.resblocks.9.ln_2.weight is different +# key module.text_branch.resblocks.9.ln_2.bias is different +# key module.text_branch.resblocks.10.attn.in_proj_weight is different +# key module.text_branch.resblocks.10.attn.in_proj_bias is different +# key module.text_branch.resblocks.10.attn.out_proj.weight is different +# key module.text_branch.resblocks.10.attn.out_proj.bias is different +# key module.text_branch.resblocks.10.ln_1.weight is different +# key module.text_branch.resblocks.10.ln_1.bias is different +# key module.text_branch.resblocks.10.mlp.c_fc.weight is different +# key module.text_branch.resblocks.10.mlp.c_fc.bias is different +# key module.text_branch.resblocks.10.mlp.c_proj.weight is different +# key module.text_branch.resblocks.10.mlp.c_proj.bias is different +# key module.text_branch.resblocks.10.ln_2.weight is different +# key module.text_branch.resblocks.10.ln_2.bias is different +# key module.text_branch.resblocks.11.attn.in_proj_weight is different +# key module.text_branch.resblocks.11.attn.in_proj_bias is different +# key module.text_branch.resblocks.11.attn.out_proj.weight is different +# key module.text_branch.resblocks.11.attn.out_proj.bias is different +# key module.text_branch.resblocks.11.ln_1.weight is different +# key module.text_branch.resblocks.11.ln_1.bias is different +# key module.text_branch.resblocks.11.mlp.c_fc.weight is different +# key module.text_branch.resblocks.11.mlp.c_fc.bias is different +# key module.text_branch.resblocks.11.mlp.c_proj.weight is different +# key module.text_branch.resblocks.11.mlp.c_proj.bias is different +# key module.text_branch.resblocks.11.ln_2.weight is different +# key module.text_branch.resblocks.11.ln_2.bias is different +# a_sum: tensor(12113.6445) +# b_sum: tensor(9883.4424) +# diff: tensor(2230.2021) +# True + + +# Transformer freeze: +# check_ckpt_diff("/fsx/clap_logs/2022_09_16-18_55_10-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_16-18_55_10-model_PANN-14-lr_0.001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.resblocks") + +# key module.text_branch.resblocks.0.attn.in_proj_weight is different +# key module.text_branch.resblocks.0.attn.in_proj_bias is different +# key module.text_branch.resblocks.0.attn.out_proj.weight is different +# key module.text_branch.resblocks.0.attn.out_proj.bias is different +# key module.text_branch.resblocks.0.ln_1.weight is different +# key module.text_branch.resblocks.0.ln_1.bias is different +# key module.text_branch.resblocks.0.mlp.c_fc.weight is different +# key module.text_branch.resblocks.0.mlp.c_fc.bias is different +# key module.text_branch.resblocks.0.mlp.c_proj.weight is different +# key module.text_branch.resblocks.0.mlp.c_proj.bias is different +# key module.text_branch.resblocks.0.ln_2.weight is different +# key module.text_branch.resblocks.0.ln_2.bias is different +# key module.text_branch.resblocks.1.attn.in_proj_weight is different +# key module.text_branch.resblocks.1.attn.in_proj_bias is different +# key module.text_branch.resblocks.1.attn.out_proj.weight is different +# key module.text_branch.resblocks.1.attn.out_proj.bias is different +# key module.text_branch.resblocks.1.ln_1.weight is different +# key module.text_branch.resblocks.1.ln_1.bias is different +# key module.text_branch.resblocks.1.mlp.c_fc.weight is different +# key module.text_branch.resblocks.1.mlp.c_fc.bias is different +# key module.text_branch.resblocks.1.mlp.c_proj.weight is different +# key module.text_branch.resblocks.1.mlp.c_proj.bias is different +# key module.text_branch.resblocks.1.ln_2.weight is different +# key module.text_branch.resblocks.1.ln_2.bias is different +# key module.text_branch.resblocks.2.attn.in_proj_weight is different +# key module.text_branch.resblocks.2.attn.in_proj_bias is different +# key module.text_branch.resblocks.2.attn.out_proj.weight is different +# key module.text_branch.resblocks.2.attn.out_proj.bias is different +# key module.text_branch.resblocks.2.ln_1.weight is different +# key module.text_branch.resblocks.2.ln_1.bias is different +# key module.text_branch.resblocks.2.mlp.c_fc.weight is different +# key module.text_branch.resblocks.2.mlp.c_fc.bias is different +# key module.text_branch.resblocks.2.mlp.c_proj.weight is different +# key module.text_branch.resblocks.2.mlp.c_proj.bias is different +# key module.text_branch.resblocks.2.ln_2.weight is different +# key module.text_branch.resblocks.2.ln_2.bias is different +# key module.text_branch.resblocks.3.attn.in_proj_weight is different +# key module.text_branch.resblocks.3.attn.in_proj_bias is different +# key module.text_branch.resblocks.3.attn.out_proj.weight is different +# key module.text_branch.resblocks.3.attn.out_proj.bias is different +# key module.text_branch.resblocks.3.ln_1.weight is different +# key module.text_branch.resblocks.3.ln_1.bias is different +# key module.text_branch.resblocks.3.mlp.c_fc.weight is different +# key module.text_branch.resblocks.3.mlp.c_fc.bias is different +# key module.text_branch.resblocks.3.mlp.c_proj.weight is different +# key module.text_branch.resblocks.3.mlp.c_proj.bias is different +# key module.text_branch.resblocks.3.ln_2.weight is different +# key module.text_branch.resblocks.3.ln_2.bias is different +# key module.text_branch.resblocks.4.attn.in_proj_weight is different +# key module.text_branch.resblocks.4.attn.in_proj_bias is different +# key module.text_branch.resblocks.4.attn.out_proj.weight is different +# key module.text_branch.resblocks.4.attn.out_proj.bias is different +# key module.text_branch.resblocks.4.ln_1.weight is different +# key module.text_branch.resblocks.4.ln_1.bias is different +# key module.text_branch.resblocks.4.mlp.c_fc.weight is different +# key module.text_branch.resblocks.4.mlp.c_fc.bias is different +# key module.text_branch.resblocks.4.mlp.c_proj.weight is different +# key module.text_branch.resblocks.4.mlp.c_proj.bias is different +# key module.text_branch.resblocks.4.ln_2.weight is different +# key module.text_branch.resblocks.4.ln_2.bias is different +# key module.text_branch.resblocks.5.attn.in_proj_weight is different +# key module.text_branch.resblocks.5.attn.in_proj_bias is different +# key module.text_branch.resblocks.5.attn.out_proj.weight is different +# key module.text_branch.resblocks.5.attn.out_proj.bias is different +# key module.text_branch.resblocks.5.ln_1.weight is different +# key module.text_branch.resblocks.5.ln_1.bias is different +# key module.text_branch.resblocks.5.mlp.c_fc.weight is different +# key module.text_branch.resblocks.5.mlp.c_fc.bias is different +# key module.text_branch.resblocks.5.mlp.c_proj.weight is different +# key module.text_branch.resblocks.5.mlp.c_proj.bias is different +# key module.text_branch.resblocks.5.ln_2.weight is different +# key module.text_branch.resblocks.5.ln_2.bias is different +# key module.text_branch.resblocks.6.attn.in_proj_weight is different +# key module.text_branch.resblocks.6.attn.in_proj_bias is different +# key module.text_branch.resblocks.6.attn.out_proj.weight is different +# key module.text_branch.resblocks.6.attn.out_proj.bias is different +# key module.text_branch.resblocks.6.ln_1.weight is different +# key module.text_branch.resblocks.6.ln_1.bias is different +# key module.text_branch.resblocks.6.mlp.c_fc.weight is different +# key module.text_branch.resblocks.6.mlp.c_fc.bias is different +# key module.text_branch.resblocks.6.mlp.c_proj.weight is different +# key module.text_branch.resblocks.6.mlp.c_proj.bias is different +# key module.text_branch.resblocks.6.ln_2.weight is different +# key module.text_branch.resblocks.6.ln_2.bias is different +# key module.text_branch.resblocks.7.attn.in_proj_weight is different +# key module.text_branch.resblocks.7.attn.in_proj_bias is different +# key module.text_branch.resblocks.7.attn.out_proj.weight is different +# key module.text_branch.resblocks.7.attn.out_proj.bias is different +# key module.text_branch.resblocks.7.ln_1.weight is different +# key module.text_branch.resblocks.7.ln_1.bias is different +# key module.text_branch.resblocks.7.mlp.c_fc.weight is different +# key module.text_branch.resblocks.7.mlp.c_fc.bias is different +# key module.text_branch.resblocks.7.mlp.c_proj.weight is different +# key module.text_branch.resblocks.7.mlp.c_proj.bias is different +# key module.text_branch.resblocks.7.ln_2.weight is different +# key module.text_branch.resblocks.7.ln_2.bias is different +# key module.text_branch.resblocks.8.attn.in_proj_weight is different +# key module.text_branch.resblocks.8.attn.in_proj_bias is different +# key module.text_branch.resblocks.8.attn.out_proj.weight is different +# key module.text_branch.resblocks.8.attn.out_proj.bias is different +# key module.text_branch.resblocks.8.ln_1.weight is different +# key module.text_branch.resblocks.8.ln_1.bias is different +# key module.text_branch.resblocks.8.mlp.c_fc.weight is different +# key module.text_branch.resblocks.8.mlp.c_fc.bias is different +# key module.text_branch.resblocks.8.mlp.c_proj.weight is different +# key module.text_branch.resblocks.8.mlp.c_proj.bias is different +# key module.text_branch.resblocks.8.ln_2.weight is different +# key module.text_branch.resblocks.8.ln_2.bias is different +# key module.text_branch.resblocks.9.attn.in_proj_weight is different +# key module.text_branch.resblocks.9.attn.in_proj_bias is different +# key module.text_branch.resblocks.9.attn.out_proj.weight is different +# key module.text_branch.resblocks.9.attn.out_proj.bias is different +# key module.text_branch.resblocks.9.ln_1.weight is different +# key module.text_branch.resblocks.9.ln_1.bias is different +# key module.text_branch.resblocks.9.mlp.c_fc.weight is different +# key module.text_branch.resblocks.9.mlp.c_fc.bias is different +# key module.text_branch.resblocks.9.mlp.c_proj.weight is different +# key module.text_branch.resblocks.9.mlp.c_proj.bias is different +# key module.text_branch.resblocks.9.ln_2.weight is different +# key module.text_branch.resblocks.9.ln_2.bias is different +# key module.text_branch.resblocks.10.attn.in_proj_weight is different +# key module.text_branch.resblocks.10.attn.in_proj_bias is different +# key module.text_branch.resblocks.10.attn.out_proj.weight is different +# key module.text_branch.resblocks.10.attn.out_proj.bias is different +# key module.text_branch.resblocks.10.ln_1.weight is different +# key module.text_branch.resblocks.10.ln_1.bias is different +# key module.text_branch.resblocks.10.mlp.c_fc.weight is different +# key module.text_branch.resblocks.10.mlp.c_fc.bias is different +# key module.text_branch.resblocks.10.mlp.c_proj.weight is different +# key module.text_branch.resblocks.10.mlp.c_proj.bias is different +# key module.text_branch.resblocks.10.ln_2.weight is different +# key module.text_branch.resblocks.10.ln_2.bias is different +# key module.text_branch.resblocks.11.attn.in_proj_weight is different +# key module.text_branch.resblocks.11.attn.in_proj_bias is different +# key module.text_branch.resblocks.11.attn.out_proj.weight is different +# key module.text_branch.resblocks.11.attn.out_proj.bias is different +# key module.text_branch.resblocks.11.ln_1.weight is different +# key module.text_branch.resblocks.11.ln_1.bias is different +# key module.text_branch.resblocks.11.mlp.c_fc.weight is different +# key module.text_branch.resblocks.11.mlp.c_fc.bias is different +# key module.text_branch.resblocks.11.mlp.c_proj.weight is different +# key module.text_branch.resblocks.11.mlp.c_proj.bias is different +# key module.text_branch.resblocks.11.ln_2.weight is different +# key module.text_branch.resblocks.11.ln_2.bias is different +# a_sum: tensor(12133.6348) +# b_sum: tensor(10423.9521) +# diff: tensor(1709.6826) +# True + + +# bert no freeze: +# check_ckpt_diff("/fsx/clap_logs/2022_09_14-02_33_11-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_14-02_33_11-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.encoder") + +# key module.text_branch.encoder.layer.0.attention.self.query.weight is different +# key module.text_branch.encoder.layer.0.attention.self.query.bias is different +# key module.text_branch.encoder.layer.0.attention.self.key.weight is different +# key module.text_branch.encoder.layer.0.attention.self.key.bias is different +# key module.text_branch.encoder.layer.0.attention.self.value.weight is different +# key module.text_branch.encoder.layer.0.attention.self.value.bias is different +# key module.text_branch.encoder.layer.0.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.0.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.0.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.0.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.0.output.dense.weight is different +# key module.text_branch.encoder.layer.0.output.dense.bias is different +# key module.text_branch.encoder.layer.0.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.0.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.1.attention.self.query.weight is different +# key module.text_branch.encoder.layer.1.attention.self.query.bias is different +# key module.text_branch.encoder.layer.1.attention.self.key.weight is different +# key module.text_branch.encoder.layer.1.attention.self.key.bias is different +# key module.text_branch.encoder.layer.1.attention.self.value.weight is different +# key module.text_branch.encoder.layer.1.attention.self.value.bias is different +# key module.text_branch.encoder.layer.1.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.1.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.1.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.1.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.1.output.dense.weight is different +# key module.text_branch.encoder.layer.1.output.dense.bias is different +# key module.text_branch.encoder.layer.1.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.1.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.2.attention.self.query.weight is different +# key module.text_branch.encoder.layer.2.attention.self.query.bias is different +# key module.text_branch.encoder.layer.2.attention.self.key.weight is different +# key module.text_branch.encoder.layer.2.attention.self.key.bias is different +# key module.text_branch.encoder.layer.2.attention.self.value.weight is different +# key module.text_branch.encoder.layer.2.attention.self.value.bias is different +# key module.text_branch.encoder.layer.2.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.2.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.2.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.2.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.2.output.dense.weight is different +# key module.text_branch.encoder.layer.2.output.dense.bias is different +# key module.text_branch.encoder.layer.2.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.2.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.3.attention.self.query.weight is different +# key module.text_branch.encoder.layer.3.attention.self.query.bias is different +# key module.text_branch.encoder.layer.3.attention.self.key.weight is different +# key module.text_branch.encoder.layer.3.attention.self.key.bias is different +# key module.text_branch.encoder.layer.3.attention.self.value.weight is different +# key module.text_branch.encoder.layer.3.attention.self.value.bias is different +# key module.text_branch.encoder.layer.3.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.3.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.3.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.3.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.3.output.dense.weight is different +# key module.text_branch.encoder.layer.3.output.dense.bias is different +# key module.text_branch.encoder.layer.3.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.3.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.4.attention.self.query.weight is different +# key module.text_branch.encoder.layer.4.attention.self.query.bias is different +# key module.text_branch.encoder.layer.4.attention.self.key.weight is different +# key module.text_branch.encoder.layer.4.attention.self.key.bias is different +# key module.text_branch.encoder.layer.4.attention.self.value.weight is different +# key module.text_branch.encoder.layer.4.attention.self.value.bias is different +# key module.text_branch.encoder.layer.4.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.4.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.4.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.4.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.4.output.dense.weight is different +# key module.text_branch.encoder.layer.4.output.dense.bias is different +# key module.text_branch.encoder.layer.4.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.4.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.5.attention.self.query.weight is different +# key module.text_branch.encoder.layer.5.attention.self.query.bias is different +# key module.text_branch.encoder.layer.5.attention.self.key.weight is different +# key module.text_branch.encoder.layer.5.attention.self.key.bias is different +# key module.text_branch.encoder.layer.5.attention.self.value.weight is different +# key module.text_branch.encoder.layer.5.attention.self.value.bias is different +# key module.text_branch.encoder.layer.5.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.5.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.5.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.5.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.5.output.dense.weight is different +# key module.text_branch.encoder.layer.5.output.dense.bias is different +# key module.text_branch.encoder.layer.5.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.5.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.6.attention.self.query.weight is different +# key module.text_branch.encoder.layer.6.attention.self.query.bias is different +# key module.text_branch.encoder.layer.6.attention.self.key.weight is different +# key module.text_branch.encoder.layer.6.attention.self.key.bias is different +# key module.text_branch.encoder.layer.6.attention.self.value.weight is different +# key module.text_branch.encoder.layer.6.attention.self.value.bias is different +# key module.text_branch.encoder.layer.6.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.6.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.6.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.6.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.6.output.dense.weight is different +# key module.text_branch.encoder.layer.6.output.dense.bias is different +# key module.text_branch.encoder.layer.6.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.6.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.7.attention.self.query.weight is different +# key module.text_branch.encoder.layer.7.attention.self.query.bias is different +# key module.text_branch.encoder.layer.7.attention.self.key.weight is different +# key module.text_branch.encoder.layer.7.attention.self.key.bias is different +# key module.text_branch.encoder.layer.7.attention.self.value.weight is different +# key module.text_branch.encoder.layer.7.attention.self.value.bias is different +# key module.text_branch.encoder.layer.7.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.7.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.7.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.7.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.7.output.dense.weight is different +# key module.text_branch.encoder.layer.7.output.dense.bias is different +# key module.text_branch.encoder.layer.7.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.7.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.8.attention.self.query.weight is different +# key module.text_branch.encoder.layer.8.attention.self.query.bias is different +# key module.text_branch.encoder.layer.8.attention.self.key.weight is different +# key module.text_branch.encoder.layer.8.attention.self.key.bias is different +# key module.text_branch.encoder.layer.8.attention.self.value.weight is different +# key module.text_branch.encoder.layer.8.attention.self.value.bias is different +# key module.text_branch.encoder.layer.8.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.8.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.8.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.8.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.8.output.dense.weight is different +# key module.text_branch.encoder.layer.8.output.dense.bias is different +# key module.text_branch.encoder.layer.8.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.8.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.9.attention.self.query.weight is different +# key module.text_branch.encoder.layer.9.attention.self.query.bias is different +# key module.text_branch.encoder.layer.9.attention.self.key.weight is different +# key module.text_branch.encoder.layer.9.attention.self.key.bias is different +# key module.text_branch.encoder.layer.9.attention.self.value.weight is different +# key module.text_branch.encoder.layer.9.attention.self.value.bias is different +# key module.text_branch.encoder.layer.9.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.9.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.9.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.9.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.9.output.dense.weight is different +# key module.text_branch.encoder.layer.9.output.dense.bias is different +# key module.text_branch.encoder.layer.9.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.9.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.10.attention.self.query.weight is different +# key module.text_branch.encoder.layer.10.attention.self.query.bias is different +# key module.text_branch.encoder.layer.10.attention.self.key.weight is different +# key module.text_branch.encoder.layer.10.attention.self.key.bias is different +# key module.text_branch.encoder.layer.10.attention.self.value.weight is different +# key module.text_branch.encoder.layer.10.attention.self.value.bias is different +# key module.text_branch.encoder.layer.10.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.10.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.10.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.10.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.10.output.dense.weight is different +# key module.text_branch.encoder.layer.10.output.dense.bias is different +# key module.text_branch.encoder.layer.10.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.10.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.11.attention.self.query.weight is different +# key module.text_branch.encoder.layer.11.attention.self.query.bias is different +# key module.text_branch.encoder.layer.11.attention.self.key.weight is different +# key module.text_branch.encoder.layer.11.attention.self.key.bias is different +# key module.text_branch.encoder.layer.11.attention.self.value.weight is different +# key module.text_branch.encoder.layer.11.attention.self.value.bias is different +# key module.text_branch.encoder.layer.11.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.11.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.11.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.11.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.11.output.dense.weight is different +# key module.text_branch.encoder.layer.11.output.dense.bias is different +# key module.text_branch.encoder.layer.11.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.11.output.LayerNorm.bias is different +# a_sum: tensor(15185.1230) +# b_sum: tensor(15576.5596) +# diff: tensor(-391.4365) +# True + + +# bert freeze: +# check_ckpt_diff("/fsx/clap_logs/2022_09_13-01_25_15-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_10.pt", "/fsx/clap_logs/2022_09_13-01_25_15-model_PANN-14-lr_0.0001-b_160-j_4-p_fp32/checkpoints/epoch_100.pt", "text_branch.encoder") + +# key module.text_branch.encoder.layer.0.attention.self.query.weight is different +# key module.text_branch.encoder.layer.0.attention.self.query.bias is different +# key module.text_branch.encoder.layer.0.attention.self.key.weight is different +# key module.text_branch.encoder.layer.0.attention.self.key.bias is different +# key module.text_branch.encoder.layer.0.attention.self.value.weight is different +# key module.text_branch.encoder.layer.0.attention.self.value.bias is different +# key module.text_branch.encoder.layer.0.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.0.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.0.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.0.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.0.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.0.output.dense.weight is different +# key module.text_branch.encoder.layer.0.output.dense.bias is different +# key module.text_branch.encoder.layer.0.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.0.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.1.attention.self.query.weight is different +# key module.text_branch.encoder.layer.1.attention.self.query.bias is different +# key module.text_branch.encoder.layer.1.attention.self.key.weight is different +# key module.text_branch.encoder.layer.1.attention.self.key.bias is different +# key module.text_branch.encoder.layer.1.attention.self.value.weight is different +# key module.text_branch.encoder.layer.1.attention.self.value.bias is different +# key module.text_branch.encoder.layer.1.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.1.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.1.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.1.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.1.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.1.output.dense.weight is different +# key module.text_branch.encoder.layer.1.output.dense.bias is different +# key module.text_branch.encoder.layer.1.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.1.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.2.attention.self.query.weight is different +# key module.text_branch.encoder.layer.2.attention.self.query.bias is different +# key module.text_branch.encoder.layer.2.attention.self.key.weight is different +# key module.text_branch.encoder.layer.2.attention.self.key.bias is different +# key module.text_branch.encoder.layer.2.attention.self.value.weight is different +# key module.text_branch.encoder.layer.2.attention.self.value.bias is different +# key module.text_branch.encoder.layer.2.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.2.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.2.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.2.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.2.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.2.output.dense.weight is different +# key module.text_branch.encoder.layer.2.output.dense.bias is different +# key module.text_branch.encoder.layer.2.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.2.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.3.attention.self.query.weight is different +# key module.text_branch.encoder.layer.3.attention.self.query.bias is different +# key module.text_branch.encoder.layer.3.attention.self.key.weight is different +# key module.text_branch.encoder.layer.3.attention.self.key.bias is different +# key module.text_branch.encoder.layer.3.attention.self.value.weight is different +# key module.text_branch.encoder.layer.3.attention.self.value.bias is different +# key module.text_branch.encoder.layer.3.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.3.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.3.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.3.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.3.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.3.output.dense.weight is different +# key module.text_branch.encoder.layer.3.output.dense.bias is different +# key module.text_branch.encoder.layer.3.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.3.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.4.attention.self.query.weight is different +# key module.text_branch.encoder.layer.4.attention.self.query.bias is different +# key module.text_branch.encoder.layer.4.attention.self.key.weight is different +# key module.text_branch.encoder.layer.4.attention.self.key.bias is different +# key module.text_branch.encoder.layer.4.attention.self.value.weight is different +# key module.text_branch.encoder.layer.4.attention.self.value.bias is different +# key module.text_branch.encoder.layer.4.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.4.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.4.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.4.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.4.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.4.output.dense.weight is different +# key module.text_branch.encoder.layer.4.output.dense.bias is different +# key module.text_branch.encoder.layer.4.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.4.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.5.attention.self.query.weight is different +# key module.text_branch.encoder.layer.5.attention.self.query.bias is different +# key module.text_branch.encoder.layer.5.attention.self.key.weight is different +# key module.text_branch.encoder.layer.5.attention.self.key.bias is different +# key module.text_branch.encoder.layer.5.attention.self.value.weight is different +# key module.text_branch.encoder.layer.5.attention.self.value.bias is different +# key module.text_branch.encoder.layer.5.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.5.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.5.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.5.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.5.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.5.output.dense.weight is different +# key module.text_branch.encoder.layer.5.output.dense.bias is different +# key module.text_branch.encoder.layer.5.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.5.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.6.attention.self.query.weight is different +# key module.text_branch.encoder.layer.6.attention.self.query.bias is different +# key module.text_branch.encoder.layer.6.attention.self.key.weight is different +# key module.text_branch.encoder.layer.6.attention.self.key.bias is different +# key module.text_branch.encoder.layer.6.attention.self.value.weight is different +# key module.text_branch.encoder.layer.6.attention.self.value.bias is different +# key module.text_branch.encoder.layer.6.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.6.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.6.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.6.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.6.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.6.output.dense.weight is different +# key module.text_branch.encoder.layer.6.output.dense.bias is different +# key module.text_branch.encoder.layer.6.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.6.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.7.attention.self.query.weight is different +# key module.text_branch.encoder.layer.7.attention.self.query.bias is different +# key module.text_branch.encoder.layer.7.attention.self.key.weight is different +# key module.text_branch.encoder.layer.7.attention.self.key.bias is different +# key module.text_branch.encoder.layer.7.attention.self.value.weight is different +# key module.text_branch.encoder.layer.7.attention.self.value.bias is different +# key module.text_branch.encoder.layer.7.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.7.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.7.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.7.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.7.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.7.output.dense.weight is different +# key module.text_branch.encoder.layer.7.output.dense.bias is different +# key module.text_branch.encoder.layer.7.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.7.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.8.attention.self.query.weight is different +# key module.text_branch.encoder.layer.8.attention.self.query.bias is different +# key module.text_branch.encoder.layer.8.attention.self.key.weight is different +# key module.text_branch.encoder.layer.8.attention.self.key.bias is different +# key module.text_branch.encoder.layer.8.attention.self.value.weight is different +# key module.text_branch.encoder.layer.8.attention.self.value.bias is different +# key module.text_branch.encoder.layer.8.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.8.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.8.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.8.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.8.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.8.output.dense.weight is different +# key module.text_branch.encoder.layer.8.output.dense.bias is different +# key module.text_branch.encoder.layer.8.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.8.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.9.attention.self.query.weight is different +# key module.text_branch.encoder.layer.9.attention.self.query.bias is different +# key module.text_branch.encoder.layer.9.attention.self.key.weight is different +# key module.text_branch.encoder.layer.9.attention.self.key.bias is different +# key module.text_branch.encoder.layer.9.attention.self.value.weight is different +# key module.text_branch.encoder.layer.9.attention.self.value.bias is different +# key module.text_branch.encoder.layer.9.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.9.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.9.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.9.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.9.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.9.output.dense.weight is different +# key module.text_branch.encoder.layer.9.output.dense.bias is different +# key module.text_branch.encoder.layer.9.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.9.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.10.attention.self.query.weight is different +# key module.text_branch.encoder.layer.10.attention.self.query.bias is different +# key module.text_branch.encoder.layer.10.attention.self.key.weight is different +# key module.text_branch.encoder.layer.10.attention.self.key.bias is different +# key module.text_branch.encoder.layer.10.attention.self.value.weight is different +# key module.text_branch.encoder.layer.10.attention.self.value.bias is different +# key module.text_branch.encoder.layer.10.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.10.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.10.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.10.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.10.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.10.output.dense.weight is different +# key module.text_branch.encoder.layer.10.output.dense.bias is different +# key module.text_branch.encoder.layer.10.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.10.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.11.attention.self.query.weight is different +# key module.text_branch.encoder.layer.11.attention.self.query.bias is different +# key module.text_branch.encoder.layer.11.attention.self.key.weight is different +# key module.text_branch.encoder.layer.11.attention.self.key.bias is different +# key module.text_branch.encoder.layer.11.attention.self.value.weight is different +# key module.text_branch.encoder.layer.11.attention.self.value.bias is different +# key module.text_branch.encoder.layer.11.attention.output.dense.weight is different +# key module.text_branch.encoder.layer.11.attention.output.dense.bias is different +# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.11.attention.output.LayerNorm.bias is different +# key module.text_branch.encoder.layer.11.intermediate.dense.weight is different +# key module.text_branch.encoder.layer.11.intermediate.dense.bias is different +# key module.text_branch.encoder.layer.11.output.dense.weight is different +# key module.text_branch.encoder.layer.11.output.dense.bias is different +# key module.text_branch.encoder.layer.11.output.LayerNorm.weight is different +# key module.text_branch.encoder.layer.11.output.LayerNorm.bias is different +# a_sum: tensor(15078.6641) +# b_sum: tensor(15540.0723) +# diff: tensor(-461.4082) +# True + +# linear_prob_text +# check_ckpt_diff("/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_50.pt", "/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_100.pt", "text_branch.resblocks") + +# a_sum: tensor(12111.0244) +# b_sum: tensor(12111.0244) +# diff: tensor(0.) + +# linear_prob_audio +# check_ckpt_diff("/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_50.pt", "/fsx/clap_logs/2022_09_15-02_05_29-linear_probemodel_PANN-14-lr_0.0001-b_512-j_4-p_fp32/checkpoints/pretrain_epoch_10_lp_epoch_100.pt", "clap_model") + +# key clap_model.audio_branch.bn0.num_batches_tracked is different +# key clap_model.audio_branch.conv_block1.bn1.running_mean is different +# key clap_model.audio_branch.conv_block1.bn1.running_var is different +# key clap_model.audio_branch.conv_block1.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block1.bn2.running_mean is different +# key clap_model.audio_branch.conv_block1.bn2.running_var is different +# key clap_model.audio_branch.conv_block1.bn2.num_batches_tracked is different +# key clap_model.audio_branch.conv_block2.bn1.running_mean is different +# key clap_model.audio_branch.conv_block2.bn1.running_var is different +# key clap_model.audio_branch.conv_block2.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block2.bn2.running_mean is different +# key clap_model.audio_branch.conv_block2.bn2.running_var is different +# key clap_model.audio_branch.conv_block2.bn2.num_batches_tracked is different +# key clap_model.audio_branch.conv_block3.bn1.running_mean is different +# key clap_model.audio_branch.conv_block3.bn1.running_var is different +# key clap_model.audio_branch.conv_block3.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block3.bn2.running_mean is different +# key clap_model.audio_branch.conv_block3.bn2.running_var is different +# key clap_model.audio_branch.conv_block3.bn2.num_batches_tracked is different +# key clap_model.audio_branch.conv_block4.bn1.running_mean is different +# key clap_model.audio_branch.conv_block4.bn1.running_var is different +# key clap_model.audio_branch.conv_block4.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block4.bn2.running_mean is different +# key clap_model.audio_branch.conv_block4.bn2.running_var is different +# key clap_model.audio_branch.conv_block4.bn2.num_batches_tracked is different +# key clap_model.audio_branch.conv_block5.bn1.running_mean is different +# key clap_model.audio_branch.conv_block5.bn1.running_var is different +# key clap_model.audio_branch.conv_block5.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block5.bn2.running_mean is different +# key clap_model.audio_branch.conv_block5.bn2.running_var is different +# key clap_model.audio_branch.conv_block5.bn2.num_batches_tracked is different +# key clap_model.audio_branch.conv_block6.bn1.running_mean is different +# key clap_model.audio_branch.conv_block6.bn1.running_var is different +# key clap_model.audio_branch.conv_block6.bn1.num_batches_tracked is different +# key clap_model.audio_branch.conv_block6.bn2.running_mean is different +# key clap_model.audio_branch.conv_block6.bn2.running_var is different +# key clap_model.audio_branch.conv_block6.bn2.num_batches_tracked is different +# a_sum: tensor(120061.5078) +# b_sum: tensor(122656.0469) +# diff: tensor(-2594.5391) +# True + diff --git a/src/tests/check_tars.py b/src/tests/check_tars.py new file mode 100644 index 0000000000000000000000000000000000000000..8dcf1c120dcc316f4f3166f57d448e64eaf2dbdd --- /dev/null +++ b/src/tests/check_tars.py @@ -0,0 +1,120 @@ +import webdataset as wds +import soundfile as sf +import io +import os +import random +import copy +from tqdm import tqdm +import shutil +import argparse +import traceback +import logging +import json +from laion_clap import tokenize + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--tar-path", + type=str, + default=None, + help="Path to the tars", + ) + parser.add_argument( + "--start", + type=int, + default=0, + help="start from tar-path + start", + ) + parser.add_argument( + "--end", + type=int, + default=99999, + help="end with tar-path + end", + ) + parser.add_argument( + "--exclude", + nargs='+', + default=None, + help="exclude tar-path + exclude", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + ) + parser.add_argument( + "--order", + default=False, + action='store_true', + help="if keep the search order accendingly", + ) + args = parser.parse_args() + return args + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + +def preprocess( + sample, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_ext = "flac" + text_ext = "json" + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + sample["waveform"] = audio_data + texts = json_dict_raw["text"] + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenize(texts) + return sample + +if __name__ == "__main__": + args = parse_args() + tar_path = args.tar_path + idx_list = list(range(args.start, args.end)) + if args.exclude != None: + for x in args.exclude: + idx_list.remove(x) + if not args.order: + random.shuffle(idx_list) + if "aws" in tar_path: + args.local = False + if args.local: + input_shards = [os.path.join(args.tar_path, str(i)+".tar") for i in idx_list] + else: + input_shards = [os.path.join(args.tar_path, str(i)+".tar -") for i in idx_list] + pipeline = [wds.SimpleShardList(input_shards)] + pipeline.extend( + [ + wds.split_by_node, + wds.split_by_worker, + wds.tarfile_to_samples(handler=log_and_continue), + wds.map(preprocess), + wds.to_tuple("__url__", "__key__", "waveform"), + wds.batched(1), + ] + ) + dataset = wds.DataPipeline(*pipeline) + dataloader = wds.WebLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) + old_k = 0 + old_batch = None + try: + for k, batch in tqdm(enumerate(dataloader)): + print("k:", k) + print("batch:", batch) + old_k = k + old_batch = copy.deepcopy(batch) + except: + with open("check_tar_log.txt","a") as file: + traceback.print_exc(file = file) + print("old_k:", old_k) + print("old_batch:", old_batch) + pass diff --git a/src/tests/data_loader_test.py b/src/tests/data_loader_test.py new file mode 100644 index 0000000000000000000000000000000000000000..03be75e7ce16723053eb3d506b59f91160a7f3c7 --- /dev/null +++ b/src/tests/data_loader_test.py @@ -0,0 +1,60 @@ +from laion_clap import create_model +from laion_clap.training.data import get_data +from laion_clap.training import parse_args +import torch +import os +from tqdm import tqdm +from laion_clap.training.distributed import is_master, world_info_from_env +from laion_clap.utils import dataset_split + + +def run_dataloader(): + for i, batch in enumerate(tqdm(dataloader, total=data["train"].dataloader.num_samples // args.batch_size)): + pass + + +if __name__ == '__main__': + + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + device = torch.device('cpu') + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type + ) + + data = get_data(args, model_cfg) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + + print('dataset size:', data["train"].dataloader.num_samples) + print('batch size:', args.batch_size) + print('num batches:', data["train"].dataloader.num_samples // args.batch_size) + + run_dataloader()