diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7101e7c8dd35d88daa7aeb845f35f584f0ad3869
--- /dev/null
+++ b/README.md
@@ -0,0 +1,186 @@
+---
+title: SongFormer
+emoji: 🎵
+colorFrom: blue
+colorTo: indigo
+sdk: gradio
+python_version: "3.10"
+app_file: app.py
+tags:
+ - music-structure-annotation
+ - transformer
+short_description: State-of-the-art music analysis with multi-scale datasets
+fullWidth: true
+---
+
+
+
+
+
+
+# SONGFORMER: SCALING MUSIC STRUCTURE ANALYSIS WITH HETEROGENEOUS SUPERVISION
+
+
+
+[]()
+[](https://github.com/ASLP-lab/SongFormer)
+[](https://huggingface.co/spaces/ASLP-lab/SongFormer)
+[](https://huggingface.co/ASLP-lab/SongFormer)
+[](https://huggingface.co/datasets/ASLP-lab/SongFormDB)
+[](https://huggingface.co/datasets/ASLP-lab/SongFormBench)
+[](https://discord.gg/rwcqh7Em)
+[](http://www.npu-aslp.org/)
+
+Chunbo Hao*, Ruibin Yuan*, Jixun Yao, Qixin Deng, Xinyi Bai, Wei Xue, Lei Xie†
+
+
+----
+
+
+SongFormer is a music structure analysis framework that leverages multi-resolution self-supervised representations and heterogeneous supervision, accompanied by the large-scale multilingual dataset SongFormDB and the high-quality benchmark SongFormBench to foster fair and reproducible research.
+
+
+
+## News and Updates
+
+## 📋 To-Do List
+
+- [x] Complete and push inference code to GitHub
+- [x] Upload model checkpoint(s) to Hugging Face Hub
+- [ ] Upload the paper to arXiv
+- [x] Fix readme
+- [ ] Deploy an out-of-the-box inference version on Hugging Face (via Inference API or Spaces)
+- [ ] Publish the package to PyPI for easy installation via `pip`
+- [ ] Open-source evaluation code
+- [ ] Open-source training code
+
+## Installation
+
+### Setting up Python Environment
+
+```bash
+git clone https://github.com/ASLP-lab/SongFormer.git
+
+# Get MuQ and MusicFM source code
+git submodule update --init --recursive
+
+conda create -n songformer python=3.10 -y
+conda activate songformer
+```
+
+For users in mainland China, you may need to set up pip mirror source:
+
+```bash
+pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
+```
+
+Install dependencies:
+
+```bash
+pip install -r requirements.txt
+```
+
+We tested this on Ubuntu 22.04.1 LTS and it works normally. If you cannot install, you may need to remove version constraints in `requirements.txt`
+
+### Download Pre-trained Models
+
+```bash
+cd src/SongFormer
+# For users in mainland China, you can modify according to the py file instructions to use hf-mirror.com for downloading
+python utils/fetch_pretrained.py
+```
+
+After downloading, you can verify the md5sum values in `src/SongFormer/ckpts/MusicFM/md5sum.txt` match the downloaded files:
+
+```bash
+md5sum ckpts/MusicFM/msd_stats.json
+md5sum ckpts/MusicFM/pretrained_msd.pt
+md5sum ckpts/SongFormer.safetensors
+# md5sum ckpts/SongFormer.pt
+```
+
+## Inference
+
+## Inference
+
+### 1. One-Click Inference with HuggingFace Space (coming soon)
+
+Available at: [https://huggingface.co/spaces/ASLP-lab/SongFormer](https://huggingface.co/spaces/ASLP-lab/SongFormer)
+
+### 2. Gradio App
+
+First, cd to the project root directory and activate the environment:
+
+```bash
+conda activate songformer
+```
+
+You can modify the server port and listening address in the last line of `app.py` according to your preference.
+
+> If you're using an HTTP proxy, please ensure you include:
+>
+> ```bash
+> export no_proxy="localhost, 127.0.0.1, ::1"
+> export NO_PROXY="localhost, 127.0.0.1, ::1"
+> ```
+>
+> Otherwise, Gradio may incorrectly assume the service hasn't started, causing startup to exit directly.
+
+When first running `app.py`, it will connect to Hugging Face to download MuQ-related weights. We recommend creating an empty folder in an appropriate location and using `export HF_HOME=XXX` to point to this folder, so cache will be stored there for easy cleanup and transfer.
+
+And for users in mainland China, you may need `export HF_ENDPOINT=https://hf-mirror.com`. For details, refer to https://hf-mirror.com/
+
+```bash
+python app.py
+```
+
+### 3. Python Code
+
+You can refer to the file `src/SongFormer/infer/infer.py`. The corresponding execution script is located at `src/SongFormer/infer.sh`. This is a ready-to-use, single-machine, multi-process annotation script.
+
+Below are some configurable parameters from the `src/SongFormer/infer.sh` script. You can set `CUDA_VISIBLE_DEVICES` to specify which GPUs to use:
+
+```bash
+-i # Input SCP folder path, each line containing the absolute path to one audio file
+-o # Output directory for annotation results
+--model # Annotation model; the default is 'SongFormer', change if using a fine-tuned model
+--checkpoint # Path to the model checkpoint file
+--config_pat # Path to the configuration file
+-gn # Total number of GPUs to use — should match the number specified in CUDA_VISIBLE_DEVICES
+-tn # Number of processes to run per GPU
+```
+
+You can control which GPUs are used by setting the `CUDA_VISIBLE_DEVICES` environment variable.
+
+### 4. CLI Inference
+
+Coming soon
+
+### 4. Pitfall
+
+- You may need to modify line 121 in `src/third_party/musicfm/model/musicfm_25hz.py` to:
+`S = torch.load(model_path, weights_only=False)["state_dict"]`
+
+## Training
+
+## Citation
+
+If our work and codebase is useful for you, please cite as:
+
+````
+comming soon
+````
+## License
+
+Our code is released under CC-BY-4.0 License.
+
+## Contact Us
+
+
+
+
+
+
+
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd13701ee78b138689f55e7f0d3bd29e2145a2b4
--- /dev/null
+++ b/app.py
@@ -0,0 +1,636 @@
+# import os
+# import sys
+
+# os.chdir(os.path.join("src", "SongFormer"))
+# sys.path.append(os.path.join("..", "third_party"))
+# sys.path.append(".")
+
+import os
+import sys
+# 获取当前文件的绝对路径和脚本名称
+current_file = os.path.abspath(__file__)
+current_dir = os.path.dirname(current_file)
+script_name = os.path.basename(__file__)
+print(f"[INFO] 正在运行脚本:{script_name}")
+print(f"[INFO] 当前文件所在目录为:{current_dir}")
+# 设置工作目录为 `src/SongFormer`(如果该路径存在)
+songformer_path = os.path.join(current_dir, "src", "SongFormer")
+if os.path.exists(songformer_path):
+ os.chdir(songformer_path)
+ print(f"[INFO] 工作目录已修改为:{songformer_path}")
+else:
+ print(f"[WARNING] 目标工作目录不存在:{songformer_path}")
+# 获取当前工作目录,即运行 os.chdir 后的路径
+working_dir = os.getcwd()
+print(f"[INFO] 当前工作目录为:{working_dir}")
+# 添加第三方库路径到 sys.path(third_party)
+third_party_path = os.path.join(current_dir, "third_party")
+if os.path.exists(third_party_path):
+ sys.path.insert(0, third_party_path)
+ print(f"[INFO] 已添加第三方库路径到 sys.path:{third_party_path}")
+else:
+ print(f"[WARNING] third_party 路径不存在:{third_party_path}")
+# 添加当前工作目录到 sys.path(通常是 src/SongFormer)
+sys.path.insert(0, working_dir)
+print(f"[INFO] 已添加当前工作目录到 sys.path:{working_dir}")
+# 尝试添加多个可能用于 musicfm 导入的路径
+musicfm_paths = [
+ os.path.join(current_dir, "src"),
+ os.path.join(current_dir, "third_party"),
+ os.path.join(current_dir, "src", "SongFormer"),
+]
+for path in musicfm_paths:
+ if os.path.exists(path):
+ sys.path.insert(0, path)
+ print(f"[INFO] 已添加路径到 sys.path:{path}")
+ else:
+ print(f"[DEBUG] 路径不存在,跳过添加:{path}")
+# 可选:打印 sys.path 的当前状态
+print("\n[DEBUG] 当前 sys.path 设置如下:")
+for idx, p in enumerate(sys.path):
+ print(f" {idx}: {p}")
+
+# monkey patch to fix issues in msaf
+import scipy
+import numpy as np
+
+scipy.inf = np.inf
+
+import gradio as gr
+import torch
+import librosa
+import json
+import math
+import importlib
+import matplotlib.pyplot as plt
+import matplotlib.ticker as ticker
+from pathlib import Path
+from argparse import Namespace
+from omegaconf import OmegaConf
+from ema_pytorch import EMA
+from muq import MuQ
+from musicfm.model.musicfm_25hz import MusicFM25Hz
+from postprocessing.functional import postprocess_functional_structure
+from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
+from utils.fetch_pretrained import download_all
+
+# Constants
+MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
+BEFORE_DOWNSAMPLING_FRAME_RATES = 25
+AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
+DATASET_LABEL = "SongForm-HX-8Class"
+DATASET_IDS = [5]
+TIME_DUR = 420
+INPUT_SAMPLING_RATE = 24000
+
+# Global model variables
+muq_model = None
+musicfm_model = None
+msa_model = None
+device = None
+
+
+def load_checkpoint(checkpoint_path, device=None):
+ """Load checkpoint from path"""
+ if device is None:
+ device = "cpu"
+
+ if checkpoint_path.endswith(".pt"):
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ elif checkpoint_path.endswith(".safetensors"):
+ from safetensors.torch import load_file
+
+ checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
+ else:
+ raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
+ return checkpoint
+
+
+def initialize_models(model_name: str, checkpoint: str, config_path: str):
+ """Initialize all models"""
+ global muq_model, musicfm_model, msa_model, device
+
+ # Set device
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ # Load MuQ
+ muq_model = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
+ muq_model = muq_model.to(device).eval()
+
+ # Load MusicFM
+ musicfm_model = MusicFM25Hz(
+ is_flash=False,
+ stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
+ model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
+ )
+ musicfm_model = musicfm_model.to(device).eval()
+
+ # Load MSA model
+ module = importlib.import_module("models." + str(model_name))
+ Model = getattr(module, "Model")
+ hp = OmegaConf.load(os.path.join("configs", config_path))
+ msa_model = Model(hp)
+
+ ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", checkpoint))
+ if ckpt.get("model_ema", None) is not None:
+ model_ema = EMA(msa_model, include_online_model=False)
+ model_ema.load_state_dict(ckpt["model_ema"])
+ msa_model.load_state_dict(model_ema.ema_model.state_dict())
+ else:
+ msa_model.load_state_dict(ckpt["model"])
+
+ msa_model.to(device).eval()
+
+ return hp
+
+
+def process_audio(audio_path, win_size=420, hop_size=420, num_classes=128):
+ """Process audio file and return structure analysis results"""
+ global muq_model, musicfm_model, msa_model, device
+
+ if muq_model is None:
+ hp = initialize_models()
+ else:
+ hp = OmegaConf.load(os.path.join("configs", "SongFormer.yaml"))
+
+ # Load audio
+ wav, sr = librosa.load(audio_path, sr=INPUT_SAMPLING_RATE)
+ audio = torch.tensor(wav).to(device)
+
+ # Prepare output
+ total_len = (
+ (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR * TIME_DUR
+ ) + TIME_DUR
+ total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
+
+ logits = {
+ "function_logits": np.zeros([total_frames, num_classes]),
+ "boundary_logits": np.zeros([total_frames]),
+ }
+ logits_num = {
+ "function_logits": np.zeros([total_frames, num_classes]),
+ "boundary_logits": np.zeros([total_frames]),
+ }
+
+ # Prepare label masks
+ dataset_id2label_mask = {}
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
+ dataset_id2label_mask[key] = np.ones(num_classes, dtype=bool)
+ dataset_id2label_mask[key][allowed_ids] = False
+
+ lens = 0
+ i = 0
+
+ with torch.no_grad():
+ while True:
+ start_idx = i * INPUT_SAMPLING_RATE
+ end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
+ if start_idx >= audio.shape[-1]:
+ break
+ if end_idx - start_idx <= 1024:
+ continue
+
+ audio_seg = audio[start_idx:end_idx]
+
+ # Get embeddings
+ muq_output = muq_model(audio_seg.unsqueeze(0), output_hidden_states=True)
+ muq_embd_420s = muq_output["hidden_states"][10]
+ del muq_output
+ torch.cuda.empty_cache()
+
+ _, musicfm_hidden_states = musicfm_model.get_predictions(
+ audio_seg.unsqueeze(0)
+ )
+ musicfm_embd_420s = musicfm_hidden_states[10]
+ del musicfm_hidden_states
+ torch.cuda.empty_cache()
+
+ # Process 30-second segments
+ wraped_muq_embd_30s = []
+ wraped_musicfm_embd_30s = []
+
+ for idx_30s in range(i, i + hop_size, 30):
+ start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
+ end_idx_30s = min(
+ (idx_30s + 30) * INPUT_SAMPLING_RATE,
+ audio.shape[-1],
+ (i + hop_size) * INPUT_SAMPLING_RATE,
+ )
+ if start_idx_30s >= audio.shape[-1]:
+ break
+ if end_idx_30s - start_idx_30s <= 1024:
+ continue
+
+ wraped_muq_embd_30s.append(
+ muq_model(
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0),
+ output_hidden_states=True,
+ )["hidden_states"][10]
+ )
+ torch.cuda.empty_cache()
+
+ wraped_musicfm_embd_30s.append(
+ musicfm_model.get_predictions(
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0)
+ )[1][10]
+ )
+ torch.cuda.empty_cache()
+
+ if wraped_muq_embd_30s:
+ wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
+ wraped_musicfm_embd_30s = torch.concatenate(
+ wraped_musicfm_embd_30s, dim=1
+ )
+
+ all_embds = [
+ wraped_musicfm_embd_30s,
+ wraped_muq_embd_30s,
+ musicfm_embd_420s,
+ muq_embd_420s,
+ ]
+
+ # Align embedding lengths
+ if len(all_embds) > 1:
+ embd_lens = [x.shape[1] for x in all_embds]
+ min_embd_len = min(embd_lens)
+ for idx in range(len(all_embds)):
+ all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
+
+ embd = torch.concatenate(all_embds, axis=-1)
+
+ # Inference
+ dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
+ msa_info, chunk_logits = msa_model.infer(
+ input_embeddings=embd,
+ dataset_ids=dataset_ids,
+ label_id_masks=torch.Tensor(
+ dataset_id2label_mask[
+ DATASET_LABEL_TO_DATASET_ID[DATASET_LABEL]
+ ]
+ )
+ .to(device, dtype=bool)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ with_logits=True,
+ )
+
+ # Accumulate logits
+ start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
+ end_frame = start_frame + min(
+ math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
+ chunk_logits["boundary_logits"][0].shape[0],
+ )
+
+ logits["function_logits"][start_frame:end_frame, :] += (
+ chunk_logits["function_logits"][0].detach().cpu().numpy()
+ )
+ logits["boundary_logits"][start_frame:end_frame] = (
+ chunk_logits["boundary_logits"][0].detach().cpu().numpy()
+ )
+ logits_num["function_logits"][start_frame:end_frame, :] += 1
+ logits_num["boundary_logits"][start_frame:end_frame] += 1
+ lens += end_frame - start_frame
+
+ i += hop_size
+
+ # Average logits
+ logits["function_logits"] /= np.maximum(logits_num["function_logits"], 1)
+ logits["boundary_logits"] /= np.maximum(logits_num["boundary_logits"], 1)
+
+ logits["function_logits"] = torch.from_numpy(
+ logits["function_logits"][:lens]
+ ).unsqueeze(0)
+ logits["boundary_logits"] = torch.from_numpy(
+ logits["boundary_logits"][:lens]
+ ).unsqueeze(0)
+
+ # Post-process
+ msa_infer_output = postprocess_functional_structure(logits, hp)
+
+ return logits, msa_infer_output
+
+
+def format_as_segments(msa_output):
+ """Format as list of segments"""
+ segments = []
+ for idx in range(len(msa_output) - 1):
+ segments.append(
+ {
+ "start": str(round(msa_output[idx][0], 2)),
+ "end": str(round(msa_output[idx + 1][0], 2)),
+ "label": msa_output[idx][1],
+ }
+ )
+ return segments
+
+
+def format_as_msa(msa_output):
+ """Format as MSA format"""
+ lines = []
+ for time, label in msa_output:
+ lines.append(f"{time:.2f} {label}")
+ return "\n".join(lines)
+
+
+def format_as_json(segments):
+ """Format as JSON"""
+ return json.dumps(segments, indent=2, ensure_ascii=False)
+
+
+def create_visualization(
+ logits, msa_output, label_num=8, frame_rates=AFTER_DOWNSAMPLING_FRAME_RATES
+):
+ """Create visualization plot"""
+ # Assume ID_TO_LABEL mapping exists
+ try:
+ from dataset.label2id import ID_TO_LABEL
+ except:
+ ID_TO_LABEL = {i: f"Class_{i}" for i in range(128)}
+
+ function_vals = logits["function_logits"].squeeze().cpu().numpy()
+ boundary_vals = logits["boundary_logits"].squeeze().cpu().numpy()
+
+ top_classes = np.argsort(function_vals.mean(axis=0))[-label_num:]
+ T = function_vals.shape[0]
+ time_axis = np.arange(T) / frame_rates
+
+ fig, ax = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
+
+ # Plot function logits
+ for cls in top_classes:
+ ax[1].plot(
+ time_axis,
+ function_vals[:, cls],
+ label=f"{ID_TO_LABEL.get(cls, f'Class_{cls}')}",
+ )
+
+ ax[1].set_title("Top 8 Function Logits by Mean Activation")
+ ax[1].set_xlabel("Time (seconds)")
+ ax[1].set_ylabel("Logit")
+ ax[1].xaxis.set_major_locator(ticker.MultipleLocator(20))
+ ax[1].xaxis.set_minor_locator(ticker.MultipleLocator(5))
+ ax[1].xaxis.set_major_formatter(ticker.FormatStrFormatter("%.1f"))
+ ax[1].legend()
+ ax[1].grid(True)
+
+ # Plot boundary logits
+ ax[0].plot(time_axis, boundary_vals, label="Boundary Logit", color="orange")
+ ax[0].set_title("Boundary Logits")
+ ax[0].set_ylabel("Logit")
+ ax[0].legend()
+ ax[0].grid(True)
+
+ # Add vertical lines for markers
+ for t_sec, label in msa_output:
+ for a in ax:
+ a.axvline(x=t_sec, color="red", linestyle="--", linewidth=0.8, alpha=0.7)
+ if label != "end":
+ ax[1].text(
+ t_sec + 0.3,
+ ax[1].get_ylim()[1] * 0.85,
+ label,
+ rotation=90,
+ fontsize=8,
+ color="red",
+ )
+
+ plt.suptitle("Music Structure Analysis - Logits Overview", fontsize=16)
+ plt.tight_layout()
+
+ return fig
+
+
+def rule_post_processing(msa_list):
+ if len(msa_list) <= 2:
+ return msa_list
+
+ result = msa_list.copy()
+
+ while len(result) > 2:
+ first_duration = result[1][0] - result[0][0]
+ if first_duration < 1.0 and len(result) > 2:
+ result[0] = (result[0][0], result[1][1])
+ result = [result[0]] + result[2:]
+ else:
+ break
+
+ while len(result) > 2:
+ last_label_duration = result[-1][0] - result[-2][0]
+ if last_label_duration < 1.0:
+ result = result[:-2] + [result[-1]]
+ else:
+ break
+
+ while len(result) > 2:
+ if result[0][1] == result[1][1] and result[1][0] <= 10.0:
+ result = [(result[0][0], result[0][1])] + result[2:]
+ else:
+ break
+
+ while len(result) > 2:
+ last_duration = result[-1][0] - result[-2][0]
+ if result[-2][1] == result[-3][1] and last_duration <= 10.0:
+ result = result[:-2] + [result[-1]]
+ else:
+ break
+
+ return result
+
+
+def process_and_analyze(audio_file):
+ """Main processing function"""
+
+ def format_time(t: float) -> str:
+ minutes = int(t // 60)
+ seconds = t % 60
+ return f"{minutes:02d}:{seconds:06.3f}" # 这个格式是正确的
+
+ if audio_file is None:
+ return None, "", "", None
+
+ try:
+ # Process audio
+ logits, msa_output = process_audio(audio_file)
+ # Apply rule-based post-processing, if not needed, use in cli infer
+ msa_output = rule_post_processing(msa_output)
+ # Format outputs
+ segments = format_as_segments(msa_output)
+ msa_format = format_as_msa(msa_output)
+ json_format = format_as_json(segments)
+
+ # Create table data
+ table_data = [
+ [
+ f"{float(seg['start']):.2f} ({format_time(float(seg['start']))})",
+ f"{float(seg['end']):.2f} ({format_time(float(seg['end']))})",
+ seg["label"],
+ ]
+ for seg in segments
+ ]
+
+ # Create visualization
+ fig = create_visualization(logits, msa_output)
+
+ return table_data, json_format, msa_format, fig
+
+ except Exception as e:
+ import traceback
+
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
+ print(error_msg) # 在命令行输出完整错误
+ return None, "", error_msg, None
+
+
+# Create Gradio interface
+with gr.Blocks(
+ title="Music Structure Analysis",
+ css="""
+ .logo-container {
+ text-align: center;
+ margin-bottom: 20px;
+ }
+ .links-container {
+ display: flex;
+ justify-content: center;
+ column-gap: 10px;
+ margin-bottom: 10px;
+ }
+ .model-title {
+ text-align: center;
+ font-size: 24px;
+ font-weight: bold;
+ margin-bottom: 30px;
+ }
+ """,
+) as demo:
+ # Top Logo
+ gr.HTML("""
+
+

+
+ """)
+
+ # Model title
+ gr.HTML("""
+
+ SongFormer: Scaling Music Structure Analysis with Heterogeneous Supervision
+
+ """)
+
+ # Links
+ gr.HTML("""
+
+ """)
+
+ # Main input area
+ with gr.Row():
+ with gr.Column(scale=3):
+ audio_input = gr.Audio(
+ label="Upload Audio File", type="filepath", elem_id="audio-input"
+ )
+
+ with gr.Column(scale=1):
+ gr.Markdown("### 📌 Examples")
+ gr.Examples(
+ examples=[
+ # Add your example audio file paths
+ # ["example1.mp3"],
+ # ["example2.mp3"],
+ ],
+ inputs=[audio_input],
+ label="Click to load example",
+ )
+
+ # Analyze button
+ with gr.Row():
+ analyze_btn = gr.Button(
+ "🚀 Analyze Music Structure", variant="primary", scale=1
+ )
+
+ # Results display area
+ with gr.Row():
+ with gr.Column(scale=13):
+ segments_table = gr.Dataframe(
+ headers=["Start / s (m:s.ms)", "End / s (m:s.ms)", "Label"],
+ label="Detected Music Segments",
+ interactive=False,
+ elem_id="result-table",
+ )
+ with gr.Column(scale=8):
+ with gr.Row():
+ with gr.Accordion("📄 JSON Output", open=False):
+ json_output = gr.Textbox(
+ label="JSON Format",
+ lines=15,
+ max_lines=20,
+ interactive=False,
+ show_copy_button=True,
+ )
+ with gr.Row():
+ with gr.Accordion("📋 MSA Text Output", open=False):
+ msa_output = gr.Textbox(
+ label="MSA Format",
+ lines=15,
+ max_lines=20,
+ interactive=False,
+ show_copy_button=True,
+ )
+
+ # Visualization plot
+ with gr.Row():
+ plot_output = gr.Plot(label="Activation Curves Visualization")
+
+ gr.HTML("""
+
+

+
+ """)
+
+ # Set event handlers
+ analyze_btn.click(
+ fn=process_and_analyze,
+ inputs=[audio_input],
+ outputs=[segments_table, json_output, msa_output, plot_output],
+ )
+
+if __name__ == "__main__":
+ # Download pretrained models if not exist
+ download_all(use_mirror=False)
+ # Initialize models
+ print("Initializing models...")
+ initialize_models(
+ model_name="SongFormer",
+ checkpoint="SongFormer.safetensors",
+ config_path="SongFormer.yaml",
+ )
+ print("Models loaded successfully!")
+
+ # Launch interface
+ demo.launch(server_name="127.0.0.1", server_port=7891, debug=True)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9a4bb60b8252e6fcb8e14632e06cd38fc8e19027
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,86 @@
+# Core Deep Learning Framework
+torch==2.4.0
+torchaudio==2.4.0
+lightning==2.5.1.post0
+
+# ML/DL Libraries
+transformers==4.51.1
+accelerate==1.5.2
+datasets==3.6.0
+tokenizers==0.21.1
+huggingface-hub==0.30.1
+safetensors==0.5.3
+
+# Scientific Computing
+numpy==1.25.0
+scipy==1.15.2
+scikit-learn==1.6.1
+pandas==2.2.3
+
+# Audio Processing
+librosa==0.11.0
+audioread==3.0.1
+soundfile==0.13.1
+pesq==0.0.4
+auraloss==0.4.0
+nnAudio==0.3.3
+julius==0.2.7
+soxr==0.5.0.post1
+mir_eval==0.8.2
+jams==0.3.4
+msaf==0.1.80
+
+# Visualization & Monitoring
+matplotlib==3.10.1
+seaborn==0.13.2
+tensorboard==2.19.0
+wandb==0.19.8
+gpustat==1.1.1
+
+# Configuration & CLI
+hydra-core==1.3.2
+omegaconf==2.3.0
+fire==0.7.1
+click==8.1.8
+
+# Deep Learning Utils
+einops==0.8.1
+einx==0.3.0
+x-transformers==2.4.14
+x-clip==0.14.4
+ema-pytorch==0.7.7
+schedulefree==1.4.1
+torchmetrics==1.7.1
+
+# Data Processing
+h5py==3.13.0
+pyarrow==19.0.1
+pillow==11.1.0
+
+# Text Processing
+ftfy==6.3.1
+regex==2024.11.6
+pypinyin==0.54.0
+textgrid==1.6.1
+pylrc==0.1.2
+
+# Model Management
+modelscope==1.27.1
+
+# Utilities
+tqdm==4.67.1
+loguru==0.7.3
+joblib==1.4.2
+easydict==1.13
+addict==2.4.0
+beartype==0.21.0
+
+# Others
+triton==3.0.0
+muq==0.1.0
+vmo==0.30.5
+
+# others
+gradio
+einops
+beartype
\ No newline at end of file
diff --git a/src/SongFormer/ckpts/md5sum.txt b/src/SongFormer/ckpts/md5sum.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f22671d6756d8dc184c9b9561cdd13fbd6be7c4
--- /dev/null
+++ b/src/SongFormer/ckpts/md5sum.txt
@@ -0,0 +1,4 @@
+df930aceac8209818556c4a656a0714c MusicFM/pretrained_msd.pt
+75ab2e47b093e07378f7f703bdb82c14 MusicFM/msd_stats.json
+5a24800e12ab357744f8b47e523ba3e6 SongFormer.safetensors
+2c66c0bb91364e318e90dbc2d9a79ee2 _SongFormer.pt
\ No newline at end of file
diff --git a/src/SongFormer/configs/SongFormer.yaml b/src/SongFormer/configs/SongFormer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d13566aeb43e0f96037ed51f1db566a532e07ce1
--- /dev/null
+++ b/src/SongFormer/configs/SongFormer.yaml
@@ -0,0 +1,186 @@
+# ============================
+# Model Configuration
+# ============================
+
+input_dim_raw: 4096 # Downsampled Fused SSL Representation Dimension
+input_dim: 2048 # Input Dimension after Linear Layer
+
+# Downsampling Module
+down_sample_conv_kernel_size: 3
+down_sample_conv_stride: 3
+down_sample_conv_dropout: 0.1
+down_sample_conv_padding: 0
+
+# Transformer Module
+transformer_encoder_input_dim: 1024
+transformer_input_dim: 512
+num_transformer_layers: 4
+transformer_nhead: 8
+transformer_dropout: 0.1
+
+# task-specific heads
+boundary_head_hidden_dims: [128, 64, 8]
+function_head_hidden_dims: []
+
+num_classes: 128
+num_dataset_classes: 64
+
+# scheduler
+warmup_steps: 300
+total_steps: 12010
+warmup_max_lr: 0.0001
+
+# frame rates after downsampling
+output_logits_frame_rates: 8.333
+# it means output_logits_frame_rates = input_embd_frame_rates // downsample_rates, because the padding is 0.
+downsample_rates: 3
+# frame rates after downsampling, used by model and post process
+frame_rates: 8.333
+
+# ema config
+ema_kwargs:
+ {update_after_step: 200}
+
+# ============================
+# Loss Functions configuration
+# ============================
+
+# Focal loss
+label_focal_loss_weight: 0.2
+
+label_focal_loss_alpha: 0.25
+label_focal_loss_gamma: 2.0
+
+# Boundary TV loss
+boundary_tvloss_weight: 0.05
+
+boundary_tv_loss_beta: 0.6
+boundary_tv_loss_lambda: 0.4
+boundary_tv_loss_boundary_threshold: 0.01
+boundary_tv_loss_reduction_weight: 0.1
+
+loss_weight_section: 0.2
+loss_weight_function: 0.8
+
+# ============================
+# Training config
+# ============================
+
+# Number of neighbors used to augment boundaries in the dataset.
+# Example: 1/25*3 * 10s = 1.2s (both sides total 4.2s)
+num_neighbors: 10
+learn_label: true
+learn_segment: true
+accumulation_steps: 2
+slice_dur: 420
+early_stopping_step: 3
+local_maxima_filter_size: 3
+
+# ============================
+# Dataset config
+# ============================
+
+train_dataset:
+ _target_: dataset.SongFormerDataset.Dataset
+ dataset_abstracts:
+ [
+ {
+ "internal_tmp_id": "SongForm-HX-8Class",
+ "dataset_type": "SongForm-HX-8Class",
+ "input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
+ "label_path": "your_data_dir/labels/harmonixset_8class_rule_revision.jsonl",
+ "split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/train.txt",
+ "multiplier": 4,
+ },
+ {
+ "internal_tmp_id": "SongForm-Private",
+ "dataset_type": "SongForm-Private",
+ "input_embedding_dir": "your_data_dir/30s_420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/30s_420s/Internal_data/muq_hop420/layer_10 your_data_dir/420s/Internal_data/musicfm_hop420/layer_10 your_data_dir/420s/Internal_data/muq_hop420/layer_10",
+ "label_path": "your_data_dir/labels/0006_single_layer_transformer_musicfm_muq_along_time_00_5k_v1.jsonl",
+ "split_ids_path": "your_data_dir/separated_ids/internal_data_sofa_clean/train.txt",
+ "multiplier": 1,
+ },
+ {
+ adapter: HookTheoryAdapter,
+ internal_tmp_id: "SongForm-Hook",
+ structure_jsonl_paths: [
+ "your_data_dir/HookTheoryStructure.train.jsonl"
+ ],
+ dataset_type: "SongForm-Hook",
+ input_embedding_dir: "your_data_dir/30s_420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/30s_420s/HookTheory/muq_hop420/layer_10 your_data_dir/420s/HookTheory/musicfm_hop420/layer_10 your_data_dir/420s/HookTheory/muq_hop420/layer_10",
+ split_ids_path: "your_data_dir/separated_ids/hooktheory_separated_ids/train.txt",
+ multiplier: 1,
+ },
+ ]
+ hparams:
+ output_logits_frame_rates: ${output_logits_frame_rates}
+ downsample_rates: ${downsample_rates}
+ num_neighbors: ${num_neighbors}
+ input_dim: ${input_dim_raw}
+ slice_dur: ${slice_dur}
+ num_classes: ${num_classes}
+ frame_rates: ${frame_rates}
+
+eval_dataset:
+ _target_: dataset.SongFormerDataset.Dataset
+ dataset_abstracts:
+ [
+ {
+ "internal_tmp_id": "SongForm-HX-8Classs_val",
+ "dataset_type": "SongForm-HX-8Class",
+ "input_embedding_dir": "your_data_dir/30s_420s/harmonix/musicfm_hop420/layer_10 your_data_dir/30s_420s/harmonix/muq_hop420/layer_10 your_data_dir/420s/harmonix/musicfm_hop420/layer_10 your_data_dir/420s/harmonix/muq_hop420/layer_10",
+ "label_path": "your_data_dir/processed_data/labels/harmonixset_8class_rule_revision.jsonl",
+ "split_ids_path": "your_data_dir/separated_ids/harmonixset_separated_ids_with_val_set/val.txt",
+ "multiplier": 1,
+ },
+ ]
+ hparams:
+ output_logits_frame_rates: ${output_logits_frame_rates}
+ downsample_rates: ${downsample_rates}
+ num_neighbors: ${num_neighbors}
+ input_dim: ${input_dim_raw}
+ slice_dur: ${slice_dur}
+ num_classes: ${num_classes}
+ frame_rates: ${frame_rates}
+
+# ============================
+# DataLoader configuration
+# ============================
+
+train_dataloader:
+ num_workers: 4
+ batch_size: 4
+ pin_memory: True
+ prefetch_factor: 4
+ drop_last: True
+ persistent_workers: True
+ shuffle: true
+
+eval_dataloader:
+ num_workers: 0
+ batch_size: 1
+ shuffle: false
+
+# ============================
+# Optimizer configuration
+# ============================
+
+optimizer:
+ lr: ${warmup_max_lr}
+ betas: [0.8, 0.999]
+ eps: 1e-08
+ weight_decay: 3e-7
+
+# ============================
+# Training Run configuration
+# ============================
+
+args:
+ run_name: SongFormer
+ model_name: SongFormer
+ save_interval: 800
+ eval_interval: 800
+ checkpoint_dir: output/SongFormer
+ max_epochs: 1000
+ max_steps: 12010
+ tags: null
\ No newline at end of file
diff --git a/src/SongFormer/dataset/DatasetAdaper.py b/src/SongFormer/dataset/DatasetAdaper.py
new file mode 100644
index 0000000000000000000000000000000000000000..abc64f597868ca8befbf979b1ec200fe1e071ad7
--- /dev/null
+++ b/src/SongFormer/dataset/DatasetAdaper.py
@@ -0,0 +1,33 @@
+from abc import ABC, abstractmethod
+
+
+class DatasetAdapter(ABC):
+ """
+ Abstract base class for dataset adapters.
+ """
+
+ @abstractmethod
+ def __init__(self, *args, **kwargs):
+ """
+ Initialize the dataset adapter with necessary parameters.
+ """
+ raise NotImplementedError("Subclasses must implement the __init__ method.")
+
+ @abstractmethod
+ def get_ids(self):
+ """
+ Get the IDs of the dataset.
+ This method should be implemented by subclasses.
+
+ Returns:
+ A list or set of IDs representing the dataset. In format: ID + start_time
+ must cosider the split of dataset, e.g. train, val, test.
+ """
+ raise NotImplementedError("Subclasses must implement this method.")
+
+ @abstractmethod
+ def get_item_json(self, *args, **kwargs):
+ """
+ Get the item JSON representation from the dataset.
+ """
+ raise NotImplementedError("Subclasses must implement this method.")
diff --git a/src/SongFormer/dataset/GeminiOnlyLabelAdapter.py b/src/SongFormer/dataset/GeminiOnlyLabelAdapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..81a570ecbc1364db3ee694a503a0a4376265431a
--- /dev/null
+++ b/src/SongFormer/dataset/GeminiOnlyLabelAdapter.py
@@ -0,0 +1,332 @@
+# 1. It was found that the annotations generated by Gemini are discontinuous between segments
+# (possibly differing by more than 1.7 seconds, accounting for approximately 1/4 to 1/3 of the cases).
+# 2. Gemini's labels can compete with our SOTA model, but Gemini's boundary metrics are very poor.
+# With a tolerance of 3 seconds, they are similar to the metrics of our best model.
+import pdb
+import random
+import os
+from collections import defaultdict
+from pathlib import Path
+import json
+from venv import logger
+import numpy as np
+import math
+from .label2id import (
+ DATASET_ID_ALLOWED_LABEL_IDS,
+ DATASET_LABEL_TO_DATASET_ID,
+ ID_TO_LABEL,
+ LABEL_TO_ID,
+)
+from argparse import Namespace
+from scipy.ndimage import gaussian_filter1d
+from .DatasetAdaper import DatasetAdapter
+from omegaconf import ListConfig
+import copy
+
+
+# Adapter for datasets labeled only by Gemini
+class GeminiOnlyLabelAdapter(DatasetAdapter):
+ def __init__(self, **kwargs):
+ (
+ label_paths,
+ hparams,
+ internal_tmp_id,
+ dataset_type,
+ input_embedding_dir,
+ split_ids_path,
+ ) = (
+ kwargs["label_paths"],
+ kwargs["hparams"],
+ kwargs["internal_tmp_id"],
+ kwargs["dataset_type"],
+ kwargs["input_embedding_dir"],
+ kwargs["split_ids_path"],
+ )
+ self.frame_rates = hparams.frame_rates
+ self.hparams = hparams
+ self.label_to_id = LABEL_TO_ID
+ self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
+ self.id_to_label = ID_TO_LABEL
+ self.internal_tmp_id = internal_tmp_id
+ self.dataset_type = dataset_type
+ self.EPS = 1e-6
+ self.dataset_id2label_mask = {}
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
+ self.dataset_id2label_mask[key] = np.ones(
+ self.hparams.num_classes, dtype=bool
+ )
+ self.dataset_id2label_mask[key][allowed_ids] = False
+
+ self.id2segments = {}
+ data = self.load_jsonl(label_paths)
+
+ self.input_embedding_dir = input_embedding_dir
+ all_input_embedding_dirs = input_embedding_dir.split()
+
+ valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
+
+ for x in all_input_embedding_dirs:
+ valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
+ split_ids = []
+ with open(split_ids_path) as f:
+ for line in f:
+ if not line.strip():
+ continue
+ split_ids.append(line.strip())
+ split_ids = set(split_ids)
+
+ valid_data_ids = [
+ x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
+ ]
+ valid_data_ids = [
+ (internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
+ for x in valid_data_ids
+ ]
+ self.valid_data_ids = valid_data_ids
+ rng = random.Random(42)
+ rng.shuffle(self.valid_data_ids)
+ for item in data:
+ self.id2segments[item["data_id"]] = item["msa_info"]
+
+ def get_ids_from_dir(self, dir_path: str):
+ ids = os.listdir(dir_path)
+ ids = [Path(x).stem for x in ids if x.endswith(".npy")]
+ return set(ids)
+
+ def time2frame(self, this_time):
+ return int(this_time * self.frame_rates)
+
+ def load_jsonl(self, paths):
+ data = []
+ for path in paths:
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ obj = json.loads(line)
+ data.append(obj)
+ return data
+
+ def get_ids(self):
+ return list(self.valid_data_ids)
+
+ def widen_temporal_events(self, events, num_neighbors):
+ def theoretical_gaussian_max(sigma):
+ return 1 / (np.sqrt(2 * np.pi) * sigma)
+
+ widen_events = events
+ sigma = num_neighbors / 3.0
+ smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
+ smoothed /= theoretical_gaussian_max(sigma)
+ smoothed = np.clip(smoothed, 0, 1)
+
+ return smoothed
+
+ def get_item_json(self, utt, start_time, end_time):
+ embd_list = []
+ embd_dirs = self.input_embedding_dir.split()
+ for embd_dir in embd_dirs:
+ if not Path(embd_dir).exists():
+ raise FileNotFoundError(
+ f"Embedding directory {embd_dir} does not exist"
+ )
+ tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
+ embd_list.append(tmp)
+
+ # Check that max and min lengths of all representations differ by at most 2
+ if len(embd_list) > 1:
+ embd_shapes = [x.shape for x in embd_list]
+ max_shape = max(embd_shapes, key=lambda x: x[0])
+ min_shape = min(embd_shapes, key=lambda x: x[0])
+ if abs(max_shape[0] - min_shape[0]) > 2:
+ raise ValueError(
+ f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
+ )
+
+ for idx in range(len(embd_list)):
+ embd_list[idx] = embd_list[idx][: min_shape[0], :]
+
+ input_embedding = np.concatenate(embd_list, axis=-1)
+
+ return_json = self._get_item_json_without_embedding(
+ "_".join(utt.split("_")[:-1]), start_time, end_time
+ )
+
+ if return_json is None:
+ logger.warning(
+ f"Skip {utt} because no valid segments found in {start_time} to {end_time}."
+ )
+ return None
+ else:
+ return_json["input_embedding"] = input_embedding
+ return return_json
+
+ def get_local_times_labels(self, utt):
+ assert utt in self.id2segments, f"utt {utt} not found in id2segments"
+ time_datas = [x[0] for x in self.id2segments[utt]]
+ time_datas = list(map(float, time_datas))
+ label_datas = [
+ -1 if x[1] == "end" else self.label_to_id[x[1]]
+ for x in self.id2segments[utt]
+ ]
+ return np.array(time_datas), label_datas
+
+ def _get_item_json_without_embedding(self, utt, start_time, end_time):
+ SLICE_DUR = int(math.ceil(end_time - start_time))
+
+ local_times, local_labels = self.get_local_times_labels(utt)
+
+ local_times, local_labels = (
+ copy.deepcopy(local_times),
+ copy.deepcopy(local_labels),
+ )
+
+ assert np.all(local_times[:-1] < local_times[1:]), (
+ f"time must be sorted, but {utt} is {local_times}"
+ )
+
+ local_times = local_times - start_time
+
+ time_L = max(0.0, float(local_times.min()))
+ time_R = min(float(SLICE_DUR), float(local_times.max()))
+ # Note whether boundary labels are reachable
+ keep_boundarys = (time_L + self.EPS < local_times) & (
+ local_times < time_R - self.EPS
+ )
+
+ # If no valid boundaries, return None
+ if keep_boundarys.sum() <= 0:
+ return None
+
+ mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
+ mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
+
+ true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
+ for idx in np.flatnonzero(keep_boundarys):
+ true_boundary[self.time2frame(local_times[idx])] = 1
+
+ true_function = np.zeros(
+ [int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
+ dtype=float,
+ )
+ true_function_list = []
+ msa_info = []
+ last_pos = self.time2frame(time_L)
+ for idx in np.flatnonzero(keep_boundarys):
+
+ true_function[
+ last_pos : self.time2frame(local_times[idx]),
+ int(local_labels[idx - 1]),
+ ] = 1
+ true_function_list.append(
+ [int(x) for x in local_labels[idx - 1]]
+ if isinstance(local_labels[idx - 1], list)
+ else int(local_labels[idx - 1])
+ )
+ msa_info.append(
+ (
+ float(max(local_times[idx - 1], time_L)),
+ [str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
+ if isinstance(local_labels[idx - 1], list)
+ else str(self.id_to_label[int(local_labels[idx - 1])]),
+ )
+ )
+ last_pos = self.time2frame(local_times[idx])
+
+ # Check last label correctness
+ true_function[
+ last_pos : self.time2frame(time_R),
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
+ ] = 1
+ true_function_list.append(
+ [int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
+ if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
+ else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
+ )
+
+ msa_info.append(
+ (
+ float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
+ [
+ str(self.id_to_label[int(x)])
+ for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
+ ]
+ if isinstance(
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
+ )
+ else str(
+ self.id_to_label[
+ int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
+ ]
+ ),
+ )
+ )
+ # Append final label at end; decide if it's necessary
+ msa_info.append((float(time_R), "end"))
+
+ # Add boundary_mask & function_mask
+ frame_len = int(SLICE_DUR * self.frame_rates)
+ # During loss computation, boundaries are fully masked
+ boundary_mask = np.ones([frame_len], dtype=bool)
+ function_mask = np.zeros([frame_len], dtype=bool)
+
+ # Set masks according to msa_info
+ for i in range(len(msa_info) - 1):
+ seg_start, seg_label = msa_info[i]
+ seg_end, _ = msa_info[i + 1]
+ start_frame = self.time2frame(seg_start)
+ end_frame = self.time2frame(seg_end)
+
+ # Handle case where label may be string or list
+ is_no_label = (
+ seg_label == "NO_LABEL"
+ if isinstance(seg_label, str)
+ else "NO_LABEL" in seg_label
+ )
+
+ if is_no_label:
+ # function_mask set True
+ function_mask[start_frame:end_frame] = True
+
+ # ------~~------------
+ # During loss computation, boundaries are fully masked
+ boundary_mask = np.ones([frame_len], dtype=bool)
+ function_mask = np.zeros([frame_len], dtype=bool)
+
+ # Set masks according to msa_info
+ for i in range(len(msa_info) - 1):
+ seg_start, seg_label = msa_info[i]
+ seg_end, _ = msa_info[i + 1]
+ start_frame = self.time2frame(seg_start)
+ end_frame = self.time2frame(seg_end)
+
+ # Handle case where label may be string or list
+ is_no_label = (
+ seg_label == "NO_LABEL"
+ if isinstance(seg_label, str)
+ else "NO_LABEL" in seg_label
+ )
+
+ if is_no_label:
+ # function_mask set True
+ function_mask[start_frame:end_frame] = True
+
+ # return all things except for input_embedding
+ return {
+ "data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
+ "mask": mask,
+ "true_boundary": true_boundary,
+ "widen_true_boundary": self.widen_temporal_events(
+ true_boundary, num_neighbors=self.hparams.num_neighbors
+ ),
+ "true_function": true_function,
+ "true_function_list": true_function_list,
+ "msa_info": msa_info,
+ "dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
+ "label_id_mask": self.dataset_id2label_mask[
+ self.dataset_id_to_dataset_id[self.dataset_type]
+ ],
+ "boundary_mask": boundary_mask, # Only effective during loss calculation
+ "function_mask": function_mask, # Only effective during loss calculation
+ }
\ No newline at end of file
diff --git a/src/SongFormer/dataset/HookTheoryAdapter.py b/src/SongFormer/dataset/HookTheoryAdapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c126cf93e622e44624a134f7d94c91d5892aee7
--- /dev/null
+++ b/src/SongFormer/dataset/HookTheoryAdapter.py
@@ -0,0 +1,448 @@
+import random
+import os
+from collections import defaultdict
+from pathlib import Path
+import json
+import numpy as np
+import math
+from .label2id import (
+ DATASET_ID_ALLOWED_LABEL_IDS,
+ DATASET_LABEL_TO_DATASET_ID,
+ ID_TO_LABEL,
+ LABEL_TO_ID,
+)
+from argparse import Namespace
+from scipy.ndimage import gaussian_filter1d
+from .DatasetAdaper import DatasetAdapter
+from omegaconf import ListConfig
+
+
+class HookTheoryAdapter(DatasetAdapter):
+ def __init__(self, **kwargs):
+ (
+ structure_jsonl_paths,
+ hparams,
+ internal_tmp_id,
+ dataset_type,
+ input_embedding_dir,
+ split_ids_path,
+ ) = (
+ kwargs["structure_jsonl_paths"],
+ kwargs["hparams"],
+ kwargs["internal_tmp_id"],
+ kwargs["dataset_type"],
+ kwargs.get("input_embedding_dir", None),
+ kwargs.get("split_ids_path", None),
+ )
+
+ # basic attrs
+ self.frame_rates = hparams.frame_rates
+ self.hparams = hparams
+ self.label_to_id = LABEL_TO_ID
+ self.dataset_id_to_dataset_id = DATASET_LABEL_TO_DATASET_ID
+ self.id_to_label = ID_TO_LABEL
+ self.internal_tmp_id = internal_tmp_id
+ self.dataset_type = dataset_type
+ self.EPS = 1e-6
+
+ # build dataset-specific label mask
+ self.dataset_id2label_mask = {}
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
+ self.dataset_id2label_mask[key] = np.ones(
+ self.hparams.num_classes, dtype=bool
+ )
+ self.dataset_id2label_mask[key][allowed_ids] = False
+
+ assert isinstance(structure_jsonl_paths, (ListConfig, tuple, list))
+
+ # load segments per audio id
+ self.id2segments = defaultdict(list)
+ data = self.load_jsonl(structure_jsonl_paths)
+
+ # input embedding dirs (space-separated)
+ self.input_embedding_dir = input_embedding_dir
+ all_input_embedding_dirs = input_embedding_dir.split()
+
+ # get valid ids that exist in all embedding dirs
+ valid_data_ids = self.get_ids_from_dir(all_input_embedding_dirs[0])
+ for x in all_input_embedding_dirs:
+ valid_data_ids = valid_data_ids.intersection(self.get_ids_from_dir(x))
+
+ # read split ids
+ split_ids = []
+ with open(split_ids_path) as f:
+ for line in f:
+ if not line.strip():
+ continue
+ split_ids.append(line.strip())
+ split_ids = set(split_ids)
+
+ # filter valid ids by split
+ valid_data_ids = [
+ x for x in valid_data_ids if "_".join(x.split("_")[:-1]) in split_ids
+ ]
+ valid_data_ids = [
+ (internal_tmp_id, dataset_type, x, "HookTheoryAdapter")
+ for x in valid_data_ids
+ ]
+ self.valid_data_ids = valid_data_ids
+
+ rng = random.Random(42)
+ rng.shuffle(self.valid_data_ids)
+
+ for item in data:
+ self.id2segments[Path(item["ori_audio_path"]).stem].append(item)
+ # logger.info(f"load {len(self.id2segments)} songs from {structure_jsonl_paths}")
+
+ def get_ids_from_dir(self, dir_path: str):
+ ids = os.listdir(dir_path)
+ ids = [Path(x).stem for x in ids if x.endswith(".npy")]
+ return set(ids)
+
+ def time2frame(self, this_time):
+ # convert time (s) to frame index
+ return int(this_time * self.frame_rates)
+
+ def load_jsonl(self, paths):
+ # load list of jsonl files
+ data = []
+ for path in paths:
+ with open(path, "r", encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ obj = json.loads(line)
+ data.append(obj)
+ return data
+
+ def split_and_label(self, query_start, query_end, segments):
+ """
+ segments: List of dicts, each with keys: "segment_start", "segment_end", 'label'
+ """
+ # Step 1: collect all boundary points (only within query interval)
+ points = set([query_start, query_end])
+ for seg in segments:
+ if query_start <= seg["segment_start"] <= query_end:
+ points.add(seg["segment_start"])
+ if query_start <= seg["segment_end"] <= query_end:
+ points.add(seg["segment_end"])
+ sorted_points = sorted(points)
+
+ result = []
+ # Step 2: for each small interval, check which segments cover it
+ for i in range(len(sorted_points) - 1):
+ part_start = sorted_points[i]
+ part_end = sorted_points[i + 1]
+ labels = []
+ for seg in segments:
+ if (
+ seg["segment_start"] <= part_start
+ and seg["segment_end"] >= part_end
+ ):
+ labels.extend(seg["label"])
+ if not labels:
+ labels = ["NO_LABEL"]
+ result.append(
+ {"segment_start": part_start, "segment_end": part_end, "labels": labels}
+ )
+
+ # deduplicate labels per interval
+ for idx in range(len(result)):
+ result[idx]["labels"] = list(set(result[idx]["labels"]))
+ return result
+
+ def merge_small_intervals(self, parts, min_duration=2.0):
+ """
+ parts: list of dicts with "segment_start", "segment_end", 'labels'
+ Merge intervals shorter than min_duration into neighbor intervals.
+ """
+ new_parts = []
+ i = 0
+ while i < len(parts):
+ part = parts[i]
+ duration = part["segment_end"] - part["segment_start"]
+ if duration < min_duration:
+ # decide where to merge
+ if len(new_parts) > 0 and (i + 1) < len(parts):
+ # randomly choose previous or next
+ if random.choice([True, False]):
+ prev = new_parts[-1]
+ prev["segment_end"] = part["segment_end"]
+ else:
+ next_part = parts[i + 1]
+ next_part["segment_start"] = part["segment_start"]
+ # skip adding this part
+ elif len(new_parts) > 0:
+ # only previous exists - merge into previous
+ prev = new_parts[-1]
+ prev["segment_end"] = part["segment_end"]
+ elif (i + 1) < len(parts):
+ # only next exists - merge into next
+ next_part = parts[i + 1]
+ next_part["segment_start"] = part["segment_start"]
+ # else: nothing to merge, drop
+ i += 1
+ else:
+ new_parts.append(part)
+ i += 1
+ return new_parts
+
+ def rounding_time(self, segments, num_decimals=3):
+ # round segment boundaries to given decimals
+ for idx in range(len(segments)):
+ segments[idx]["segment_start"] = round(
+ segments[idx]["segment_start"], num_decimals
+ )
+ segments[idx]["segment_end"] = round(
+ segments[idx]["segment_end"], num_decimals
+ )
+ return segments
+
+ def get_ids(self):
+ return list(self.valid_data_ids)
+
+ def convert_label(self, label: str):
+ # map various labels to canonical labels
+ mapping = {
+ "chorus": "chorus",
+ "intro": "intro",
+ "bridge": "bridge",
+ "verse": "verse",
+ "pre-chorus": "pre-chorus",
+ "solo": "inst",
+ "instrumental": "inst",
+ "outro": "outro",
+ "NO_LABEL": "NO_LABEL",
+ }
+ assert label in mapping, f"Unknown label: {label}"
+ return mapping[label]
+
+ def parts_to_label_and_times(self, parts, use_random_tag=True):
+ """
+ parts: list of dicts with 'segment_start', 'segment_end', 'labels'
+
+ if use_random_tag: label will be random from valid labels
+ else: label will be all valid labels (labels list)
+
+ return:
+ local_times: np.array of right boundary time points (excluding query_end)
+ local_labels: list of label indices corresponding to self.label_to_id
+ """
+ local_times = []
+ local_labels = []
+
+ for part in parts:
+ local_times.append(part["segment_start"])
+ label = random.choice(part["labels"]) if use_random_tag else part["labels"]
+ local_labels.append(self.label_to_id[self.convert_label(label)])
+ return np.array(local_times), local_labels
+
+ def get_parts(self, utt, query_start, query_end):
+ key = "_".join(utt.split("_")[:-1])
+ assert key in self.id2segments
+ segments = self.id2segments[key]
+ segments = self.rounding_time(segments)
+ parts = self.split_and_label(query_start, query_end, segments)
+
+ # Apply merging twice to remove very short intervals
+ new_parts = self.merge_small_intervals(parts, min_duration=2.0)
+ new_parts = self.merge_small_intervals(new_parts, min_duration=2.0)
+
+ return new_parts
+
+ def widen_temporal_events(self, events, num_neighbors):
+ # smooth binary events with a normalized gaussian
+ def theoretical_gaussian_max(sigma):
+ return 1 / (np.sqrt(2 * np.pi) * sigma)
+
+ widen_events = events
+ sigma = num_neighbors / 3.0
+ smoothed = gaussian_filter1d(widen_events.astype(float), sigma=sigma)
+ smoothed /= theoretical_gaussian_max(sigma)
+ smoothed = np.clip(smoothed, 0, 1)
+
+ return smoothed
+
+ def get_item_json(self, utt, start_time, end_time):
+ # load embeddings from all embedding dirs
+ embd_list = []
+ embd_dirs = self.input_embedding_dir.split()
+ for embd_dir in embd_dirs:
+ if not Path(embd_dir).exists():
+ raise FileNotFoundError(
+ f"Embedding directory {embd_dir} does not exist"
+ )
+ tmp = np.load(Path(embd_dir) / f"{utt}.npy").squeeze(axis=0)
+ embd_list.append(tmp)
+
+ # Check that max/min length difference across embeddings <= 2
+ if len(embd_list) > 1:
+ embd_shapes = [x.shape for x in embd_list]
+ max_shape = max(embd_shapes, key=lambda x: x[0])
+ min_shape = min(embd_shapes, key=lambda x: x[0])
+ if abs(max_shape[0] - min_shape[0]) > 2:
+ raise ValueError(
+ f"Embedding shapes differ too much: {max_shape} vs {min_shape}"
+ )
+
+ for idx in range(len(embd_list)):
+ embd_list[idx] = embd_list[idx][: min_shape[0], :]
+
+ input_embedding = np.concatenate(embd_list, axis=-1)
+
+ return_json = self.get_item_json_without_embedding(utt, start_time, end_time)
+ if return_json is None:
+ return None
+ else:
+ return_json["input_embedding"] = input_embedding
+ return return_json
+
+ def get_item_json_without_embedding(self, utt, start_time, end_time):
+ SLICE_DUR = int(math.ceil(end_time - start_time))
+
+ local_times, local_labels = self.parts_to_label_and_times(
+ self.get_parts(utt, start_time, end_time)
+ )
+
+ assert np.all(local_times[:-1] < local_times[1:]), (
+ f"time must be sorted, but {utt} is {local_times}"
+ )
+
+ # normalize local times relative to slice start
+ local_times = local_times - start_time
+ time_L = 0.0
+ # here time_R is full slice duration because NO_LABEL may appear
+ time_R = float(SLICE_DUR)
+
+ # determine which boundaries are within (time_L, time_R)
+ keep_boundarys = (time_L + self.EPS < local_times) & (
+ local_times < time_R - self.EPS
+ )
+
+ # if no valid boundary, return None
+ if keep_boundarys.sum() <= 0:
+ return None
+
+ mask = np.ones([int(SLICE_DUR * self.frame_rates)], dtype=bool)
+ mask[self.time2frame(time_L) : self.time2frame(time_R)] = False
+
+ true_boundary = np.zeros([int(SLICE_DUR * self.frame_rates)], dtype=float)
+ for idx in np.flatnonzero(keep_boundarys):
+ true_boundary[self.time2frame(local_times[idx])] = 1
+
+ true_function = np.zeros(
+ [int(SLICE_DUR * self.frame_rates), self.hparams.num_classes],
+ dtype=float,
+ )
+ true_function_list = []
+ msa_info = []
+ last_pos = self.time2frame(time_L)
+ for idx in np.flatnonzero(keep_boundarys):
+ # local_labels[idx] might be int or list(int)
+ true_function[
+ last_pos : self.time2frame(local_times[idx]),
+ local_labels[idx - 1],
+ ] = 1
+ true_function_list.append(
+ [int(x) for x in local_labels[idx - 1]]
+ if isinstance(local_labels[idx - 1], list)
+ else int(local_labels[idx - 1])
+ )
+ msa_info.append(
+ (
+ float(max(local_times[idx - 1], time_L)),
+ [str(self.id_to_label[int(x)]) for x in local_labels[idx - 1]]
+ if isinstance(local_labels[idx - 1], list)
+ else str(self.id_to_label[int(local_labels[idx - 1])]),
+ )
+ )
+ last_pos = self.time2frame(local_times[idx])
+
+ # check last label correctness
+ true_function[
+ last_pos : self.time2frame(time_R),
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])],
+ ] = 1
+ true_function_list.append(
+ [int(x) for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]]
+ if isinstance(local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list)
+ else int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
+ )
+ msa_info.append(
+ (
+ float(local_times[int(np.flatnonzero(keep_boundarys)[-1])]),
+ [
+ str(self.id_to_label[int(x)])
+ for x in local_labels[int(np.flatnonzero(keep_boundarys)[-1])]
+ ]
+ if isinstance(
+ local_labels[int(np.flatnonzero(keep_boundarys)[-1])], list
+ )
+ else str(
+ self.id_to_label[
+ int(local_labels[int(np.flatnonzero(keep_boundarys)[-1])])
+ ]
+ ),
+ )
+ )
+ # append final "end" marker
+ msa_info.append((float(time_R), "end"))
+
+ # -------------------------
+ # boundary_mask & function_mask
+ # -------------------------
+ frame_len = int(SLICE_DUR * self.frame_rates)
+ boundary_mask = np.zeros([frame_len], dtype=bool)
+ function_mask = np.zeros([frame_len], dtype=bool)
+
+ # set masks according to msa_info
+ for i in range(len(msa_info) - 1):
+ seg_start, seg_label = msa_info[i]
+ seg_end, _ = msa_info[i + 1]
+ start_frame = self.time2frame(seg_start)
+ end_frame = self.time2frame(seg_end)
+
+ # handle label being string or list
+ is_no_label = (
+ seg_label == "NO_LABEL"
+ if isinstance(seg_label, str)
+ else "NO_LABEL" in seg_label
+ )
+
+ if is_no_label:
+ # set function_mask True for NO_LABEL regions
+ function_mask[start_frame:end_frame] = True
+
+ # set boundary_mask True for regions >4s away from ends
+ left_offset = self.time2frame(seg_start + 4)
+ right_offset = self.time2frame(seg_end - 4)
+ if i == 0:
+ if right_offset > 0:
+ boundary_mask[0 : min(right_offset, frame_len)] = True
+ elif i == len(msa_info) - 2:
+ if left_offset < frame_len:
+ boundary_mask[left_offset:frame_len] = True
+ elif right_offset > left_offset:
+ boundary_mask[left_offset:right_offset] = True
+
+ # -------------------------
+ # return all things except input_embedding
+ # -------------------------
+ return {
+ "data_id": self.internal_tmp_id + "_" + f"{utt}_{start_time}",
+ "mask": mask,
+ "true_boundary": true_boundary,
+ "widen_true_boundary": self.widen_temporal_events(
+ true_boundary, num_neighbors=self.hparams.num_neighbors
+ ),
+ "true_function": true_function,
+ "true_function_list": true_function_list,
+ "msa_info": msa_info,
+ "dataset_id": self.dataset_id_to_dataset_id[self.dataset_type],
+ "label_id_mask": self.dataset_id2label_mask[
+ self.dataset_id_to_dataset_id[self.dataset_type]
+ ],
+ "boundary_mask": boundary_mask, # only effective during loss computation
+ "function_mask": function_mask, # only effective during loss computation
+ }
diff --git a/src/SongFormer/dataset/custom_types.py b/src/SongFormer/dataset/custom_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0af8fdf420e97ca8156c4e4d7d183e4aefb8feb
--- /dev/null
+++ b/src/SongFormer/dataset/custom_types.py
@@ -0,0 +1,14 @@
+"""
+MsaInfo
+ A list of (timestamp, label) tuples used to represent music structure
+ analysis (MSA). The first element of the tuple is a float timestamp
+ (in seconds) and the second is a string label
+
+Example
+-------
+ >>> msa: MsaInfo = [(0.0, "intro"), (12.5, "verse"), (34.0, "chorus")]
+"""
+
+from typing import List, Tuple
+
+MsaInfo = List[Tuple[float, str]]
\ No newline at end of file
diff --git a/src/SongFormer/dataset/label2id.py b/src/SongFormer/dataset/label2id.py
new file mode 100644
index 0000000000000000000000000000000000000000..595d9cdae826cb2b4332012e54e061437fd13c1f
--- /dev/null
+++ b/src/SongFormer/dataset/label2id.py
@@ -0,0 +1,163 @@
+LABEL_TO_ID = {
+ "intro": 0,
+ "verse": 1,
+ "chorus": 2,
+ "bridge": 3,
+ "inst": 4,
+ "outro": 5,
+ "silence": 6,
+ "intchorus": 7,
+ "prechorus": 8,
+ "gtrbreak": 9,
+ "solo": 10,
+ "quietchorus": 11,
+ "bre": 12,
+ "break": 13,
+ "introverse": 14,
+ "mainriff": 15,
+ "chorushalf": 16,
+ "instintro": 17,
+ "gtr": 18,
+ "vocaloutro": 19,
+ "verse_slow": 20,
+ "fadein": 21,
+ "saxobeat": 22,
+ "transition": 23,
+ "verse1a": 24,
+ "build": 25,
+ "pre-chorus": 26,
+ "outroa": 27,
+ "bigoutro": 28,
+ "fast": 29,
+ "instrumentalverse": 30,
+ "section": 31,
+ "choruspart": 32,
+ "instbridge": 33,
+ "guitar": 34,
+ "instrumental": 35,
+ "breakdown": 36,
+ "rhythmlessintro": 37,
+ "intropt": 38,
+ "interlude": 39,
+ "postchorus": 40,
+ "postverse": 41,
+ "opening": 42,
+ "altchorus": 43,
+ "stutter": 44,
+ "oddriff": 45,
+ "synth": 46,
+ "preverse": 47,
+ "quiet": 48,
+ "raps": 49,
+ "verseinst": 50,
+ "instchorus": 51,
+ "chorus_instrumental": 52,
+ "slowverse": 53,
+ "slow": 54,
+ "worstthingever": 55,
+ "transition2a": 56,
+ "miniverse": 57,
+ "refrain": 58,
+ "introchorus": 59,
+ "drumroll": 60,
+ "guitarsolo": 61,
+ "versepart": 62,
+ "chorusinst": 63,
+ "ending": 64,
+ "no-vocal-intro": 65,
+ "no-vocal-interlude": 66,
+ "no-vocal-outro": 67,
+ "NO_LABEL": 68, # Only referring to cases without labels, this portion of labels will be ignored during the loss calculation process.
+}
+
+ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}
+
+# Reserve 64 embedding positions for dataset identifiers in the model.
+DATASET_LABEL_TO_DATASET_ID = {
+ "SongForm-HX-7Class": 0, # Categories after rule mapping for HarmonixSet
+ "SongForm-HX-Widen": 1, # Original HarmonixSet
+ "SongForm-Private-Raw": 2,
+ "SongForm-Private": 3,
+ "SongForm-HX-Gemini-Relabeled": 4, # Rule-mapped HarmonixSet corrected by Gemini
+ "SongForm-HX-8Class": 5, # Rule-mapped (pre-chorus retained)
+ "SongForm-Hook": 6,
+ "SongForm-Gem": 7,
+ "SongForm-Gem-Only-Label": 8, # Use only segments with labels in SongForm-Gem
+}
+
+DATASET_ID_TO_DATASET_LABEL = {v: k for k, v in DATASET_LABEL_TO_DATASET_ID.items()}
+
+DATASET_ID_ALLOWED_LABEL_IDS = {
+ 0: [0, 1, 2, 3, 4, 5, 6],
+ 1: [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ ],
+ 2: [0, 1, 2, 3, 26, 39, 64, 65, 66, 67],
+ 3: [0, 1, 2, 3, 4, 5, 6, 26, 39, 64, 65, 66, 67],
+ 4: [0, 1, 2, 3, 4, 5, 6, 26],
+ 5: [0, 1, 2, 3, 4, 5, 6, 26],
+ 6: [0, 1, 2, 3, 4, 5, 6, 26],
+ 7: [0, 1, 2, 3, 4, 5, 6, 26],
+ 8: [0, 1, 2, 3, 4, 5, 6, 26],
+}
diff --git a/src/SongFormer/dataset/msa_info_utils.py b/src/SongFormer/dataset/msa_info_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a4e04612c8e97af5dc29487258c829bc6e75f6
--- /dev/null
+++ b/src/SongFormer/dataset/msa_info_utils.py
@@ -0,0 +1,47 @@
+from dataset.custom_types import MsaInfo
+from dataset.label2id import LABEL_TO_ID
+
+
+def load_msa_info(msa_info_path):
+ msa_info: MsaInfo = []
+ with open(msa_info_path) as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ time_, label = line.split()
+ time_ = float(time_)
+ label = str(label)
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
+ msa_info.append((time_, label))
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
+ return msa_info
+
+
+def load_msa_infos(msa_str):
+ msa_info: MsaInfo = []
+ for line in msa_str:
+ line = line.strip()
+ if not line:
+ continue
+ time_, label = line.split()
+ time_ = float(time_)
+ label = str(label)
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
+ msa_info.append((time_, label))
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
+ return msa_info
+
+
+def dump_msa_info(msa_info_path, msa_info: MsaInfo):
+ with open(msa_info_path, "w") as f:
+ for time_, label in msa_info:
+ f.write(f"{time_} {label}\n")
+
+
+def dump_msa_infos(msa_info: MsaInfo):
+ mas_strs = []
+ for time_, label in msa_info:
+ mas_strs.append(f"{round(time_, 2)} {label}")
+
+ return "\n".join(mas_strs)
diff --git a/src/SongFormer/eval.sh b/src/SongFormer/eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f6d6a1cff88987f7eeafac01924d4dbd4df9f9f
--- /dev/null
+++ b/src/SongFormer/eval.sh
@@ -0,0 +1,22 @@
+export CUDA_VISIBLE_DEVICES=-1
+export PYTHONPATH=${PWD}:$PYTHONPATH
+
+export HYDRA_FULL_ERROR=1
+export OMP_NUM_THREADS=1
+export MPI_NUM_THREADS=1
+export NCCL_P2P_DISABLE=1
+export NCCL_IB_DISABLE=1
+
+
+EST_DIR=
+ANN_DIR=
+OUTPUT_DIR=
+echo "$EST_DIR --> $OUTPUT_DIR"
+mkdir -p "$OUTPUT_DIR"
+
+python evaluation/eval_infer_results.py \
+ --ann_dir $ANN_DIR \
+ --est_dir $EST_DIR \
+ --output_dir $OUTPUT_DIR \
+ --prechorus2what verse
+ # --armerge_continuous_segments
\ No newline at end of file
diff --git a/src/SongFormer/evaluation/eval_infer_results.py b/src/SongFormer/evaluation/eval_infer_results.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b547116dd93db020b26e1d7d69aeb3a1143d415
--- /dev/null
+++ b/src/SongFormer/evaluation/eval_infer_results.py
@@ -0,0 +1,198 @@
+import argparse
+import os
+from collections import defaultdict
+from pathlib import Path
+import mir_eval
+import numpy as np
+import pandas as pd
+from dataset.custom_types import MsaInfo
+from dataset.label2id import LABEL_TO_ID
+from dataset.msa_info_utils import load_msa_info
+from msaf.eval import compute_results
+from postprocessing.calc_acc import cal_acc
+from postprocessing.calc_iou import cal_iou
+from tqdm import tqdm
+from loguru import logger
+
+LEGAL_LABELS = {
+ "end",
+ "intro",
+ "verse",
+ "chorus",
+ "bridge",
+ "inst",
+ "outro",
+ "silence",
+ "pre-chorus",
+}
+
+
+def to_inters_labels(msa_info: MsaInfo):
+ label_ids = np.array([LABEL_TO_ID[x[1]] for x in msa_info[:-1]])
+ times = [x[0] for x in msa_info]
+ start_times = np.column_stack([np.array(times[:-1]), np.array(times[1:])])
+ return start_times, label_ids
+
+
+def merge_continuous_segments(segments):
+ """
+ Merge continuous segments with the same label.
+
+ Parameters:
+ segments: List of tuples [(start_time, label), ...], where the last element is (end_time, 'end')
+
+ Returns:
+ Merged segment list in the same format [(start_time, label), ...], with the last element being (end_time, 'end')
+ """
+ if not segments or len(segments) < 2:
+ return segments
+
+ merged = []
+ current_start = segments[0][0]
+ current_label = segments[0][1]
+
+ for i in range(1, len(segments)):
+ time, label = segments[i]
+
+ if label == "end":
+ if current_label != "end":
+ merged.append((current_start, current_label))
+ merged.append((time, "end"))
+ break
+
+ if label != current_label:
+ merged.append((current_start, current_label))
+ current_start = time
+ current_label = label
+
+ return merged
+
+
+def main():
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument("--ann_dir", type=str, required=True)
+ argparser.add_argument("--est_dir", type=str, required=True)
+ argparser.add_argument("--output_dir", type=str, default="./eval_infer_results")
+ argparser.add_argument("--prechorus2what", type=str, default=None)
+ argparser.add_argument("--armerge_continuous_segments", action="store_true")
+ args = argparser.parse_args()
+
+ ann_dir = args.ann_dir
+ est_dir = args.est_dir
+ output_dir = args.output_dir
+ if args.armerge_continuous_segments:
+ logger.info("Merging continuous segments")
+ os.makedirs(output_dir, exist_ok=True)
+
+ ann_id_lists = [x for x in os.listdir(ann_dir) if x.endswith(".txt")]
+ est_id_lists = [x for x in os.listdir(est_dir) if x.endswith(".txt")]
+
+ common_id_lists = set(ann_id_lists) & set(est_id_lists)
+ common_id_lists = list(common_id_lists)
+ logger.info(f"Common number of files: {len(common_id_lists)}")
+
+ resultes = []
+ ious = {}
+
+ for id in tqdm(common_id_lists):
+ try:
+ logger.info(f"Processing {id}")
+ ann_msa = load_msa_info(os.path.join(ann_dir, id))
+ est_msa = load_msa_info(os.path.join(est_dir, id))
+
+ if args.prechorus2what == "verse":
+ ann_msa = [
+ (t, "verse") if l == "pre-chorus" else (t, l) for t, l in ann_msa
+ ]
+ est_msa = [
+ (t, "verse") if l == "pre-chorus" else (t, l) for t, l in est_msa
+ ]
+ elif args.prechorus2what == "chorus":
+ ann_msa = [
+ (t, "chorus") if l == "pre-chorus" else (t, l) for t, l in ann_msa
+ ]
+ est_msa = [
+ (t, "chorus") if l == "pre-chorus" else (t, l) for t, l in est_msa
+ ]
+ elif args.prechorus2what is not None:
+ raise ValueError(f"Unknown prechorus2what: {args.prechorus2what}")
+ if args.armerge_continuous_segments:
+ ann_msa = merge_continuous_segments(ann_msa)
+ est_msa = merge_continuous_segments(est_msa)
+
+ ann_inter, ann_labels = to_inters_labels(ann_msa)
+ est_inter, est_labels = to_inters_labels(est_msa)
+
+ result = compute_results(
+ ann_inter,
+ est_inter,
+ ann_labels,
+ est_labels,
+ bins=11,
+ est_file="test.txt",
+ weight=0.58,
+ )
+ acc = cal_acc(ann_msa, est_msa, post_digit=3)
+
+ ious[id] = cal_iou(ann_msa, est_msa)
+ result["HitRate_1P"], result["HitRate_1R"], result["HitRate_1F"] = (
+ mir_eval.segment.detection(ann_inter, est_inter, window=1, trim=False)
+ )
+ result.update({"id": Path(id).stem})
+ result.update({"acc": acc})
+ for v in ious[id]:
+ result.update({f"iou-{v['label']}": v["iou"]})
+ del result["track_id"]
+ del result["ds_name"]
+
+ resultes.append(result)
+ except Exception as e:
+ logger.error(f"Error processing {id}: {e}")
+ continue
+
+ df = pd.DataFrame(resultes)
+ df.to_csv(f"{output_dir}/eval_infer.csv", index=False)
+
+ intsec_dur_total = defaultdict(float)
+ uni_dur_total = defaultdict(float)
+
+ for tid, value in ious.items():
+ for item in value:
+ label = item["label"]
+ intsec_dur_total[label] += item.get("intsec_dur", 0)
+ uni_dur_total[label] += item.get("uni_dur", 0)
+
+ total_intsec = sum(intsec_dur_total.values())
+ total_uni = sum(uni_dur_total.values())
+ overall_iou = total_intsec / total_uni if total_uni > 0 else 0.0
+
+ class_ious = {}
+ for label in intsec_dur_total:
+ intsec = intsec_dur_total[label]
+ uni = uni_dur_total[label]
+ class_ious[label] = intsec / uni if uni > 0 else 0.0
+
+ summary = pd.DataFrame(
+ [
+ {
+ "num_samples": len(df),
+ "HR.5F": df["HitRate_0.5F"].mean(),
+ "HR3F": df["HitRate_3F"].mean(),
+ "HR1F": df["HitRate_1F"].mean(),
+ "PWF": df["PWF"].mean(),
+ "Sf": df["Sf"].mean(),
+ "acc": df["acc"].mean(),
+ "iou": overall_iou,
+ **{f"iou_{k}": v for k, v in class_ious.items()},
+ }
+ ]
+ )
+ with open(f"{output_dir}/eval_infer_summary.md", "w") as f:
+ print(summary.to_markdown(), file=f)
+
+ summary.to_csv(f"{output_dir}/eval_infer_summary.csv", index=False)
+ logger.info(f"Results saved to {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/SongFormer/infer.sh b/src/SongFormer/infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5341f09169bf74b7b02ba20029a0657850893649
--- /dev/null
+++ b/src/SongFormer/infer.sh
@@ -0,0 +1,21 @@
+
+export CUDA_VISIBLE_DEVICES=
+echo "use gpu ${CUDA_VISIBLE_DEVICES}"
+
+export PYTHONPATH=../third_party:$PYTHONPATH
+
+export OMP_NUM_THREADS=1
+export MPI_NUM_THREADS=1
+export NCCL_P2P_DISABLE=1
+export NCCL_IB_DISABLE=1
+
+python infer/infer.py \
+-i XXX.scp \
+-o XXX_dir \
+--model SongFormer \
+--checkpoint SongFormer.safetensors \
+--config_path SongFormer.yaml \
+-gn 1 \
+-tn 1
+# --debug
+# --no_rule_post_processing
\ No newline at end of file
diff --git a/src/SongFormer/infer/infer.py b/src/SongFormer/infer/infer.py
new file mode 100755
index 0000000000000000000000000000000000000000..5cb99cf4fb1677efcd7cef4657b25a42a65e8b92
--- /dev/null
+++ b/src/SongFormer/infer/infer.py
@@ -0,0 +1,439 @@
+import argparse
+import importlib
+import json
+import math
+import multiprocessing as mp
+import os
+import time
+from argparse import Namespace
+from pathlib import Path
+
+# monkey patch to fix issues in msaf
+import scipy
+import numpy as np
+
+scipy.inf = np.inf
+
+import librosa
+import torch
+from ema_pytorch import EMA
+from loguru import logger
+from muq import MuQ
+from musicfm.model.musicfm_25hz import MusicFM25Hz
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+mp.set_start_method("spawn", force=True)
+
+MUSICFM_HOME_PATH = os.path.join("ckpts", "MusicFM")
+
+BEFORE_DOWNSAMPLING_FRAME_RATES = 25
+AFTER_DOWNSAMPLING_FRAME_RATES = 8.333
+
+DATASET_LABEL = "SongForm-HX-8Class"
+DATASET_IDS = [5]
+
+TIME_DUR = 420
+INPUT_SAMPLING_RATE = 24000
+
+from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID
+from postprocessing.functional import postprocess_functional_structure
+
+
+def get_processed_ids(output_path):
+ """Get already processed IDs from output directory"""
+ ids = os.listdir(output_path)
+ ret = []
+ for x in ids:
+ if x.endswith(".json"):
+ ret.append(x.replace(".json", ""))
+ return set(ret)
+
+
+def get_processing_ids(input_path, processed_ids_set):
+ """Get IDs to be processed from input directory"""
+ ret = []
+ with open(input_path) as f:
+ for line in f:
+ if line.strip() and Path(line.strip()).stem not in processed_ids_set:
+ ret.append(line.strip())
+ return ret
+
+
+def load_checkpoint(checkpoint_path, device=None):
+ """Load checkpoint from path"""
+ if device is None:
+ device = "cpu"
+
+ if checkpoint_path.endswith(".pt"):
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ elif checkpoint_path.endswith(".safetensors"):
+ from safetensors.torch import load_file
+
+ checkpoint = {"model_ema": load_file(checkpoint_path, device=device)}
+ else:
+ raise ValueError("Unsupported checkpoint format. Use .pt or .safetensors")
+ return checkpoint
+
+
+def rule_post_processing(msa_list):
+ if len(msa_list) <= 2:
+ return msa_list
+
+ result = msa_list.copy()
+
+ while len(result) > 2:
+ first_duration = result[1][0] - result[0][0]
+ if first_duration < 1.0 and len(result) > 2:
+ result[0] = (result[0][0], result[1][1])
+ result = [result[0]] + result[2:]
+ else:
+ break
+
+ while len(result) > 2:
+ last_label_duration = result[-1][0] - result[-2][0]
+ if last_label_duration < 1.0:
+ result = result[:-2] + [result[-1]]
+ else:
+ break
+
+ while len(result) > 2:
+ if result[0][1] == result[1][1] and result[1][0] <= 10.0:
+ result = [(result[0][0], result[0][1])] + result[2:]
+ else:
+ break
+
+ while len(result) > 2:
+ last_duration = result[-1][0] - result[-2][0]
+ if result[-2][1] == result[-3][1] and last_duration <= 10.0:
+ result = result[:-2] + [result[-1]]
+ else:
+ break
+
+ return result
+
+
+def inference(rank, queue_input: mp.Queue, queue_output: mp.Queue, args):
+ """Run inference on the input audio"""
+ device = f"cuda:{rank}"
+
+ # MuQ model loading (this will automatically fetch the checkpoint from huggingface)
+ muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
+ muq = muq.to(device).eval()
+
+ # MusicFM model loading
+ musicfm = MusicFM25Hz(
+ is_flash=False,
+ stat_path=os.path.join(MUSICFM_HOME_PATH, "msd_stats.json"),
+ model_path=os.path.join(MUSICFM_HOME_PATH, "pretrained_msd.pt"),
+ )
+ musicfm = musicfm.to(device)
+ musicfm.eval()
+
+ # Custom model loading based on the config
+ module = importlib.import_module("models." + str(args.model))
+ Model = getattr(module, "Model")
+ hp = OmegaConf.load(os.path.join("configs", args.config_path))
+ model = Model(hp)
+
+ ckpt = load_checkpoint(checkpoint_path=os.path.join("ckpts", args.checkpoint))
+ if ckpt.get("model_ema", None) is not None:
+ logger.info("Loading EMA model parameters")
+ model_ema = EMA(model, include_online_model=False)
+ model_ema.load_state_dict(ckpt["model_ema"])
+ model.load_state_dict(model_ema.ema_model.state_dict())
+ else:
+ logger.info("No EMA model parameters found, using original model")
+ model.load_state_dict(ckpt["model"])
+
+ model.to(device)
+ model.eval()
+
+ num_classes = args.num_classes
+ dataset_id2label_mask = {}
+
+ for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items():
+ dataset_id2label_mask[key] = np.ones(args.num_classes, dtype=bool)
+ dataset_id2label_mask[key][allowed_ids] = False
+
+ with torch.no_grad():
+ while True:
+ item = queue_input.get()
+ if not item:
+ queue_output.put(None)
+ break
+
+ try:
+ # Loading the audio file
+ wav, sr = librosa.load(item, sr=INPUT_SAMPLING_RATE)
+ audio = torch.tensor(wav).to(device)
+
+ win_size = args.win_size
+ hop_size = args.hop_size
+ total_len = (
+ (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR
+ ) * TIME_DUR + TIME_DUR
+ total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES)
+
+ logits = {
+ "function_logits": np.zeros([total_frames, num_classes]),
+ "boundary_logits": np.zeros([total_frames]),
+ }
+ logits_num = {
+ "function_logits": np.zeros([total_frames, num_classes]),
+ "boundary_logits": np.zeros([total_frames]),
+ }
+
+ lens = 0
+ i = 0
+ while True:
+ start_idx = i * INPUT_SAMPLING_RATE
+ end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1])
+ if start_idx >= audio.shape[-1]:
+ break
+ if end_idx - start_idx <= 1024:
+ continue
+ audio_seg = audio[start_idx:end_idx]
+
+ # MuQ embedding
+ muq_output = muq(audio_seg.unsqueeze(0), output_hidden_states=True)
+ muq_embd_420s = muq_output["hidden_states"][10]
+ del muq_output
+ torch.cuda.empty_cache()
+
+ # MusicFM embedding
+ _, musicfm_hidden_states = musicfm.get_predictions(
+ audio_seg.unsqueeze(0)
+ )
+ musicfm_embd_420s = musicfm_hidden_states[10]
+ del musicfm_hidden_states
+ torch.cuda.empty_cache()
+
+ wraped_muq_embd_30s = []
+ wraped_musicfm_embd_30s = []
+
+ for idx_30s in range(i, i + hop_size, 30):
+ start_idx_30s = idx_30s * INPUT_SAMPLING_RATE
+ end_idx_30s = min(
+ (idx_30s + 30) * INPUT_SAMPLING_RATE,
+ audio.shape[-1],
+ (i + hop_size) * INPUT_SAMPLING_RATE,
+ )
+ if start_idx_30s >= audio.shape[-1]:
+ break
+ if end_idx_30s - start_idx_30s <= 1024:
+ continue
+ wraped_muq_embd_30s.append(
+ muq(
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0),
+ output_hidden_states=True,
+ )["hidden_states"][10]
+ )
+ torch.cuda.empty_cache()
+ wraped_musicfm_embd_30s.append(
+ musicfm.get_predictions(
+ audio[start_idx_30s:end_idx_30s].unsqueeze(0)
+ )[1][10]
+ )
+ torch.cuda.empty_cache()
+
+ wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1)
+ wraped_musicfm_embd_30s = torch.concatenate(
+ wraped_musicfm_embd_30s, dim=1
+ )
+ all_embds = [
+ wraped_musicfm_embd_30s,
+ wraped_muq_embd_30s,
+ musicfm_embd_420s,
+ muq_embd_420s,
+ ]
+
+ if len(all_embds) > 1:
+ embd_lens = [x.shape[1] for x in all_embds]
+ max_embd_len = max(embd_lens)
+ min_embd_len = min(embd_lens)
+ if abs(max_embd_len - min_embd_len) > 4:
+ raise ValueError(
+ f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}"
+ )
+
+ for idx in range(len(all_embds)):
+ all_embds[idx] = all_embds[idx][:, :min_embd_len, :]
+
+ embd = torch.concatenate(all_embds, axis=-1)
+
+ dataset_label = DATASET_LABEL
+ dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long)
+ msa_info, chunk_logits = model.infer(
+ input_embeddings=embd,
+ dataset_ids=dataset_ids,
+ label_id_masks=torch.Tensor(
+ dataset_id2label_mask[
+ DATASET_LABEL_TO_DATASET_ID[dataset_label]
+ ]
+ )
+ .to(device, dtype=bool)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ with_logits=True,
+ )
+
+ start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES)
+ end_frame = start_frame + min(
+ math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES),
+ chunk_logits["boundary_logits"][0].shape[0],
+ )
+
+ logits["function_logits"][start_frame:end_frame, :] += (
+ chunk_logits["function_logits"][0].detach().cpu().numpy()
+ )
+ logits["boundary_logits"][start_frame:end_frame] = (
+ chunk_logits["boundary_logits"][0].detach().cpu().numpy()
+ )
+ logits_num["function_logits"][start_frame:end_frame, :] += 1
+ logits_num["boundary_logits"][start_frame:end_frame] += 1
+ lens += end_frame - start_frame
+
+ i += hop_size
+ logits["function_logits"] /= logits_num["function_logits"]
+ logits["boundary_logits"] /= logits_num["boundary_logits"]
+
+ logits["function_logits"] = torch.from_numpy(
+ logits["function_logits"][:lens]
+ ).unsqueeze(0)
+ logits["boundary_logits"] = torch.from_numpy(
+ logits["boundary_logits"][:lens]
+ ).unsqueeze(0)
+
+ msa_infer_output = postprocess_functional_structure(logits, hp)
+
+ assert msa_infer_output[-1][-1] == "end"
+ if not args.no_rule_post_processing:
+ msa_infer_output = rule_post_processing(msa_infer_output)
+ msa_json = []
+ for idx in range(len(msa_infer_output) - 1):
+ msa_json.append(
+ {
+ "label": msa_infer_output[idx][1],
+ "start": msa_infer_output[idx][0],
+ "end": msa_infer_output[idx + 1][0],
+ }
+ )
+ json.dump(
+ msa_json,
+ open(os.path.join(args.output_dir, f"{Path(item).stem}.json"), "w"),
+ indent=4,
+ ensure_ascii=False,
+ )
+
+ queue_output.put(None)
+
+ except Exception as e:
+ queue_output.put(None)
+ logger.error(f"process {rank} error\n{item}\n{e}")
+
+
+def deal_with_output(output_path, queue_output, length):
+ """Handle output data from the queue"""
+ pbar = tqdm(range(length), desc="getting inference output")
+ for _ in pbar:
+ data = queue_output.get()
+ if not data:
+ continue
+
+
+def main(args):
+ input_path = args.input_path
+ output_path = args.output_path
+ gpu_num = args.gpu_num
+ num_thread_per_gpu = args.num_thread_per_gpu
+ debug = args.debug
+
+ os.makedirs(output_path, exist_ok=True)
+
+ processed_ids = get_processed_ids(output_path=output_path)
+ processing_ids = get_processing_ids(input_path, processed_ids)
+
+ num_threads = num_thread_per_gpu * gpu_num
+
+ queue_input: mp.Queue = mp.Queue()
+ queue_output: mp.Queue = mp.Queue()
+
+ init_args = Namespace(
+ output_dir=output_path,
+ win_size=420,
+ hop_size=420,
+ num_classes=128,
+ model=args.model,
+ checkpoint=args.checkpoint,
+ config_path=args.config_path,
+ no_rule_post_processing=args.no_rule_post_processing,
+ )
+
+ processes = []
+
+ if debug:
+ queue_input.put(processing_ids[0])
+ queue_input.put(None)
+
+ inference(0, queue_input, queue_output, init_args)
+
+ print("debug exit")
+ exit(0)
+
+ for thread_num in range(num_threads):
+ rank = thread_num % gpu_num
+ print(f"num_threads: {thread_num} on GPU {rank}")
+ time.sleep(0.2)
+ p = mp.Process(
+ target=inference,
+ args=(rank, queue_input, queue_output, init_args),
+ daemon=True,
+ )
+ p.start()
+ processes.append(p)
+
+ for wav_id in tqdm(processing_ids, desc="add data to queue"):
+ queue_input.put(wav_id)
+
+ for _ in range(num_threads):
+ queue_input.put(None)
+
+ deal_with_output(output_path, queue_output, len(processing_ids))
+
+ for p in processes:
+ p.join()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--input_path", "-i", type=str, required=True, help="Input file path"
+ )
+ parser.add_argument(
+ "--output_path", "-o", type=str, required=True, help="Output file path"
+ )
+ parser.add_argument(
+ "--gpu_num", "-gn", type=int, default=1, help="Number of GPUs, default is 1"
+ )
+ parser.add_argument(
+ "--num_thread_per_gpu",
+ "-tn",
+ type=int,
+ default=1,
+ help="Number of threads per GPU, default is 1",
+ )
+ parser.add_argument("--model", type=str, help="Model to use")
+ parser.add_argument("--checkpoint", type=str, help="Checkpoint path")
+ parser.add_argument("--config_path", type=str, help="Configuration file path")
+ parser.add_argument(
+ "--no_rule_post_processing",
+ action="store_true",
+ help="Disable rule-based post-processing",
+ )
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
+
+ args = parser.parse_args()
+
+ main(args=args)
diff --git a/src/SongFormer/models/SongFormer.py b/src/SongFormer/models/SongFormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6d860e2ee3674cfb04fd883b028d9b28936b33
--- /dev/null
+++ b/src/SongFormer/models/SongFormer.py
@@ -0,0 +1,521 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from dataset.custom_types import MsaInfo
+from msaf.eval import compute_results
+from postprocessing.functional import postprocess_functional_structure
+from x_transformers import Encoder
+import bisect
+
+
+class Head(nn.Module):
+ def __init__(self, input_dim, output_dim, hidden_dims=None, activation="silu"):
+ super().__init__()
+ hidden_dims = hidden_dims or []
+ act_layers = {"relu": nn.ReLU, "silu": nn.SiLU, "gelu": nn.GELU}
+ act_layer = act_layers.get(activation.lower())
+ if not act_layer:
+ raise ValueError(f"Unsupported activation: {activation}")
+
+ dims = [input_dim] + hidden_dims + [output_dim]
+ layers = []
+ for i in range(len(dims) - 1):
+ layers.append(nn.Linear(dims[i], dims[i + 1]))
+ if i < len(dims) - 2:
+ layers.append(act_layer())
+ self.net = nn.Sequential(*layers)
+
+ def reset_parameters(self, confidence):
+ bias_value = -torch.log(torch.tensor((1 - confidence) / confidence))
+ self.net[-1].bias.data.fill_(bias_value.item())
+
+ def forward(self, x):
+ batch, T, C = x.shape
+ x = x.reshape(-1, C)
+ x = self.net(x)
+ return x.reshape(batch, T, -1)
+
+
+class WrapedTransformerEncoder(nn.Module):
+ def __init__(
+ self, input_dim, transformer_input_dim, num_layers=1, nhead=8, dropout=0.1
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.transformer_input_dim = transformer_input_dim
+
+ if input_dim != transformer_input_dim:
+ self.input_proj = nn.Sequential(
+ nn.Linear(input_dim, transformer_input_dim),
+ nn.LayerNorm(transformer_input_dim),
+ nn.GELU(),
+ nn.Dropout(dropout * 0.5),
+ nn.Linear(transformer_input_dim, transformer_input_dim),
+ )
+ else:
+ self.input_proj = nn.Identity()
+
+ self.transformer = Encoder(
+ dim=transformer_input_dim,
+ depth=num_layers,
+ heads=nhead,
+ layer_dropout=dropout,
+ attn_dropout=dropout,
+ ff_dropout=dropout,
+ attn_flash=True,
+ rotary_pos_emb=True,
+ )
+
+ def forward(self, x, src_key_padding_mask=None):
+ """
+ The input src_key_padding_mask is a B x T boolean mask, where True indicates masked positions.
+ However, in x-transformers, False indicates masked positions.
+ Therefore, it needs to be converted so that False represents masked positions.
+ """
+ x = self.input_proj(x)
+ mask = (
+ ~torch.tensor(src_key_padding_mask, dtype=torch.bool, device=x.device)
+ if src_key_padding_mask is not None
+ else None
+ )
+ return self.transformer(x, mask=mask)
+
+
+def prefix_dict(d, prefix: str):
+ if prefix:
+ return d
+ return {prefix + key: value for key, value in d.items()}
+
+
+class TimeDownsample(nn.Module):
+ def __init__(
+ self, dim_in, dim_out=None, kernel_size=5, stride=5, padding=0, dropout=0.1
+ ):
+ super().__init__()
+ self.dim_out = dim_out or dim_in
+ assert self.dim_out % 2 == 0
+
+ self.depthwise_conv = nn.Conv1d(
+ in_channels=dim_in,
+ out_channels=dim_in,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=dim_in,
+ bias=False,
+ )
+ self.pointwise_conv = nn.Conv1d(
+ in_channels=dim_in,
+ out_channels=self.dim_out,
+ kernel_size=1,
+ bias=False,
+ )
+ self.pool = nn.AvgPool1d(kernel_size, stride, padding=padding)
+ self.norm1 = nn.LayerNorm(self.dim_out)
+ self.act1 = nn.GELU()
+ self.dropout1 = nn.Dropout(dropout)
+
+ if dim_in != self.dim_out:
+ self.residual_conv = nn.Conv1d(
+ dim_in, self.dim_out, kernel_size=1, bias=False
+ )
+ else:
+ self.residual_conv = None
+
+ def forward(self, x):
+ residual = x # [B, T, D_in]
+ # Convolutional module
+ x_c = x.transpose(1, 2) # [B, D_in, T]
+ x_c = self.depthwise_conv(x_c) # [B, D_in, T_down]
+ x_c = self.pointwise_conv(x_c) # [B, D_out, T_down]
+
+ # Residual module
+ res = self.pool(residual.transpose(1, 2)) # [B, D_in, T]
+ if self.residual_conv:
+ res = self.residual_conv(res) # [B, D_out, T_down]
+ x_c = x_c + res # [B, D_out, T_down]
+ x_c = x_c.transpose(1, 2) # [B, T_down, D_out]
+ x_c = self.norm1(x_c)
+ x_c = self.act1(x_c)
+ x_c = self.dropout1(x_c)
+ return x_c
+
+
+class AddFuse(nn.Module):
+ def __init__(self):
+ super(AddFuse, self).__init__()
+
+ def forward(self, x, cond):
+ return x + cond
+
+
+class TVLoss1D(nn.Module):
+ def __init__(
+ self, beta=1.0, lambda_tv=0.4, boundary_threshold=0.01, reduction_weight=0.1
+ ):
+ """
+ Args:
+ beta: Exponential parameter for TV loss (recommended 0.5~1.0)
+ lambda_tv: Overall weight for TV loss
+ boundary_threshold: Label threshold to determine if a region is a "boundary area" (e.g., 0.01)
+ reduction_weight: Scaling factor for TV penalty within boundary regions (e.g., 0.1, meaning only 10% penalty)
+ """
+ super().__init__()
+ self.beta = beta
+ self.lambda_tv = lambda_tv
+ self.boundary_threshold = boundary_threshold
+ self.reduction_weight = reduction_weight
+
+ def forward(self, pred, target=None):
+ """
+ Args:
+ pred: (B, T) or (B, T, 1), float boundary scores output by the model
+ target: (B, T) or (B, T, 1), ground truth labels (optional, used for spatial weighting if provided)
+
+ Returns:
+ scalar: weighted TV loss
+ """
+ if pred.dim() == 3:
+ pred = pred.squeeze(-1)
+ if target is not None and target.dim() == 3:
+ target = target.squeeze(-1)
+
+ diff = pred[:, 1:] - pred[:, :-1]
+ tv_base = torch.pow(torch.abs(diff) + 1e-8, self.beta)
+
+ if target is None:
+ return self.lambda_tv * tv_base.mean()
+
+ left_in_boundary = target[:, :-1] > self.boundary_threshold
+ right_in_boundary = target[:, 1:] > self.boundary_threshold
+ near_boundary = left_in_boundary | right_in_boundary
+ weight_mask = torch.where(
+ near_boundary,
+ self.reduction_weight * torch.ones_like(tv_base),
+ torch.ones_like(tv_base),
+ )
+ tv_weighted = (tv_base * weight_mask).mean()
+ return self.lambda_tv * tv_weighted
+
+
+class SoftmaxFocalLoss(nn.Module):
+ """
+ Softmax Focal Loss for single-label multi-class classification.
+ Suitable for mutually exclusive classes.
+ """
+
+ def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
+ super().__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+
+ def forward(self, pred: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ pred: [B, T, C], raw logits
+ targets: [B, T, C] (soft) or [B, T] (hard, dtype=long)
+ Returns:
+ loss: scalar or [B, T] depending on reduction
+ """
+ log_probs = F.log_softmax(pred, dim=-1)
+ probs = torch.exp(log_probs)
+
+ if targets.dtype == torch.long:
+ targets_onehot = F.one_hot(targets, num_classes=pred.size(-1)).float()
+ else:
+ targets_onehot = targets
+
+ p_t = (probs * targets_onehot).sum(dim=-1)
+ p_t = p_t.clamp(min=1e-8, max=1.0 - 1e-8)
+
+ if self.alpha > 0:
+ alpha_t = self.alpha * targets_onehot + (1 - self.alpha) * (
+ 1 - targets_onehot
+ )
+ alpha_weight = (alpha_t * targets_onehot).sum(dim=-1)
+ else:
+ alpha_weight = 1.0
+
+ focal_weight = (1 - p_t) ** self.gamma
+ ce_loss = -log_probs * targets_onehot
+ ce_loss = ce_loss.sum(dim=-1)
+
+ loss = alpha_weight * focal_weight * ce_loss
+ return loss
+
+
+class Model(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.input_norm = nn.LayerNorm(config.input_dim)
+ self.mixed_win_downsample = nn.Linear(config.input_dim_raw, config.input_dim)
+ self.dataset_class_prefix = nn.Embedding(
+ num_embeddings=config.num_dataset_classes,
+ embedding_dim=config.transformer_encoder_input_dim,
+ )
+ self.down_sample_conv = TimeDownsample(
+ dim_in=config.input_dim,
+ dim_out=config.transformer_encoder_input_dim,
+ kernel_size=config.down_sample_conv_kernel_size,
+ stride=config.down_sample_conv_stride,
+ dropout=config.down_sample_conv_dropout,
+ padding=config.down_sample_conv_padding,
+ )
+ self.AddFuse = AddFuse()
+ self.transformer = WrapedTransformerEncoder(
+ input_dim=config.transformer_encoder_input_dim,
+ transformer_input_dim=config.transformer_input_dim,
+ num_layers=config.num_transformer_layers,
+ nhead=config.transformer_nhead,
+ dropout=config.transformer_dropout,
+ )
+ self.boundary_TVLoss1D = TVLoss1D(
+ beta=config.boundary_tv_loss_beta,
+ lambda_tv=config.boundary_tv_loss_lambda,
+ boundary_threshold=config.boundary_tv_loss_boundary_threshold,
+ reduction_weight=config.boundary_tv_loss_reduction_weight,
+ )
+ self.label_focal_loss = SoftmaxFocalLoss(
+ alpha=config.label_focal_loss_alpha, gamma=config.label_focal_loss_gamma
+ )
+ self.boundary_head = Head(config.transformer_input_dim, 1)
+ self.function_head = Head(config.transformer_input_dim, config.num_classes)
+
+ def cal_metrics(self, gt_info: MsaInfo, msa_info: MsaInfo):
+ assert gt_info[-1][1] == "end" and msa_info[-1][1] == "end", (
+ "gt_info and msa_info should end with 'end'"
+ )
+ gt_info_labels = [label for time_, label in gt_info][:-1]
+ gt_info_inters = [time_ for time_, label in gt_info]
+ gt_info_inters = np.column_stack(
+ [np.array(gt_info_inters[:-1]), np.array(gt_info_inters[1:])]
+ )
+
+ msa_info_labels = [label for time_, label in msa_info][:-1]
+ msa_info_inters = [time_ for time_, label in msa_info]
+ msa_info_inters = np.column_stack(
+ [np.array(msa_info_inters[:-1]), np.array(msa_info_inters[1:])]
+ )
+ result = compute_results(
+ ann_inter=gt_info_inters,
+ est_inter=msa_info_inters,
+ ann_labels=gt_info_labels,
+ est_labels=msa_info_labels,
+ bins=11,
+ est_file="test.txt",
+ weight=0.58,
+ )
+ return result
+
+ def cal_acc(
+ self, ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3
+ ):
+ ann_info_time = [
+ int(round(time_, post_digit) * (10**post_digit))
+ for time_, label in ann_info
+ ]
+ est_info_time = [
+ int(round(time_, post_digit) * (10**post_digit))
+ for time_, label in est_info
+ ]
+
+ common_start_time = max(ann_info_time[0], est_info_time[0])
+ common_end_time = min(ann_info_time[-1], est_info_time[-1])
+
+ time_points = {common_start_time, common_end_time}
+ time_points.update(
+ {
+ time_
+ for time_ in ann_info_time
+ if common_start_time <= time_ <= common_end_time
+ }
+ )
+ time_points.update(
+ {
+ time_
+ for time_ in est_info_time
+ if common_start_time <= time_ <= common_end_time
+ }
+ )
+
+ time_points = sorted(time_points)
+ total_duration, total_score = 0, 0
+
+ for idx in range(len(time_points) - 1):
+ duration = time_points[idx + 1] - time_points[idx]
+ ann_label = ann_info[
+ bisect.bisect_right(ann_info_time, time_points[idx]) - 1
+ ][1]
+ est_label = est_info[
+ bisect.bisect_right(est_info_time, time_points[idx]) - 1
+ ][1]
+ total_duration += duration
+ if ann_label == est_label:
+ total_score += duration
+ return total_score / total_duration
+
+ def infer_with_metrics(self, batch, prefix: str = None):
+ with torch.no_grad():
+ logits = self.forward_func(batch)
+
+ losses = self.compute_losses(logits, batch, prefix=None)
+
+ expanded_mask = batch["label_id_masks"].expand(
+ -1, logits["function_logits"].size(1), -1
+ )
+ logits["function_logits"] = logits["function_logits"].masked_fill(
+ expanded_mask, -float("inf")
+ )
+
+ msa_info = postprocess_functional_structure(
+ logits=logits, config=self.config
+ )
+ gt_info = batch["msa_infos"][0]
+ results = self.cal_metrics(gt_info=gt_info, msa_info=msa_info)
+
+ ret_results = {
+ "loss": losses["loss"].item(),
+ "HitRate_3P": results["HitRate_3P"],
+ "HitRate_3R": results["HitRate_3R"],
+ "HitRate_3F": results["HitRate_3F"],
+ "HitRate_0.5P": results["HitRate_0.5P"],
+ "HitRate_0.5R": results["HitRate_0.5R"],
+ "HitRate_0.5F": results["HitRate_0.5F"],
+ "PWF": results["PWF"],
+ "PWP": results["PWP"],
+ "PWR": results["PWR"],
+ "Sf": results["Sf"],
+ "So": results["So"],
+ "Su": results["Su"],
+ "acc": self.cal_acc(ann_info=gt_info, est_info=msa_info),
+ }
+ if prefix:
+ ret_results = prefix_dict(ret_results, prefix)
+
+ return ret_results
+
+ def infer(
+ self,
+ input_embeddings,
+ dataset_ids,
+ label_id_masks,
+ prefix: str = None,
+ with_logits=False,
+ ):
+ with torch.no_grad():
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
+ input_embeddings = self.input_norm(input_embeddings)
+ logits = self.down_sample_conv(input_embeddings)
+
+ dataset_prefix = self.dataset_class_prefix(dataset_ids)
+ dataset_prefix_expand = dataset_prefix.unsqueeze(1).expand(
+ logits.size(0), 1, -1
+ )
+ logits = self.AddFuse(x=logits, cond=dataset_prefix_expand)
+ logits = self.transformer(x=logits, src_key_padding_mask=None)
+
+ function_logits = self.function_head(logits)
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
+
+ logits = {
+ "function_logits": function_logits,
+ "boundary_logits": boundary_logits,
+ }
+
+ expanded_mask = label_id_masks.expand(
+ -1, logits["function_logits"].size(1), -1
+ )
+ logits["function_logits"] = logits["function_logits"].masked_fill(
+ expanded_mask, -float("inf")
+ )
+
+ msa_info = postprocess_functional_structure(
+ logits=logits, config=self.config
+ )
+
+ return (msa_info, logits) if with_logits else msa_info
+
+ def compute_losses(self, outputs, batch, prefix: str = None):
+ loss = 0.0
+ losses = {}
+
+ loss_section = F.binary_cross_entropy_with_logits(
+ outputs["boundary_logits"],
+ batch["widen_true_boundaries"],
+ reduction="none",
+ )
+ loss_section += self.config.boundary_tvloss_weight * self.boundary_TVLoss1D(
+ pred=outputs["boundary_logits"],
+ target=batch["widen_true_boundaries"],
+ )
+ loss_function = F.cross_entropy(
+ outputs["function_logits"].transpose(1, 2),
+ batch["true_functions"].transpose(1, 2),
+ reduction="none",
+ )
+ # input is [B, T, C]
+ ttt = self.config.label_focal_loss_weight * self.label_focal_loss(
+ pred=outputs["function_logits"], targets=batch["true_functions"]
+ )
+ loss_function += ttt
+
+ float_masks = (~batch["masks"]).float()
+ boundary_mask = batch.get("boundary_mask", None)
+ function_mask = batch.get("function_mask", None)
+ if boundary_mask is not None:
+ boundary_mask = (~boundary_mask).float()
+ else:
+ boundary_mask = 1
+
+ if function_mask is not None:
+ function_mask = (~function_mask).float()
+ else:
+ function_mask = 1
+
+ loss_section = torch.mean(boundary_mask * float_masks * loss_section)
+ loss_function = torch.mean(function_mask * float_masks * loss_function)
+
+ loss_section *= self.config.loss_weight_section
+ loss_function *= self.config.loss_weight_function
+
+ if self.config.learn_label:
+ loss += loss_function
+ if self.config.learn_segment:
+ loss += loss_section
+
+ losses.update(
+ loss=loss,
+ loss_section=loss_section,
+ loss_function=loss_function,
+ )
+ if prefix:
+ losses = prefix_dict(losses, prefix)
+ return losses
+
+ def forward_func(self, batch):
+ input_embeddings = batch["input_embeddings"]
+ input_embeddings = self.mixed_win_downsample(input_embeddings)
+ input_embeddings = self.input_norm(input_embeddings)
+ logits = self.down_sample_conv(input_embeddings)
+
+ dataset_prefix = self.dataset_class_prefix(batch["dataset_ids"])
+ logits = self.AddFuse(x=logits, cond=dataset_prefix.unsqueeze(1))
+ src_key_padding_mask = batch["masks"]
+ logits = self.transformer(x=logits, src_key_padding_mask=src_key_padding_mask)
+
+ function_logits = self.function_head(logits)
+ boundary_logits = self.boundary_head(logits).squeeze(-1)
+
+ logits = {
+ "function_logits": function_logits,
+ "boundary_logits": boundary_logits,
+ }
+ return logits
+
+ def forward(self, batch):
+ logits = self.forward_func(batch)
+ losses = self.compute_losses(logits, batch, prefix=None)
+ return logits, losses["loss"], losses
diff --git a/src/SongFormer/postprocessing/calc_acc.py b/src/SongFormer/postprocessing/calc_acc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5128d33e1759928c57304b83fd994537d1c04ada
--- /dev/null
+++ b/src/SongFormer/postprocessing/calc_acc.py
@@ -0,0 +1,82 @@
+import os
+import bisect
+from dataset.msa_info_utils import (
+ load_msa_info,
+)
+from dataset.custom_types import MsaInfo
+import glob
+import pdb
+import pandas as pd
+
+
+def cal_acc(ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3):
+ if type(ann_info) is str:
+ assert os.path.exists(ann_info), f"{ann_info} not exists"
+ ann_info = load_msa_info(ann_info)
+
+ if type(ann_info) is str:
+ assert os.path.exists(est_info), f"{est_info} not exists"
+ est_info = load_msa_info(est_info)
+
+ ann_info_time = [
+ int(round(time_, post_digit) * (10**post_digit)) for time_, label in ann_info
+ ]
+ est_info_time = [
+ int(round(time_, post_digit) * (10**post_digit)) for time_, label in est_info
+ ]
+
+ common_start_time = max(ann_info_time[0], est_info_time[0])
+ common_end_time = min(ann_info_time[-1], est_info_time[-1])
+
+ time_points = set()
+ time_points.add(common_start_time)
+ time_points.add(common_end_time)
+
+ for time_ in ann_info_time:
+ if time_ >= common_start_time and time_ <= common_end_time:
+ time_points.add(time_)
+ for time_ in est_info_time:
+ if time_ >= common_start_time and time_ <= common_end_time:
+ time_points.add(time_)
+
+ time_points = sorted(list(time_points))
+ total_duration = 0
+ total_score = 0
+
+ for idx in range(len(time_points) - 1):
+ duration = time_points[idx + 1] - time_points[idx]
+ ann_label = ann_info[bisect.bisect_right(ann_info_time, time_points[idx]) - 1][
+ 1
+ ]
+ est_label = est_info[bisect.bisect_right(est_info_time, time_points[idx]) - 1][
+ 1
+ ]
+ total_duration += duration
+ if ann_label == est_label:
+ total_score += duration
+ return total_score / total_duration
+
+
+if __name__ == "__main__":
+ ext_paths = glob.glob("")
+ results = []
+ for ext_path in ext_paths:
+ try:
+ ann_path = os.path.join(
+ "",
+ os.path.basename(ext_path).split(".")[0] + ".txt",
+ )
+ results.append(
+ {
+ "data_id": os.path.basename(ext_path).split(".")[0],
+ "acc": cal_acc(
+ ann_info=ann_path,
+ est_info=ext_path,
+ ),
+ }
+ )
+ except Exception as e:
+ print(e)
+ continue
+ df = pd.DataFrame(results)
+ print(df["acc"].mean())
diff --git a/src/SongFormer/postprocessing/calc_iou.py b/src/SongFormer/postprocessing/calc_iou.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc2e282bdc06f5be473c535f346222da2d854b2
--- /dev/null
+++ b/src/SongFormer/postprocessing/calc_iou.py
@@ -0,0 +1,89 @@
+import os
+from dataset.custom_types import MsaInfo
+from dataset.label2id import LABEL_TO_ID
+from pprint import pprint
+
+
+def load_msa_info(msa_info_path):
+ msa_info: MsaInfo = []
+ with open(msa_info_path) as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ time_, label = line.split()
+ time_ = float(time_)
+ label = str(label)
+ assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
+ msa_info.append((time_, label))
+ assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
+ return msa_info
+
+
+def msa_info_to_segments(msa_info):
+ # skip the last "end"
+ segments = []
+ for i in range(len(msa_info) - 1):
+ start = msa_info[i][0]
+ end = msa_info[i + 1][0]
+ label = msa_info[i][1]
+ segments.append((start, end, label))
+ return segments
+
+
+def compute_iou_for_label(segments_a, segments_b, label):
+ # segments_a, segments_b: [(start, end, label)]
+ # only process the current label
+ intervals_a = [(s, e) for s, e, l in segments_a if l == label]
+ intervals_b = [(s, e) for s, e, l in segments_b if l == label]
+ # sum up all intersections between a and b
+ intersection = 0.0
+ for sa, ea in intervals_a:
+ for sb, eb in intervals_b:
+ left = max(sa, sb)
+ right = min(ea, eb)
+ if left < right:
+ intersection += right - left
+ # union = total length of both sets - overlapping intersection
+ length_a = sum([e - s for s, e in intervals_a])
+ length_b = sum([e - s for s, e in intervals_b])
+ union = length_a + length_b - intersection
+ if union == 0:
+ return 0.0
+ return intersection / union, intersection, union
+
+
+def compute_mean_iou(segments_a, segments_b, labels):
+ ious = []
+ for label in labels:
+ iou, intsec_dur, uni_dur = compute_iou_for_label(segments_a, segments_b, label)
+ ious.append(
+ {"label": label, "iou": iou, "intsec_dur": intsec_dur, "uni_dur": uni_dur}
+ )
+ return ious
+
+
+def cal_iou(ann_info, est_info):
+ if type(ann_info) is str:
+ assert os.path.exists(ann_info), f"{ann_info} not exists"
+ ann_info = load_msa_info(ann_info)
+
+ if type(est_info) is str:
+ assert os.path.exists(est_info), f"{est_info} not exists"
+ est_info = load_msa_info(est_info)
+
+ segments_ann = msa_info_to_segments(ann_info)
+ segments_est = msa_info_to_segments(est_info)
+
+ occurred_labels = list(
+ set([l for s, e, l in segments_ann]) | set(l for s, e, l in segments_est)
+ )
+
+ mean_iou = compute_mean_iou(segments_ann, segments_est, occurred_labels)
+ return mean_iou
+
+
+if __name__ == "__main__":
+ ann_info = ""
+ est_info = ""
+ pprint(cal_iou(ann_info, est_info))
diff --git a/src/SongFormer/postprocessing/functional.py b/src/SongFormer/postprocessing/functional.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dedd22391b7e2308661e8ad861732e4df749af0
--- /dev/null
+++ b/src/SongFormer/postprocessing/functional.py
@@ -0,0 +1,71 @@
+# This file contains code adapted from the following sources:
+# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/functional.py
+
+import numpy as np
+import torch
+from .helpers import (
+ local_maxima,
+ peak_picking,
+ # event_frames_to_time,
+)
+from dataset.label2id import LABEL_TO_ID, ID_TO_LABEL
+from dataset.custom_types import MsaInfo
+
+
+def event_frames_to_time(frame_rates, boundary: np.array):
+ boundary = np.array(boundary)
+ boundary_times = boundary / frame_rates
+ return boundary_times
+
+
+def postprocess_functional_structure(
+ logits,
+ config,
+):
+ # pdb.set_trace()
+ boundary_logits = logits["boundary_logits"]
+ function_logits = logits["function_logits"]
+
+ assert boundary_logits.shape[0] == 1 and function_logits.shape[0] == 1, (
+ "Only batch size 1 is supported"
+ )
+ raw_prob_sections = torch.sigmoid(boundary_logits[0])
+ raw_prob_functions = torch.softmax(function_logits[0].transpose(0, 1), dim=0)
+
+ # filter_size=4 * cfg.min_hops_per_beat + 1
+ prob_sections, _ = local_maxima(
+ raw_prob_sections, filter_size=config.local_maxima_filter_size
+ )
+ prob_sections = prob_sections.cpu().numpy()
+
+ prob_functions = raw_prob_functions.cpu().numpy()
+
+ boundary_candidates = peak_picking(
+ boundary_activation=prob_sections,
+ window_past=int(12 * config.frame_rates), # 原来是fps
+ window_future=int(12 * config.frame_rates),
+ )
+ boundary = boundary_candidates > 0.0
+
+ duration = len(prob_sections) / config.frame_rates
+ pred_boundary_times = event_frames_to_time(
+ frame_rates=config.frame_rates, boundary=np.flatnonzero(boundary)
+ )
+ if pred_boundary_times[0] != 0:
+ pred_boundary_times = np.insert(pred_boundary_times, 0, 0)
+ if pred_boundary_times[-1] != duration:
+ pred_boundary_times = np.append(pred_boundary_times, duration)
+ pred_boundaries = np.stack([pred_boundary_times[:-1], pred_boundary_times[1:]]).T
+
+ pred_boundary_indices = np.flatnonzero(boundary)
+ pred_boundary_indices = pred_boundary_indices[pred_boundary_indices > 0]
+ prob_segment_function = np.split(prob_functions, pred_boundary_indices, axis=1)
+ pred_labels = [p.mean(axis=1).argmax().item() for p in prob_segment_function]
+
+ segments: MsaInfo = []
+ for (start, end), label in zip(pred_boundaries, pred_labels):
+ segment = (float(start), str(ID_TO_LABEL[label]))
+ segments.append(segment)
+
+ segments.append((float(pred_boundary_times[-1]), "end"))
+ return segments
diff --git a/src/SongFormer/postprocessing/helpers.py b/src/SongFormer/postprocessing/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd46d8b89829e74ecd677d27e9bd03023b31920c
--- /dev/null
+++ b/src/SongFormer/postprocessing/helpers.py
@@ -0,0 +1,101 @@
+# This file contains code adapted from the following sources:
+# [MIT license] https://github.com/mir-aidj/all-in-one/blob/main/src/allin1/postprocessing/helpers.py
+
+import numpy as np
+import torch.nn.functional as F
+import torch
+import librosa
+from typing import Union
+from scipy.signal import argrelextrema
+from scipy.interpolate import interp1d
+from numpy.lib.stride_tricks import sliding_window_view
+from numpy.typing import NDArray
+
+
+def local_maxima(tensor, filter_size=41):
+ assert len(tensor.shape) in (1, 2), "Input tensor should have 1 or 2 dimensions"
+ assert filter_size % 2 == 1, "Filter size should be an odd number"
+
+ original_shape = tensor.shape
+ if len(original_shape) == 1:
+ tensor = tensor.unsqueeze(0)
+
+ # Pad the input array with the minimum value
+ padding = filter_size // 2
+ padded_arr = F.pad(tensor, (padding, padding), mode="constant", value=-torch.inf)
+
+ # Create a rolling window view of the padded array
+ rolling_view = padded_arr.unfold(1, filter_size, 1)
+
+ # Find the indices of the local maxima
+ center = filter_size // 2
+ local_maxima_mask = torch.eq(
+ rolling_view[:, :, center], torch.max(rolling_view, dim=-1).values
+ )
+ local_maxima_indices = local_maxima_mask.nonzero()
+
+ # Initialize a new PyTorch tensor with zeros and the same shape as the input tensor
+ output_arr = torch.zeros_like(tensor)
+
+ # Set the local maxima values in the output tensor
+ output_arr[local_maxima_mask] = tensor[local_maxima_mask]
+
+ output_arr = output_arr.reshape(original_shape)
+
+ return output_arr, local_maxima_indices
+
+
+def local_maxima_numpy(arr, order=20):
+ is_batch = len(arr.shape) == 2
+ if is_batch:
+ return np.stack([local_maxima_numpy(x, order) for x in arr])
+
+ # Define a comparison function for argrelextrema to find local maxima
+ compare_func = np.greater
+
+ # Find the indices of the local maxima
+ local_maxima_indices = argrelextrema(arr, compare_func, order=order)
+
+ # Initialize a new numpy array with zeros and the same shape as the input array
+ output_arr = np.zeros_like(arr)
+
+ # Set the local maxima values in the output array
+ output_arr[local_maxima_indices] = arr[local_maxima_indices]
+
+ return output_arr
+
+
+def peak_picking(boundary_activation, window_past=12, window_future=6):
+ # Find local maxima using a sliding window
+ window_size = window_past + window_future
+ assert window_size % 2 == 0, "window_past + window_future must be even"
+ window_size += 1
+
+ # Pad boundary_activation
+ boundary_activation_padded = np.pad(
+ boundary_activation, (window_past, window_future), mode="constant"
+ )
+ max_filter = sliding_window_view(boundary_activation_padded, window_size)
+ local_maxima = (boundary_activation == np.max(max_filter, axis=-1)) & (
+ boundary_activation > 0
+ )
+
+ # Compute strength values by subtracting the mean of the past and future windows
+ past_window_filter = sliding_window_view(
+ boundary_activation_padded[: -(window_future + 1)], window_past
+ )
+ future_window_filter = sliding_window_view(
+ boundary_activation_padded[window_past + 1 :], window_future
+ )
+ past_mean = np.mean(past_window_filter, axis=-1)
+ future_mean = np.mean(future_window_filter, axis=-1)
+ strength_values = boundary_activation - ((past_mean + future_mean) / 2)
+
+ # Get boundary candidates and their corresponding strength values
+ boundary_candidates = np.flatnonzero(local_maxima)
+ strength_values = strength_values[boundary_candidates]
+
+ strength_activations = np.zeros_like(boundary_activation)
+ strength_activations[boundary_candidates] = strength_values
+
+ return strength_activations
diff --git a/src/SongFormer/train/accelerate_config/single_gpu.yaml b/src/SongFormer/train/accelerate_config/single_gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1d8e95b6998c7eb8d3f890510965d40abde22999
--- /dev/null
+++ b/src/SongFormer/train/accelerate_config/single_gpu.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: 'NO'
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 1
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/SongFormer/utils/average_checkpoints.py b/src/SongFormer/utils/average_checkpoints.py
new file mode 100644
index 0000000000000000000000000000000000000000..783a87d3a30aa27160746cf14cdf3dce78d11fb9
--- /dev/null
+++ b/src/SongFormer/utils/average_checkpoints.py
@@ -0,0 +1,152 @@
+import torch
+import copy
+from typing import List, Dict, Any
+
+
+def average_checkpoints(checkpoint_paths: List[str], output_path: str = None):
+ """
+ Average the model and model_ema weights from multiple checkpoints
+
+ Parameters:
+ checkpoint_paths: List of checkpoint file paths
+ output_path: Output path; if None, return the averaged checkpoint dictionary
+
+ Returns:
+ Averaged checkpoint dictionary
+ """
+ if not checkpoint_paths:
+ raise ValueError("At least one checkpoint path is required")
+
+ # Load the first checkpoint as the base
+ print(f"Loading base checkpoint: {checkpoint_paths[0]}")
+ avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
+
+ if len(checkpoint_paths) == 1:
+ if output_path:
+ torch.save(avg_checkpoint, output_path)
+ return avg_checkpoint
+
+ # Initialize accumulators
+ avg_model_state = copy.deepcopy(avg_checkpoint["model"])
+ avg_model_ema_state = None
+
+ if "model_ema" in avg_checkpoint:
+ avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"])
+
+ # Accumulate the weights from the other checkpoints
+ for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
+ print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+
+ # Accumulate model weights
+ for key in avg_model_state.keys():
+ if key in ckpt["model"]:
+ avg_model_state[key] += ckpt["model"][key]
+
+ # Accumulate model_ema weights (if available)
+ if avg_model_ema_state is not None and "model_ema" in ckpt:
+ for key in avg_model_ema_state.keys():
+ if key in ckpt["model_ema"]:
+ avg_model_ema_state[key] += ckpt["model_ema"][key]
+
+ # Compute the average
+ num_checkpoints = len(checkpoint_paths)
+ print(f"Averaging over {num_checkpoints} checkpoints...")
+
+ for key in avg_model_state.keys():
+ avg_model_state[key] = avg_model_state[key] / num_checkpoints
+
+ if avg_model_ema_state is not None:
+ for key in avg_model_ema_state.keys():
+ avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints
+
+ # Update the checkpoint dictionary
+ avg_checkpoint["model"] = avg_model_state
+ if avg_model_ema_state is not None:
+ avg_checkpoint["model_ema"] = avg_model_ema_state
+
+ # Save (if an output path is specified)
+ if output_path:
+ print(f"Saving averaged checkpoint to: {output_path}")
+ torch.save(avg_checkpoint, output_path)
+
+ return avg_checkpoint
+
+
+def average_checkpoints_memory_efficient(
+ checkpoint_paths: List[str], output_path: str = None
+):
+ """
+ Memory efficient version: Load and process checkpoints one by one, suitable for large models
+ """
+ if not checkpoint_paths:
+ raise ValueError("At least one checkpoint path is required")
+
+ print(f"Loading base checkpoint: {checkpoint_paths[0]}")
+ avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
+
+ if len(checkpoint_paths) == 1:
+ if output_path:
+ torch.save(avg_checkpoint, output_path)
+ return avg_checkpoint
+
+ # Convert to float32 for better precision
+ for key in avg_checkpoint["model"].keys():
+ avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float()
+
+ if "model_ema" in avg_checkpoint:
+ for key in avg_checkpoint["model_ema"].keys():
+ avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float()
+
+ # Load and accumulate checkpoints one by one
+ for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
+ print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
+ ckpt = torch.load(ckpt_path, map_location="cpu")
+
+ # Accumulate model weights
+ for key in avg_checkpoint["model"].keys():
+ if key in ckpt["model"]:
+ avg_checkpoint["model"][key] += ckpt["model"][key].float()
+
+ # Accumulate model_ema weights
+ if "model_ema" in avg_checkpoint and "model_ema" in ckpt:
+ for key in avg_checkpoint["model_ema"].keys():
+ if key in ckpt["model_ema"]:
+ avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float()
+
+ # Free memory
+ del ckpt
+ torch.cuda.empty_cache()
+
+ # Compute the average
+ num_checkpoints = len(checkpoint_paths)
+ print(f"Averaging over {num_checkpoints} checkpoints...")
+
+ for key in avg_checkpoint["model"].keys():
+ avg_checkpoint["model"][key] /= num_checkpoints
+
+ if "model_ema" in avg_checkpoint:
+ for key in avg_checkpoint["model_ema"].keys():
+ avg_checkpoint["model_ema"][key] /= num_checkpoints
+
+ if output_path:
+ print(f"Saving averaged checkpoint to: {output_path}")
+ torch.save(avg_checkpoint, output_path)
+
+ return avg_checkpoint
+
+
+# Example usage
+if __name__ == "__main__":
+ # Method 1: Simple usage
+ checkpoint_paths = []
+
+ # Average and save
+ average_checkpoints(checkpoint_paths, "")
+
+ # Method 2: Get the averaged checkpoint and further process it
+ # avg_ckpt = average_checkpoints(checkpoint_paths)
+ # print("Averaged checkpoint keys:", avg_ckpt.keys())
+
+ # Method 3: Use memory-efficient version (suitable for large models)
+ # average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')
diff --git a/src/SongFormer/utils/convert_res2msa_txt.py b/src/SongFormer/utils/convert_res2msa_txt.py
new file mode 100644
index 0000000000000000000000000000000000000000..e598ec5553203bb81ebe7e5fc7125984dac0565e
--- /dev/null
+++ b/src/SongFormer/utils/convert_res2msa_txt.py
@@ -0,0 +1,79 @@
+import json
+import os
+from pathlib import Path
+import fire
+
+
+def convert_json_to_format(json_data):
+ """Convert JSON data to the specified format"""
+ result = []
+
+ # Process the start time and label for each segment
+ for segment in json_data:
+ start_time = segment["start"]
+ label = segment["label"]
+ result.append(f"{start_time:.6f} {label}")
+
+ # Add the last end time
+ if json_data:
+ last_end_time = json_data[-1]["end"]
+ result.append(f"{last_end_time:.6f} end")
+
+ return "\n".join(result)
+
+
+def process_json_files(input_folder, output_folder):
+ """Process all JSON files in the input folder"""
+
+ # Create the output folder if it doesn't exist
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
+
+ # Get all JSON files
+ json_files = [f for f in os.listdir(input_folder) if f.endswith(".json")]
+
+ if not json_files:
+ print(f"No JSON files found in {input_folder}")
+ return
+
+ print(f"Found {len(json_files)} JSON files")
+
+ # Process each JSON file
+ for json_file in json_files:
+ input_path = os.path.join(input_folder, json_file)
+
+ try:
+ # Read the JSON file
+ with open(input_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ # Convert the format
+ converted_data = convert_json_to_format(data)
+
+ # Generate the output filename (replace .json with .txt)
+ output_filename = json_file.replace(".json", ".txt")
+ output_path = os.path.join(output_folder, output_filename)
+
+ # Write to the output file
+ with open(output_path, "w", encoding="utf-8") as f:
+ f.write(converted_data)
+
+ print(f"✓ Processed: {json_file} -> {output_filename}")
+
+ except Exception as e:
+ print(f"✗ Error processing {json_file}: {str(e)}")
+
+
+def main(input_folder: str, output_folder: str):
+ print(f"Input folder: {input_folder}")
+ print(f"Output folder: {output_folder}")
+ print("-" * 50)
+
+ # Process the files
+ process_json_files(input_folder, output_folder)
+
+ print("-" * 50)
+ print("Processing complete!")
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/src/SongFormer/utils/fetch_pretrained.py b/src/SongFormer/utils/fetch_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..584fce8affe979ebfdb73d9a360d63adebb0a46d
--- /dev/null
+++ b/src/SongFormer/utils/fetch_pretrained.py
@@ -0,0 +1,40 @@
+import os
+import requests
+from tqdm import tqdm
+
+
+def download(url, path):
+ if os.path.exists(path):
+ print(f"File already exists, skipping download: {path}")
+ return
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get("content-length", 0))
+ with (
+ open(path, "wb") as f,
+ tqdm(
+ desc=path,
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar,
+ ):
+ for data in response.iter_content(chunk_size=1024):
+ size = f.write(data)
+ bar.update(size)
+
+
+# 根据 https://github.com/minzwon/musicfm 下载预训练模型
+download(
+ "https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json",
+ os.path.join("ckpts", "MusicFM", "msd_stats.json"),
+)
+download(
+ "https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
+ os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"),
+)
+
+# for Mainland China
+# download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/msd_stats.json', os.path.join("ckpts", "MusicFM", "msd_stats.json"))
+# download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/pretrained_msd.pt', os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"))
diff --git a/src/third_party/MuQ/.gitattributes b/src/third_party/MuQ/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..dfe0770424b2a19faf507a501ebfc23be8f54e7b
--- /dev/null
+++ b/src/third_party/MuQ/.gitattributes
@@ -0,0 +1,2 @@
+# Auto detect text files and perform LF normalization
+* text=auto
diff --git a/src/third_party/MuQ/.gitignore b/src/third_party/MuQ/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..4d4668182b454060a00e13f8ce3c3ec262038737
--- /dev/null
+++ b/src/third_party/MuQ/.gitignore
@@ -0,0 +1,46 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+*.egg*/
+*pyc
+
+# Distribution / packaging
+.Python
+env/
+build/
+dist/
+*.log
+
+# pyenv
+.python-version
+
+# dotenv
+.env
+
+# virtualenv
+.venv/
+venv/
+ENV/
+
+# VSCode settings
+.vscode
+
+# IDEA files
+.idea
+
+# OSX dir files
+.DS_Store
+
+# Sublime Text settings
+*.sublime-workspace
+*.sublime-project
+
+# custom
+open/
+src/recipes/pretrain/dataset/music4all/*.json
+src/recipes/contrastive_learning/datasets/mtg-jamendo/*.json
+runs/
+output/
+logs
+outputs/
\ No newline at end of file
diff --git a/src/third_party/MuQ/.gitmodules b/src/third_party/MuQ/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..ebf2b4a60ee1c518f4db7c6458ea84a32ce927f9
--- /dev/null
+++ b/src/third_party/MuQ/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "src/recipes/pretrain/fairseq"]
+ path = src/recipes/pretrain/fairseq
+ url = https://github.com/facebookresearch/fairseq
diff --git a/src/third_party/MuQ/LICENSE b/src/third_party/MuQ/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d690c4b29c6b3e3b2f83dda027d06de3c34bfafb
--- /dev/null
+++ b/src/third_party/MuQ/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) Tencent.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/src/third_party/MuQ/LICENSE_weights b/src/third_party/MuQ/LICENSE_weights
new file mode 100644
index 0000000000000000000000000000000000000000..e395ca3e2cdebf48a6375a3c1022d10caabba7db
--- /dev/null
+++ b/src/third_party/MuQ/LICENSE_weights
@@ -0,0 +1,399 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/src/third_party/MuQ/README.md b/src/third_party/MuQ/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a715493cbf7a3ee0b87f81ea96de8ffa9b419c29
--- /dev/null
+++ b/src/third_party/MuQ/README.md
@@ -0,0 +1,129 @@
+#
MuQ & MuQ-MuLan
+
+
+
+This is the official repository for the paper *"**MuQ**: Self-Supervised **Mu**sic Representation Learning
+ with Mel Residual Vector **Q**uantization"*.
+
+In this repo, the following models are released:
+
+- **MuQ**: A large music foundation model pre-trained via Self-Supervised Learning (SSL), achieving SOTA in various MIR tasks.
+- **MuQ-MuLan**: A music-text joint embedding model trained via contrastive learning, supporting both English and Chinese texts.
+
+## Overview
+
+We develop the **MuQ** for music SSL. MuQ applys our proposed Mel-RVQ as quantitative targets and achieves SOTA performance on many music understanding (or MIR) tasks.
+
+We also construct the **MuQ-MuLan**, a CLIP-like model trained by contrastive learning, which jointly represents music and text into embeddings.
+
+For more details, please refer to our [paper](https://arxiv.org/abs/2501.01108).
+
+
+

+

+
+
+## Usage
+
+To begin with, please use pip to install the official `muq` lib, and ensure that your `python>=3.8`:
+```bash
+pip3 install muq
+```
+
+
+To extract music audio features using **MuQ**, you can refer to the following code:
+```python
+import torch, librosa
+from muq import MuQ
+
+device = 'cuda'
+wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
+wavs = torch.tensor(wav).unsqueeze(0).to(device)
+
+# This will automatically fetch the checkpoint from huggingface
+muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
+muq = muq.to(device).eval()
+
+with torch.no_grad():
+ output = muq(wavs, output_hidden_states=True)
+
+print('Total number of layers: ', len(output.hidden_states))
+print('Feature shape: ', output.last_hidden_state.shape)
+
+```
+
+Using **MuQ-MuLan** to extract the music and text embeddings and calculate the similarity:
+```python
+import torch, librosa
+from muq import MuQMuLan
+
+# This will automatically fetch checkpoints from huggingface
+device = 'cuda'
+mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
+mulan = mulan.to(device).eval()
+
+# Extract music embeddings
+wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
+wavs = torch.tensor(wav).unsqueeze(0).to(device)
+with torch.no_grad():
+ audio_embeds = mulan(wavs = wavs)
+
+# Extract text embeddings (texts can be in English or Chinese)
+texts = ["classical genres, hopeful mood, piano.", "一首适合海边风景的小提琴曲,节奏欢快"]
+with torch.no_grad():
+ text_embeds = mulan(texts = texts)
+
+# Calculate dot product similarity
+sim = mulan.calc_similarity(audio_embeds, text_embeds)
+print(sim)
+```
+
+> Note that both MuQ and MuQ-MuLan strictly require **24 kHz** audio as input.
+> We recommend using **fp32** during MuQ inference to avoid potential NaN issues.
+
+
+## Performance
+
+
+
+
+## Model Checkpoints
+
+| Model Name | Parameters | Data | HuggingFace🤗 |
+| ----------- | --- | --- | ----------- |
+| MuQ | ~300M | MSD dataset | [OpenMuQ/MuQ-large-msd-iter](https://huggingface.co/OpenMuQ/MuQ-large-msd-iter) |
+| MuQ-MuLan | ~700M | music-text pairs | [OpenMuQ/MuQ-MuLan-large](https://huggingface.co/OpenMuQ/MuQ-MuLan-large) |
+
+**Note**: Please note that the open-sourced MuQ was trained on the Million Song Dataset. Due to differences in dataset size, the open-sourced model may not achieve the same level of performance as reported in the paper. The training recipes can be found [here](./src/recipes).
+
+## License
+
+The code in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file.
+
+The model weights (MuQ-large-msd-iter, MuQ-MuLan-large) in this repository are released under the CC-BY-NC 4.0 license, as detailed in the [LICENSE_weights](LICENSE_weights) file.
+
+## Citation
+
+```
+@article{zhu2025muq,
+ title={MuQ: Self-Supervised Music Representation Learning with Mel Residual Vector Quantization},
+ author={Haina Zhu and Yizhi Zhou and Hangting Chen and Jianwei Yu and Ziyang Ma and Rongzhi Gu and Yi Luo and Wei Tan and Xie Chen},
+ journal={arXiv preprint arXiv:2501.01108},
+ year={2025}
+}
+```
+
+## Acknowledgement
+
+We borrow many codes from the following repositories:
+- [lucidrains/musiclm-pytorch](https://github.com/lucidrains/musiclm-pytorch)
+- [minzwon/musicfm](https://github.com/minzwon/musicfm)
+
+
+Also, we are especially grateful to the awesome [MARBLE-Benchmark](https://github.com/a43992899/MARBLE-Benchmark).
diff --git a/src/third_party/MuQ/images/muq-logo.jpeg b/src/third_party/MuQ/images/muq-logo.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..37cd1531747f00a00d001e43946f012053a15146
Binary files /dev/null and b/src/third_party/MuQ/images/muq-logo.jpeg differ
diff --git a/src/third_party/MuQ/images/radar.jpg b/src/third_party/MuQ/images/radar.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..69961a7e8546e39427825c19717800633d7faff8
--- /dev/null
+++ b/src/third_party/MuQ/images/radar.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c128d768e4888aa0bfbef2d3caa47e819b81840b153c2e4265fd40d921c3685
+size 43098
diff --git a/src/third_party/MuQ/images/tab-marble.jpg b/src/third_party/MuQ/images/tab-marble.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5d85812cd86a6776f1cda603bc53fa2f0c9b697d
--- /dev/null
+++ b/src/third_party/MuQ/images/tab-marble.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7287c7741b06062fb5cb57b10149c9138cbb56ad3eabef7e3b957ea32db1639
+size 263577
diff --git a/src/third_party/MuQ/images/tab-mulan.png b/src/third_party/MuQ/images/tab-mulan.png
new file mode 100644
index 0000000000000000000000000000000000000000..38943c8679a1e2a35a4243281b58166010336b41
--- /dev/null
+++ b/src/third_party/MuQ/images/tab-mulan.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f473ebd635d2f0c5e4f3fdf7b0a11b9b99ad000215035b4dd412e0c9f7fa3304
+size 83551
diff --git a/src/third_party/MuQ/images/tagging.jpg b/src/third_party/MuQ/images/tagging.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ffa4c311b055f3641764445b4e87ce6cda134a05
--- /dev/null
+++ b/src/third_party/MuQ/images/tagging.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57717afef5c8341b64d685410311bc1335752508c0d127b6573a226688eb61b0
+size 44360
diff --git a/src/third_party/MuQ/requirements.txt b/src/third_party/MuQ/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..21619ff79263184ec409ffac46ff7967d6ec4160
--- /dev/null
+++ b/src/third_party/MuQ/requirements.txt
@@ -0,0 +1,11 @@
+einops
+librosa
+nnAudio
+numpy
+soundfile
+torch
+torchaudio
+tqdm
+transformers
+easydict
+x_clip
\ No newline at end of file
diff --git a/src/third_party/MuQ/setup.py b/src/third_party/MuQ/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..1006570f10a897f88b59a4e92cdb869c1299a34e
--- /dev/null
+++ b/src/third_party/MuQ/setup.py
@@ -0,0 +1,34 @@
+from setuptools import setup, find_packages
+
+setup(
+ name='muq', # Name of the package
+ version='0.1.0', # Version of the package
+ packages=find_packages(where='src'), # Automatically discover packages under the 'src' directory
+ package_dir={'': 'src'}, # Specify the root directory for packages as 'src'
+ include_package_data=True, # Include additional files, such as static files
+ install_requires=[ # List of dependencies
+ "einops",
+ "librosa",
+ "nnAudio",
+ "numpy",
+ "soundfile",
+ "torch",
+ "torchaudio",
+ "tqdm",
+ "transformers",
+ "easydict",
+ "x_clip",
+ ],
+ author='Haina Zhu', # Author name
+ author_email='juhayna@qq.com', # Author email address
+ description='MuQ: A deep learning model for music and text', # Short description of the package
+ long_description=open('README.md', encoding='utf-8').read(), # Long description from the README file
+ long_description_content_type='text/markdown', # Format of the long description (Markdown)
+ url='https://github.com/tencent-ailab/MuQ', # Project URL
+ classifiers=[
+ 'Programming Language :: Python :: 3', # Python 3 support
+ 'License :: OSI Approved :: MIT License', # License type
+ 'Operating System :: OS Independent', # Supports all operating systems
+ ],
+ python_requires='>=3.8', # Supported Python version
+)
diff --git a/src/third_party/MuQ/src/muq/__init__.py b/src/third_party/MuQ/src/muq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7575a352fcd9af16c791d2ce5dae3dfe72521b6f
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/__init__.py
@@ -0,0 +1,2 @@
+from .muq import MuQ, MuQConfig
+from .muq_mulan import MuQMuLan, MuQMuLanConfig
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq/__init__.py b/src/third_party/MuQ/src/muq/muq/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..40a3cb62132cb3b3628b977c2a13563b78fe59c1
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/__init__.py
@@ -0,0 +1 @@
+from .muq import MuQConfig, MuQ
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq/models/__init__.py b/src/third_party/MuQ/src/muq/muq/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/third_party/MuQ/src/muq/muq/models/muq_model.py b/src/third_party/MuQ/src/muq/muq/models/muq_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd40397376cd25960909619aff68fdc3846d02d9
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/models/muq_model.py
@@ -0,0 +1,366 @@
+import json
+import random
+import torch
+from torch import nn
+from einops import rearrange
+import os
+from easydict import EasyDict
+
+from ..modules.random_quantizer import RandomProjectionQuantizer
+from ..modules.features import MelSTFT
+from ..modules.conv import Conv2dSubsampling
+
+class MuQModel(nn.Module):
+
+ def __init__(
+ self,
+ num_codebooks=1,
+ codebook_dim=16,
+ codebook_size=4096,
+ features=["melspec_2048"],
+ hop_length=240,
+ n_mels=128,
+ conv_dim=512,
+ encoder_dim=1024,
+ encoder_depth=12,
+ mask_hop=0.4,
+ mask_prob=0.6,
+ is_flash=False,
+ stat=dict(),
+ w2v2_config=dict(),
+ use_rvq_target=False,
+ use_vq_target=False,
+ use_encodec_target=False,
+ rvq_ckpt_path=None,
+ recon_loss_ratio=None,
+ label_rate=25,
+ rvq_n_codebooks=8,
+ rvq_multi_layer_num=1,
+ ):
+ super().__init__()
+
+ # global variables
+ self.hop_length = hop_length
+ self.mask_hop = mask_hop
+ self.mask_prob = mask_prob
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.features = features
+ self.recon_loss_ratio = recon_loss_ratio
+ self.n_fold = int(100//label_rate)
+ self.label_rate = label_rate
+
+ # load feature mean / std stats
+ self.stat = stat
+
+ # feature extractor
+ self.preprocessor_melspec_2048 = MelSTFT(
+ n_fft=2048, hop_length=hop_length, is_db=True
+ )
+
+ # random quantizer
+ self.use_rvq_target = use_rvq_target
+ self.use_vq_target = use_vq_target
+ self.use_encodec_target = use_encodec_target
+
+ seed = 142
+ if self.use_rvq_like_target:
+ if use_rvq_target:
+ from ..modules.rvq import ResidualVectorQuantize
+
+ inp_dim = 128*self.n_fold
+ self.rvq = ResidualVectorQuantize(
+ input_dim = inp_dim,
+ n_codebooks = rvq_n_codebooks,
+ codebook_size = 1024,
+ codebook_dim = 16,
+ quantizer_dropout = 0.0,
+ use_multi_layer_num = rvq_multi_layer_num,
+ )
+ elif use_vq_target:
+ from ..modules.rvq import VectorQuantize
+
+ self.rvq = VectorQuantize(
+ input_dim = 128*self.n_fold,
+ codebook_size = 1024,
+ codebook_dim = 8,
+ stale_tolerance = 1000,
+ mfcc_clustering = False
+ )
+ elif use_encodec_target:
+ from encodec import EncodecModel
+ self.rvq = EncodecModel.encodec_model_24khz()
+ self.rvq.set_target_bandwidth(6.0)
+ for param in self.rvq.parameters():
+ param.requires_grad = False
+
+ if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path):
+ state_dict = torch.load(rvq_ckpt_path, map_location="cpu")
+ self.rvq.load_state_dict(state_dict)
+ else:
+ pass
+ # print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.')
+ else:
+ for feature in self.features:
+ for i in range(num_codebooks):
+ setattr(
+ self,
+ f"quantizer_{feature}", # _{i}
+ RandomProjectionQuantizer(
+ n_mels * self.n_fold, codebook_dim, codebook_size, seed=seed + i
+ ),
+ )
+
+ # two residual convolution layers + one projection layer
+ strides_factory = {
+ 4: [2, 2],
+ 2: [2, 1]
+ }
+ self.conv = Conv2dSubsampling(
+ 1, conv_dim, encoder_dim, strides=strides_factory.get(self.n_fold), n_bands=n_mels
+ )
+
+ # Conformer
+ if is_flash:
+ from modules.flash_conformer import (
+ Wav2Vec2ConformerEncoder,
+ Wav2Vec2ConformerConfig,
+ )
+ else:
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ Wav2Vec2ConformerEncoder,
+ Wav2Vec2ConformerConfig,
+ )
+ config = EasyDict(w2v2_config)
+ config.num_hidden_layers = encoder_depth
+ config.hidden_size = encoder_dim
+
+ self.conformer = Wav2Vec2ConformerEncoder(config)
+
+ self.linear = nn.Linear(encoder_dim, codebook_size) # projection layer
+
+ # reconstruct melspec
+ if self.recon_loss_ratio is not None and self.recon_loss_ratio > 0:
+ self.recon_proj = nn.Linear(encoder_dim, n_mels * self.n_fold)
+ self.recon_loss = nn.MSELoss()
+
+ # loss function
+ self.loss = nn.CrossEntropyLoss()
+
+ # cls token (used for sequence classification)
+ random.seed(seed)
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
+
+
+ @property
+ def use_rvq_like_target(self):
+ return self.use_rvq_target or self.use_vq_target or self.use_encodec_target
+
+ def masking(self, x, attention_mask=None):
+ """random masking of 400ms with given probability"""
+ mx = x.clone()
+ b, t = mx.shape
+ len_masking_raw = int(24000 * self.mask_hop)
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
+
+ # get random mask indices
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
+ time_domain_masked_indices = torch.nonzero(
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
+ )
+ token_domain_masked_indices = torch.nonzero(
+ start_indices.repeat_interleave(len_masking_token, dim=1)
+ )
+
+ # mask with random values
+ masking_noise = (
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
+ ) # 0 mean 0.1 std
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
+
+ return mx, token_domain_masked_indices
+
+
+ @torch.no_grad()
+ def preprocessing(self, x, features):
+ """extract classic audio features"""
+ # check precision
+ if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
+ precision = 16
+ else:
+ precision = 32
+
+ out = {}
+ for key in features:
+ layer = getattr(self, "preprocessor_%s" % key)
+ layer.to(x.device)
+ dtype = x.dtype
+ out[key] = layer(x.float())[..., :-1]
+ if precision == 16:
+ out[key] = out[key].half()
+ if out[key].dtype != dtype:
+ out[key].to(dtype=dtype)
+ return out
+
+ def encoder(self, x, *, attention_mask=None, is_features_only=False):
+ """2-layer conv + w2v-conformer"""
+ x = self.conv(x)
+ mask_indices = None
+ if attention_mask is None:
+ out = self.conformer(x, output_hidden_states=True)
+ else:
+ attention_mask = attention_mask.bool()
+ skip_n = int(attention_mask.size(-1) / x.size(1))
+ attention_mask = attention_mask[:, ::skip_n]
+ attention_mask = attention_mask[:, :x.size(1)]
+ out = self.conformer(x, attention_mask=attention_mask, output_hidden_states=True)
+ hidden_emb = out["hidden_states"]
+ last_emb = out["last_hidden_state"]
+ logits = self.linear(last_emb)
+ interval = self.codebook_size
+ logits = {
+ key: logits[:, :, i * interval : (i + 1) * interval]
+ for i, key in enumerate(self.features)
+ }
+ return logits, hidden_emb, mask_indices
+
+ @torch.no_grad()
+ def normalize(self, x):
+ """normalize the input audio to have zero mean unit variance"""
+ for key in x.keys():
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
+ return x
+
+ @torch.no_grad()
+ def rearrange(self, x):
+ """rearrange the batch to flatten every 4 steps"""
+ for key in x.keys():
+ if key == "chromagram":
+ x[key] = rearrange(x[key], "b f t -> b t f")
+ else:
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.n_fold)
+ return x
+
+ def get_rvq_codes(self, inp, raw_wav):
+ if self.use_rvq_target:
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(inp)
+ return codes
+ if self.use_vq_target:
+ quantized_prompt_embeds, commitment_loss, codebook_loss, codes, _ = self.rvq(inp)
+ return codes.unsqueeze(1)
+ if self.use_encodec_target:
+ encoded_frames = self.rvq.encode(raw_wav.unsqueeze(1)) #list, B,[ 8,T ]
+ codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1)
+ if self.label_rate == 25:
+ codes = codes[:, :, ::3]
+ return codes
+
+ @torch.no_grad()
+ def tokenize(self, x, raw_wav):
+ out = {}
+ for key in x.keys():
+ if self.use_rvq_like_target:
+ self.rvq.eval()
+ inp = x[key].permute((0, 2, 1))
+ codes = self.get_rvq_codes(inp, raw_wav)
+ out[key] = torch.cat([codes[:, idx, ...] for idx in range(int(self.codebook_size//1024))], dim=-1)
+ else:
+ layer = getattr(self, "quantizer_%s" % key)
+ out[key] = layer(x[key])
+ return out
+
+ def get_targets(self, x, label=None):
+ if self.use_encodec_target:
+ raw_x = x.clone()
+ else:
+ raw_x = None
+ x = self.preprocessing(x, features=self.features)
+ x = self.normalize(x)
+ x = self.rearrange(x)
+ melspec = x['melspec_2048']
+ if label is None:
+ # Use labels from Mel-RVQ
+ target_tokens = self.tokenize(x, raw_x)
+ else:
+ # Use labels pre-extracted for iteration training
+ target_tokens = {'melspec_2048': rearrange(label, "b n s -> b (n s)").long()}
+ return target_tokens, melspec
+
+ def get_predictions(self, x, *, mask=None, attention_mask=None, return_new_mask=False, is_features_only=False):
+ # preprocessing
+ x = self.preprocessing(x, features=["melspec_2048"])
+ x = self.normalize(x)
+
+ # encoding
+ logits, hidden_emb, new_mask = self.encoder(x["melspec_2048"], attention_mask=attention_mask, is_features_only=is_features_only)
+
+ if return_new_mask:
+ return logits, hidden_emb, mask if new_mask is None else new_mask
+ else:
+ return logits, hidden_emb
+
+ def get_latent(self, x, layer_ix=12):
+ _, hidden_states = self.get_predictions(x)
+ emb = hidden_states[layer_ix]
+ return emb
+
+ def compute_nce(self, x, pos, negs):
+ neg_is_pos = (pos == negs).all(-1)
+ pos = pos.unsqueeze(0)
+ targets = torch.cat([pos, negs], dim=0)
+
+ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
+ logits /= 0.1
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+ logits = logits.transpose(0, 1)
+ return logits
+
+ def get_loss(self, logits, target_tokens, masked_indices):
+ losses = {}
+ accuracies = {}
+ for key in logits.keys():
+ if not self.use_rvq_like_target:
+ masked_logits = logits[key][tuple(masked_indices.t())]
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
+ else:
+ Batch, SeqLen, N_Codebook_x_CodebookSize = logits[key].shape
+ Batch, N_Codebook_x_SeqLen = target_tokens[key].shape
+ N_Codebook = int(N_Codebook_x_SeqLen // SeqLen)
+ target_tokens[key] = rearrange(target_tokens[key], "b (n s) -> b s n", n=N_Codebook) # Batch, SeqLen=750, N_Codebook=4
+ masked_logits = logits[key][tuple(masked_indices.t())]
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
+ masked_logits = rearrange(masked_logits, "b (n c) -> (b n) c", n=N_Codebook)
+ masked_tokens = rearrange(masked_tokens, "b n -> (b n)", n=N_Codebook)
+
+ losses[key] = self.loss(masked_logits, masked_tokens)
+ accuracies[key] = (
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
+ / masked_tokens.numel()
+ )
+ return losses, accuracies
+
+ def get_recon_loss(self, last_hidden_emb, melspec, masked_indices):
+ pred_melspec = self.recon_proj(last_hidden_emb[tuple(masked_indices.t())])
+ target_melspec = melspec[tuple(masked_indices.t())]
+ recon_loss = self.recon_loss(pred_melspec, target_melspec)
+ return recon_loss
+
+ def forward(self, x, attention_mask=None, label=None):
+ dtype = x.dtype
+ # get target feature tokens
+ target_tokens, melspec = self.get_targets(x, label=label)
+
+ # masking
+ x, masked_indices = self.masking(x, attention_mask=attention_mask)
+
+ # forward
+ logits, hidden_emb, masked_indices = self.get_predictions(x, mask=masked_indices, attention_mask=attention_mask, return_new_mask=True)
+
+ # get loss
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
+
+ if self.recon_loss_ratio:
+ losses["recon_loss"] = self.get_recon_loss(hidden_emb[-1], melspec, masked_indices) * self.recon_loss_ratio
+
+ return logits, hidden_emb, losses, accuracies
diff --git a/src/third_party/MuQ/src/muq/muq/modules/__init__.py b/src/third_party/MuQ/src/muq/muq/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/__init__.py
@@ -0,0 +1,2 @@
+
+
diff --git a/src/third_party/MuQ/src/muq/muq/modules/conv.py b/src/third_party/MuQ/src/muq/muq/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4695371474a19789b77b3668d01e0f430f11a1
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/conv.py
@@ -0,0 +1,77 @@
+from torch import nn
+from einops import rearrange
+
+
+class Res2dModule(nn.Module):
+ def __init__(self, idim, odim, stride=(2, 2)):
+ super(Res2dModule, self).__init__()
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+ self.bn1 = nn.BatchNorm2d(odim)
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
+ self.bn2 = nn.BatchNorm2d(odim)
+ self.relu = nn.ReLU()
+
+ # residual
+ self.diff = False
+ if (idim != odim) or (stride[0] > 1):
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+ self.bn3 = nn.BatchNorm2d(odim)
+ self.diff = True
+
+ def forward(self, x):
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
+ if self.diff:
+ x = self.bn3(self.conv3(x))
+ out = x + out
+ out = self.relu(out)
+ return out
+
+
+class Conv2dSubsampling(nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ hdim (int): Hidden dimension.
+ odim (int): Output dimension.
+ strides (list): Sizes of strides.
+ n_bands (int): Number of frequency bands.
+ """
+
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
+ """Construct an Conv2dSubsampling object."""
+ super(Conv2dSubsampling, self).__init__()
+
+ self.conv = nn.Sequential(
+ Res2dModule(idim, hdim, (2, strides[0])),
+ Res2dModule(hdim, hdim, (2, strides[1])),
+ )
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
+
+ def forward(self, x):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, idim, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ """
+
+ if x.dim() == 3:
+ x = x.unsqueeze(1) # (b, c, f, t)
+ x = self.conv(x)
+ x = rearrange(x, "b c f t -> b t (c f)")
+ x = self.linear(x)
+ return x
+
+if __name__ == '__main__':
+ import torch
+ conv_dim, encoder_dim = 512, 1024
+ conv = Conv2dSubsampling(
+ 1, conv_dim, encoder_dim, strides=[2, 1], n_bands=128
+ )
+ inp = torch.randn((1, 128, 3000))
+ out = conv(inp)
+ print(out.shape)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq/modules/features.py b/src/third_party/MuQ/src/muq/muq/modules/features.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c206ae569b5b0d2d770ae4ab533ed45add81eb8
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/features.py
@@ -0,0 +1,37 @@
+import torchaudio
+from torch import nn
+import torch
+
+
+class MelSTFT:
+ def __init__(
+ self,
+ sample_rate=24000,
+ n_fft=2048,
+ hop_length=240,
+ n_mels=128,
+ is_db=False,
+ ):
+ super(MelSTFT, self).__init__()
+
+ # spectrogram
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
+ )
+
+ # amplitude to decibel
+ self.is_db = is_db
+ if is_db:
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
+
+ def __call__(self, waveform):
+ if self.is_db:
+ return self.amplitude_to_db(self.mel_stft(waveform))
+ else:
+ return self.mel_stft(waveform)
+
+ def to(self, device):
+ self.mel_stft = self.mel_stft.to(device)
+ if self.is_db:
+ self.amplitude_to_db = self.amplitude_to_db.to(device)
+ return self
diff --git a/src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py b/src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89012c476c27973748e2cee914dae3b400348465
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/flash_conformer.py
@@ -0,0 +1,2114 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+
+from transformers.activations import ACT2FN
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/wav2vec2-conformer-rel-pos-large",
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
+]
+
+
+@dataclass
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: torch.FloatTensor = None
+ projected_quantized_states: torch.FloatTensor = None
+ codevector_perplexity: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ contrastive_loss: Optional[torch.FloatTensor] = None
+ diversity_loss: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.sum(-1).detach().tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+ """
+ Sample `num_negatives` vectors from feature vectors.
+ """
+ batch_size, sequence_length = features_shape
+
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+ sequence_length_range = np.arange(sequence_length)
+
+ # get `num_negatives` random vector indices from the same utterance
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+ mask_time_indices = (
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
+ )
+
+ for batch_idx in range(batch_size):
+ high = mask_time_indices[batch_idx].sum() - 1
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+ # avoid sampling the same positive vector, but keep the distribution uniform
+ sampled_indices[sampled_indices >= feature_indices] += 1
+
+ # remap to actual indices
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+ # correct for batch size
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+ return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
+ else:
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
+ return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (iWav2Vec2Conformer
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ if self._requires_grad and self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(conv_layer),
+ hidden_states,
+ )
+ else:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerConvolutionModule(nn.Module):
+ """Convolution block used in the conformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.glu = torch.nn.GLU(dim=1)
+ self.depthwise_conv = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ config.conv_depthwise_kernel_size,
+ stride=1,
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.layer_norm(hidden_states)
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism
+ # => (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+ """Construct an Wav2Vec2ConformerSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_size = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+ self.dropout_p = config.attention_dropout
+
+ self.is_causal = config.is_causal
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
+ probs = None
+
+ # # apply attention_mask if necessary
+ # if attention_mask is not None:
+ # scores = scores + attention_mask
+
+ # # => (batch, head, time1, time2)
+ # probs = torch.softmax(scores, dim=-1)
+ # probs = self.dropout(probs)
+
+ # # => (batch, head, time1, d_k)
+ # hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+ def __init__(self, config, is_causal=False):
+ super().__init__()
+ config.is_causal = is_causal
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states[~attention_mask] = 0.0
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ # create gradient checkpointing function
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
+ """
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_groups = config.num_codevector_groups
+ self.num_vars = config.num_codevectors_per_group
+
+ if config.codevector_dim % self.num_groups != 0:
+ raise ValueError(
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ )
+
+ # storage for codebook variables (codewords)
+ self.codevectors = nn.Parameter(
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+ )
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+ # can be decayed for training
+ self.temperature = 2
+
+ @staticmethod
+ def _compute_perplexity(probs, mask=None):
+ if mask is not None:
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
+ marginal_probs = probs.sum(dim=0) / mask.sum()
+ else:
+ marginal_probs = probs.mean(dim=0)
+
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states, mask_time_indices=None):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2ConformerConfig
+ base_model_prefix = "wav2vec2_conformer"
+ main_input_name = "input_values"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
+ module.project_hid.reset_parameters()
+ module.project_q.reset_parameters()
+ module.project_hid._is_hf_initialized = True
+ module.project_q._is_hf_initialized = True
+ # gumbel softmax requires special init
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
+ module.gradient_checkpointing = value
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+ that these models also yield slightly different results depending on whether `input_values` is padded or
+ not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2ConformerEncoder(config)
+
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
+
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 0.1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
+ target_features
+ )
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.BoolTensor] = None,
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+ Required input for pre-training.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ ... _compute_mask_indices,
+ ... _sample_negative_indices,
+ ... )
+ >>> from datasets import load_dataset
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
+
+ >>> # compute masked indices
+ >>> batch_size, raw_sequence_length = input_values.shape
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
+ >>> mask_time_indices = _compute_mask_indices(
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
+ ... )
+ >>> sampled_negative_indices = _sample_negative_indices(
+ ... features_shape=(batch_size, sequence_length),
+ ... num_negatives=model.config.num_negatives,
+ ... mask_time_indices=mask_time_indices,
+ ... )
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
+ >>> sampled_negative_indices = torch.tensor(
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
+ ... )
+
+ >>> with torch.no_grad():
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
+
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ >>> # show that cosine similarity is much higher than random
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
+ tensor(True)
+
+ >>> # for contrastive loss training model should be put into train mode
+ >>> model = model.train()
+ >>> loss = model(
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
+ ... ).loss
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if mask_time_indices is not None:
+ mask_time_indices = mask_time_indices.to(torch.bool)
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ mask_time_indices=mask_time_indices,
+ return_dict=return_dict,
+ )
+
+ # 1. project all transformed features (including masked) to final vq dim
+ transformer_features = self.project_hid(outputs[0])
+
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+
+ if attention_mask is not None:
+ # compute reduced attention_mask correponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ quantized_features, codevector_perplexity = self.quantizer(
+ extract_features, mask_time_indices=mask_time_indices
+ )
+ quantized_features = self.project_q(quantized_features)
+
+ loss = contrastive_loss = diversity_loss = None
+ if sampled_negative_indices is not None:
+ batch_size, sequence_length, hidden_size = quantized_features.shape
+
+ # for training, we sample negatives
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
+ # sample negative quantized vectors BTC => (BxT)C
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
+ sampled_negative_indices.long().view(-1)
+ ]
+ negative_quantized_features = negative_quantized_features.view(
+ batch_size, sequence_length, -1, hidden_size
+ ).permute(2, 0, 1, 3)
+
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
+ logits = self.compute_contrastive_logits(
+ quantized_features[None, :],
+ negative_quantized_features,
+ transformer_features,
+ self.config.contrastive_logits_temperature,
+ )
+
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
+ # its cosine similarity will be masked
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
+
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
+
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
+ # 7. compute diversity loss: \mathbf{L}_d
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
+
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
+
+ if not return_dict:
+ if loss is not None:
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return Wav2Vec2ConformerForPreTrainingOutput(
+ loss=loss,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ contrastive_loss=contrastive_loss,
+ diversity_loss=diversity_loss,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ hidden_states[~padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super(AMSoftmaxLoss, self).__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = nn.functional.unfold(
+ hidden_states,
+ (self.kernel_size, self.in_conv_dim),
+ stride=(1, self.in_conv_dim),
+ dilation=(self.dilation, 1),
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.kernel(hidden_states)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py b/src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c40a74b7614b448332de0d671e617212aded7b1f
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/random_quantizer.py
@@ -0,0 +1,68 @@
+import torch
+from torch import nn, einsum
+from einops import rearrange
+
+
+class RandomProjectionQuantizer(nn.Module):
+ """
+ Random projection and codebook lookup module
+
+ Some code is borrowed from:
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ codebook_dim,
+ codebook_size,
+ seed=142,
+ ):
+ super().__init__()
+
+ # random seed
+ torch.manual_seed(seed)
+
+ # randomly initialized projection
+ random_projection = torch.empty(input_dim, codebook_dim)
+ nn.init.xavier_normal_(random_projection)
+ self.register_buffer("random_projection", random_projection)
+
+ # randomly initialized codebook
+ codebook = torch.empty(codebook_size, codebook_dim)
+ nn.init.normal_(codebook)
+ self.register_buffer("codebook", codebook)
+
+ def codebook_lookup(self, x):
+ # reshape
+ b = x.shape[0]
+ x = rearrange(x, "b n e -> (b n) e")
+
+ # L2 normalization
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
+
+ # compute distances
+ distances = torch.cdist(normalized_codebook, normalized_x)
+
+ # get nearest
+ nearest_indices = torch.argmin(distances, dim=0)
+
+ # reshape
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
+
+ return xq
+
+ @torch.no_grad()
+ def forward(self, x):
+ # always eval
+ self.eval()
+
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
+
+ # codebook lookup
+ xq = self.codebook_lookup(x)
+
+ return xq
diff --git a/src/third_party/MuQ/src/muq/muq/modules/rvq.py b/src/third_party/MuQ/src/muq/muq/modules/rvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7553cfab1b3233f1db91aa2853224c7320711cb
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/modules/rvq.py
@@ -0,0 +1,314 @@
+
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+try:
+ from torch.nn.utils import weight_norm
+except:
+ try:
+ from torch.nn.utils.parametrizations import weight_norm
+ except:
+ from torch.nn.utils.parametrize import weight_norm
+
+def WNConv1d(*args, **kwargs):
+ return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+class VectorQuantize(nn.Module):
+ """
+ Implementation of VQ similar to Karpathy's repo:
+ https://github.com/karpathy/deep-vector-quantization
+ Additionally uses following tricks from Improved VQGAN
+ (https://arxiv.org/pdf/2110.04627.pdf):
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+ for improved codebook usage
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+ improves training stability
+ """
+
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 1000, mfcc_clustering=False, n_layer=1):
+ super().__init__()
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim
+ self.mfcc_clustering = mfcc_clustering
+
+ ProjClass = nn.Identity if mfcc_clustering else WNConv1d
+ if n_layer==1:
+ self.in_proj = ProjClass(input_dim, codebook_dim, kernel_size=1)
+ self.out_proj = ProjClass(codebook_dim, input_dim, kernel_size=1)
+ elif n_layer >= 2:
+ ndim_hidden = 128
+ self.in_proj = nn.Sequential(
+ ProjClass(input_dim, ndim_hidden, kernel_size=1),
+ *[nn.Sequential(nn.ReLU(), ProjClass(ndim_hidden, ndim_hidden, kernel_size=1),) for _ in range(n_layer-2)],
+ nn.ReLU(),
+ ProjClass(ndim_hidden, codebook_dim, kernel_size=1)
+ )
+ self.out_proj = nn.Sequential(
+ ProjClass(codebook_dim, ndim_hidden, kernel_size=1),
+ nn.ReLU(),
+ *[nn.Sequential(ProjClass(ndim_hidden, ndim_hidden, kernel_size=1), nn.ReLU()) for _ in range(n_layer-2)],
+ ProjClass(ndim_hidden, input_dim, kernel_size=1),
+ )
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
+ self.register_buffer("stale_counter", torch.zeros(self.codebook_size,))
+ self.stale_tolerance = stale_tolerance
+
+ def forward(self, z):
+ """Quantized the input tensor using a fixed codebook and returns
+ the corresponding codebook vectors
+
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ Tensor[1]
+ Codebook loss to update the codebook
+ Tensor[B x T]
+ Codebook indices (quantized discrete representation of input)
+ Tensor[B x D x T]
+ Projected latents (continuous representation of input before quantization)
+ """
+
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+
+ z_e = self.in_proj(z) # z_e : (B x D x T)
+ z_q, indices = self.decode_latents(z_e)
+
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+ z_q = (
+ z_e + (z_q - z_e).detach()
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
+
+ z_q = self.out_proj(z_q)
+
+ return z_q, commitment_loss, codebook_loss, indices, z_e
+
+ def embed_code(self, embed_id):
+ return F.embedding(embed_id, self.codebook.weight)
+
+ def decode_code(self, embed_id):
+ return self.embed_code(embed_id).transpose(1, 2)
+
+ def decode_latents(self, latents):
+ encodings = rearrange(latents, "b d t -> (b t) d")
+ codebook = self.codebook.weight # codebook: (N x D)
+
+ # L2 normalize encodings and codebook (ViT-VQGAN)
+ encodings = F.normalize(encodings)
+ codebook = F.normalize(codebook)
+
+ # Compute euclidean distance with codebook
+ dist = (
+ encodings.pow(2).sum(1, keepdim=True)
+ - 2 * encodings @ codebook.t()
+ + codebook.pow(2).sum(1, keepdim=True).t()
+ )
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+ z_q = self.decode_code(indices)
+
+ if(self.training):
+ onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size
+ stale_codes = (onehots.sum(0).sum(0) == 0).float()
+ self.stale_counter = self.stale_counter * stale_codes + stale_codes
+
+ # random replace codes that haven't been used for a while
+ replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size
+ if replace_code.sum(-1) > 0:
+ print("Replace {} codes".format(replace_code.sum(-1)))
+ random_input_idx = torch.randperm(encodings.shape[0])
+ random_input = encodings[random_input_idx].view(encodings.shape)
+ if random_input.shape[0] < self.codebook_size:
+ random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0)
+ random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim
+
+ self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1)
+ self.stale_counter = self.stale_counter * (1 - replace_code)
+
+ return z_q, indices
+
+
+class ResidualVectorQuantize(nn.Module):
+ """
+ Introduced in SoundStream: An end2end neural audio codec
+ https://arxiv.org/abs/2107.03312
+ """
+
+ def __init__(
+ self,
+ input_dim: int = 512,
+ n_codebooks: int = 9,
+ codebook_size: int = 1024,
+ codebook_dim: Union[int, list] = 8,
+ quantizer_dropout: float = 0.0,
+ stale_tolerance: int = 100,
+ use_multi_layer_num:int = 1,
+ ):
+ super().__init__()
+ if isinstance(codebook_dim, int):
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+ self.n_codebooks = n_codebooks
+ self.codebook_dim = codebook_dim
+ self.codebook_size = codebook_size
+
+ self.quantizers = nn.ModuleList(
+ [
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance, n_layer=use_multi_layer_num)
+ for i in range(n_codebooks)
+ ]
+ )
+ self.quantizer_dropout = quantizer_dropout
+
+ def forward(self, z, n_quantizers: int = None):
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
+ the corresponding codebook vectors
+ Parameters
+ ----------
+ z : Tensor[B x D x T]
+ n_quantizers : int, optional
+ No. of quantizers to use
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
+ when in training mode, and a random number of quantizers is used.
+ Returns
+ -------
+ dict
+ A dictionary with the following keys:
+
+ "z" : Tensor[B x D x T]
+ Quantized continuous representation of input
+ "codes" : Tensor[B x N x T]
+ Codebook indices for each codebook
+ (quantized discrete representation of input)
+ "latents" : Tensor[B x N*D x T]
+ Projected latents (continuous representation of input before quantization)
+ "vq/commitment_loss" : Tensor[1]
+ Commitment loss to train encoder to predict vectors closer to codebook
+ entries
+ "vq/codebook_loss" : Tensor[1]
+ Codebook loss to update the codebook
+ """
+ z_q = 0
+ residual = z
+ commitment_loss = 0
+ codebook_loss = 0
+
+ codebook_indices = []
+ latents = []
+
+ if n_quantizers is None:
+ n_quantizers = self.n_codebooks
+ if self.training:
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(z.device)
+ else:
+ n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1
+ n_quantizers = n_quantizers.to(z.device)
+
+ for i, quantizer in enumerate(self.quantizers):
+ # if self.training is False and i >= n_quantizers:
+ # break
+
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+ residual
+ )
+
+ # Create mask to apply quantizer dropout
+ mask = (
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+ )
+ z_q = z_q + z_q_i * mask[:, None, None]
+ residual = residual - z_q_i
+
+ # Sum losses
+ commitment_loss += (commitment_loss_i * mask).mean()
+ codebook_loss += (codebook_loss_i * mask).mean()
+
+ codebook_indices.append(indices_i)
+ latents.append(z_e_i)
+
+ codes = torch.stack(codebook_indices, dim=1)
+ latents = torch.cat(latents, dim=1)
+
+ encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024
+
+ return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1
+
+ def get_loss(self, x, quantized_prompt_embeds, commitment_loss, codebook_loss):
+ final_loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()
+ return final_loss
+
+ def from_codes(self, codes: torch.Tensor):
+ """Given the quantized codes, reconstruct the continuous representation
+ Parameters
+ ----------
+ codes : Tensor[B x N x T]
+ Quantized discrete representation of input
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized continuous representation of input
+ """
+ z_q = 0.0
+ z_p = []
+ n_codebooks = codes.shape[1]
+ for i in range(n_codebooks):
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+ z_p.append(z_p_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+ return z_q, torch.cat(z_p, dim=1), codes
+
+ def from_latents(self, latents: torch.Tensor):
+ """Given the unquantized latents, reconstruct the
+ continuous representation after quantization.
+
+ Parameters
+ ----------
+ latents : Tensor[B x N x T]
+ Continuous representation of input after projection
+
+ Returns
+ -------
+ Tensor[B x D x T]
+ Quantized representation of full-projected space
+ Tensor[B x D x T]
+ Quantized representation of latent space
+ """
+ z_q = 0
+ z_p = []
+ codes = []
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+ 0
+ ]
+ for i in range(n_codebooks):
+ j, k = dims[i], dims[i + 1]
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+ z_p.append(z_p_i)
+ codes.append(codes_i)
+
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
+ z_q = z_q + z_q_i
+
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
diff --git a/src/third_party/MuQ/src/muq/muq/muq.py b/src/third_party/MuQ/src/muq/muq/muq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c55e00e8dec0a6ea3e03d40a1c3ee22ff58ddce2
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq/muq.py
@@ -0,0 +1,90 @@
+import torch.nn as nn
+import torch
+from .models.muq_model import MuQModel
+from dataclasses import dataclass, field
+from typing import List, Optional
+from transformers.modeling_outputs import BaseModelOutput
+from huggingface_hub import PyTorchModelHubMixin
+
+@dataclass
+class MuQConfig:
+ label_rate:int = field(default=25)
+ num_codebooks:int = field(default=1)
+ codebook_dim:int = field(default=16)
+ codebook_size:int = field(default=4096)
+ features:List[str] = field(default_factory=lambda:["melspec_2048"])
+ hop_length:int = field(default=240)
+ n_mels:int = field(default=128)
+ conv_dim:int = field(default=512)
+ encoder_dim:int = field(default=1024)
+ encoder_depth:int = field(default=12)
+ mask_hop:float = field(default=0.4)
+ mask_prob:float = field(default=0.6)
+ is_flash:bool = field(default=False)
+ stat:Optional[dict] = field(default_factory=dict)
+ w2v2_config:Optional[dict] = field(default_factory=dict)
+ use_rvq_target:bool = field(default=False)
+ use_vq_target:bool = field(default=False)
+ use_encodec_target:bool = field(default=False)
+ rvq_ckpt_path: Optional[str] = field(default=None)
+ recon_loss_ratio: Optional[float] = field(default=None)
+ resume_checkpoint: Optional[str] = None
+ rvq_n_codebooks:int = field(default=8)
+ rvq_multi_layer_num:int = field(default=1)
+
+class MuQ(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, config: MuQConfig):
+ super().__init__()
+ if isinstance(config, dict):
+ config = MuQConfig(**config)
+ self.config = config
+ self.model = MuQModel(
+ num_codebooks=config.num_codebooks,
+ codebook_dim=config.codebook_dim,
+ codebook_size=config.codebook_size,
+ features=config.features,
+ hop_length=config.hop_length,
+ n_mels=config.n_mels,
+ conv_dim=config.conv_dim,
+ encoder_dim=config.encoder_dim,
+ encoder_depth=config.encoder_depth,
+ mask_hop=config.mask_hop,
+ mask_prob=config.mask_prob,
+ is_flash=config.is_flash,
+ stat=config.stat,
+ w2v2_config=config.w2v2_config,
+ use_rvq_target=config.use_rvq_target,
+ use_vq_target=config.use_vq_target,
+ use_encodec_target=config.use_encodec_target,
+ rvq_ckpt_path=config.rvq_ckpt_path,
+ recon_loss_ratio=config.recon_loss_ratio,
+ label_rate=config.label_rate,
+ rvq_n_codebooks=config.rvq_n_codebooks,
+ rvq_multi_layer_num=config.rvq_multi_layer_num,
+ )
+
+ def forward(self, x, attention_mask:Optional[torch.Tensor]=None, output_hidden_states:bool=True) ->BaseModelOutput:
+ """
+ Forward pass through the MuQ model and extract features.
+
+ Args:
+ x (torch.Tensor): Input waveform tensor of shape (batch_size, time).
+ attention_mask (torch.Tensor, optional): Mask to avoid performing attention on padding token indices.
+ Default is None.
+ output_hidden_states (bool, optional): Whether to return all hidden states or only the last one.
+ Default is False.
+
+ Returns:
+ BaseModelOutput: An object containing the last hidden state and optionally all hidden states.
+ - last_hidden_state (torch.Tensor): The last hidden state of the model, i.e. extracted MuQ features, of shape (batch_size, sequence_length, hidden_size).
+ - hidden_states (tuple(torch.Tensor), optional): A tuple containing all hidden states produced by the model,
+ each of shape (batch_size, sequence_length, hidden_size). Only returned if output_hidden_states is True.
+ """
+ _, hidden_states = self.model.get_predictions(x, attention_mask=attention_mask, is_features_only=True)
+ last_hidden_state = hidden_states[-1]
+ if not output_hidden_states:
+ return BaseModelOutput(last_hidden_state=last_hidden_state)
+ return BaseModelOutput(
+ last_hidden_state=last_hidden_state,
+ hidden_states=hidden_states
+ )
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/__init__.py b/src/third_party/MuQ/src/muq/muq_mulan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a9979cecaaa28b5f297abe6ed94f62e4eb2e90b
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/__init__.py
@@ -0,0 +1 @@
+from .muq_mulan import MuQMuLan, MuQMuLanConfig, MuLanConfig, ModalModelConfig, TextTransformerConfig, AudioTransformerConfig
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py b/src/third_party/MuQ/src/muq/muq_mulan/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/models/audio.py b/src/third_party/MuQ/src/muq/muq_mulan/models/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1289976c7e1e60d3ea894b64b7a01f99a2a9c97
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/models/audio.py
@@ -0,0 +1,294 @@
+from contextlib import suppress
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat, reduce
+from einops.layers.torch import Rearrange
+from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
+from transformers import Wav2Vec2FeatureExtractor,AutoModel
+from ..modules.transformer import Transformer, LayerNorm, posemb_sincos_2d
+from ..modules.utils import print_once, round_down_nearest_multiple, frozen_params, Sequential
+
+
+def pair(t):
+ return (t, t) if not isinstance(t, tuple) else t
+
+class AudioSpectrogramTransformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ patch_size = 16,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_mult = 4,
+ ff_dropout = 0.,
+ accept_spec = False,
+ accept_spec_time_first = True,
+ spec_n_fft = 128,
+ spec_power = 2,
+ spec_win_length = 24,
+ spec_hop_length = None,
+ spec_pad = 0,
+ spec_center = True,
+ spec_pad_mode = 'reflect',
+ spec_aug_stretch_factor = 0.8,
+ spec_aug_freq_mask = 80,
+ spec_aug_time_mask = 80,
+ patch_dropout_prob = 0.25
+ ):
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+
+ self.patch_size = pair(patch_size)
+ patch_input_dim = self.patch_size[0] * self.patch_size[1]
+
+ self.to_patch_tokens = Sequential(
+ Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
+ nn.LayerNorm(patch_input_dim),
+ nn.Linear(patch_input_dim, dim),
+ nn.LayerNorm(dim)
+ )
+
+ self.accept_spec = accept_spec
+ self.accept_spec_time_first = accept_spec_time_first
+
+ self.spec = Spectrogram(
+ n_fft = spec_n_fft,
+ power = spec_power,
+ win_length = spec_win_length,
+ hop_length = spec_hop_length,
+ pad = spec_pad,
+ center = spec_center,
+ pad_mode = spec_pad_mode
+ )
+
+ # SpecAugment - seems to be widely used in audio field https://arxiv.org/abs/1904.08779
+
+ self.aug = torch.nn.Sequential(
+ TimeStretch(spec_aug_stretch_factor, fixed_rate = True),
+ FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
+ TimeMasking(time_mask_param = spec_aug_time_mask),
+ )
+
+ self.transformer = Transformer(
+ dim = dim,
+ depth = depth,
+ dim_head = dim_head,
+ heads = heads,
+ attn_dropout = attn_dropout,
+ ff_mult = ff_mult,
+ ff_dropout = ff_dropout
+ )
+
+ self.norm = LayerNorm(dim)
+
+ # patch dropout
+
+ self.patch_dropout_prob = patch_dropout_prob
+
+ # 2d dynamic positional bias
+
+ mlp_hidden_dim = dim // 4
+
+ self.dynamic_pos_bias_mlp = nn.Sequential(
+ nn.Linear(2, mlp_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(mlp_hidden_dim, heads),
+ Rearrange('... i j h -> ... h i j')
+ )
+
+ def forward(
+ self,
+ x,
+ force_no_patch_dropout = False,
+ return_all_layers = False
+ ):
+ batch, device = x.shape[0], x.device
+ assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2)
+
+ if self.accept_spec and self.accept_spec_time_first:
+ x = rearrange(x, 'b t f -> b f t')
+
+ if not self.accept_spec:
+ x = self.spec(x)
+
+ if self.training:
+ x = self.aug(x)
+
+ # automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
+
+ height, width = x.shape[-2:]
+ patch_height, patch_width = self.patch_size
+
+ rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
+
+ if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
+ print_once(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
+
+ x = x[..., :rounded_height, :rounded_width]
+
+ # to patches
+
+ x = self.to_patch_tokens(x)
+
+ # get number of patches along height and width
+
+ _, num_patch_height, num_patch_width, _ = x.shape
+
+ # get 2d relative positions
+
+ grid = torch.stack(torch.meshgrid(
+ torch.arange(num_patch_height, device = device),
+ torch.arange(num_patch_width, device = device)
+ , indexing = 'ij'), dim = -1)
+
+ grid = rearrange(grid, '... c -> (...) c')
+
+ # 2d sinusoidal positional embedding
+
+ x = x + posemb_sincos_2d(x)
+
+ x = rearrange(x, 'b ... c -> b (...) c')
+
+ # patch dropout
+
+ if self.training and self.patch_dropout_prob > 0. and not force_no_patch_dropout:
+ n, device = x.shape[1], x.device
+
+ batch_indices = torch.arange(batch, device = device)
+ batch_indices = rearrange(batch_indices, '... -> ... 1')
+ num_patches_keep = max(1, int(n * (1 - self.patch_dropout_prob)))
+ patch_indices_keep = torch.randn(batch, n, device = device).topk(num_patches_keep, dim = -1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ grid = repeat(grid, '... -> b ...', b = batch)
+ grid = grid[batch_indices, patch_indices_keep]
+
+ # 2d relative positional bias
+
+ rel_dist = rearrange(grid, '... i c -> ... i 1 c') - rearrange(grid, '... j c -> ... 1 j c')
+ rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())
+
+ # attention, what else
+
+ x, all_layers = self.transformer(x, rel_pos_bias = rel_pos_bias, return_all_layers = True)
+
+ # final global average and norm (most recent papers show this is superior to CLS token)
+
+ x = reduce(x, 'b n d -> b d', 'mean')
+
+ out = self.norm(x)
+
+ if not return_all_layers:
+ return out
+
+ return out, all_layers
+
+class AudioSpectrogramTransformerPretrained(nn.Module):
+ def __init__(
+ self,
+ model_name = 'm-a-p/MERT-v1-330M',
+ dim = 768,
+ model_dim = 1024,
+ sr = 24000,
+ tf_depth = 12,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_dropout = 0.,
+ ff_mult = 4,
+ use_layer_idx = -1,
+ frozen_pretrained = True,
+ hf_hub_cache_dir = None,
+ ):
+ super().__init__()
+ self.model_name = model_name
+ self.dim = dim
+ self.sr = sr
+ self.use_layer_idx = use_layer_idx # Which layer's features should be used
+ self.hf_hub_cache_dir = hf_hub_cache_dir
+
+ self.model_name
+
+ self._init_pretrained_model(model_name)
+
+ self.aggregator = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
+
+
+ self.transformer = Transformer(
+ dim = dim,
+ depth = tf_depth,
+ dim_head = dim_head,
+ heads = heads,
+ attn_dropout = attn_dropout,
+ ff_dropout = ff_dropout,
+ ff_mult = ff_mult
+ ) # if tf_depth > 0 else torch.nn.Identity()
+
+ self.proj = nn.Linear(model_dim, dim)
+
+ if frozen_pretrained:
+ frozen_params(self.model)
+ frozen_params(self.aggregator)
+ self.frozen_pretrained = frozen_pretrained
+
+ def _init_pretrained_model(self, model_name):
+ if 'muq' in model_name.lower():
+ from muq import MuQ
+ self.model = MuQ.from_pretrained(model_name, cache_dir=self.hf_hub_cache_dir)
+ else:
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, cache_dir=self.hf_hub_cache_dir)
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True, cache_dir=self.hf_hub_cache_dir)
+
+ assert self.processor.sampling_rate == self.sr
+
+ @property
+ def device(self):
+ return next(self.model.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.model.parameters()).dtype
+
+ def _forward_pretrained_model(self, x):
+ if 'muq' in self.model_name.lower():
+ outputs = self.model(x, output_hidden_states=True)
+ return outputs.hidden_states # 13 layer x [batch_size, Time steps, 1024 feature_dim]
+ else:
+ inputs = self.processor(x, sampling_rate=self.sr, return_tensors="pt")
+ input_values = inputs['input_values'].squeeze(0).to(self.device, dtype = self.dtype)
+ outputs = self.model(input_values, output_hidden_states=True) # [25 layer, batch_size, Time steps, 1024 feature_dim]
+ return outputs.hidden_states
+
+ def forward(
+ self,
+ x,
+ return_all_layers = False,
+ return_mean = True,
+ no_proj = False,
+ ):
+ batch, device = x.shape[0], x.device
+
+ with torch.no_grad() if self.frozen_pretrained else suppress():
+ outputs = self._forward_pretrained_model(x)
+ layer_hidden_states = outputs[self.use_layer_idx]
+
+ if no_proj:
+ outputs = layer_hidden_states
+ else:
+ outputs = self.proj(layer_hidden_states)
+ outputs, layer_results = self.transformer(outputs, return_all_layers=True)
+
+ if return_mean:
+ outputs = outputs.mean(dim = -2)
+
+ if return_all_layers:
+ return outputs, layer_results
+ return outputs
+
+
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/models/mulan.py b/src/third_party/MuQ/src/muq/muq_mulan/models/mulan.py
new file mode 100644
index 0000000000000000000000000000000000000000..2abaca29c2cb84051d87af8d77d2fb7e2d50862a
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/models/mulan.py
@@ -0,0 +1,148 @@
+import math
+from typing import List, Optional, Union
+from collections import OrderedDict
+from functools import partial
+
+import torch
+from torch import nn, einsum
+
+from .audio import AudioSpectrogramTransformer, AudioSpectrogramTransformerPretrained
+from .text import TextTransformer, TextTransformerPretrained
+from ..modules.contrastive import RankSoftmaxContrastiveLearning, SoftmaxContrastiveLearning, SigmoidContrastiveLearning, MultiLayerContrastiveLoss, interspersed_indices
+from ..modules.utils import exists, default, l2norm
+
+class MuLanModel(nn.Module):
+ def __init__(
+ self,
+ audio_transformer: Union[AudioSpectrogramTransformer, AudioSpectrogramTransformerPretrained],
+ text_transformer: Union[TextTransformer, TextTransformerPretrained],
+ dim_latent = 128, # they use 128
+ decoupled_contrastive_learning = True, # think this was used, make it optional
+ hierarchical_contrastive_loss = False,
+ hierarchical_contrastive_loss_layers = None,
+ sigmoid_contrastive_loss = False,
+ rank_contrast = False, # apply contrast on rank dimension
+ proj_to_latent = True,
+ norm_type = 'l2norm',
+ **kwargs,
+ ):
+ super().__init__()
+ self.dim_latent = dim_latent
+
+ # audio and text transformer
+ self.audio = audio_transformer
+ self.text = text_transformer
+
+ # two linear layers to project embeddings to latent space
+ if proj_to_latent:
+ self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
+ self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
+
+ self.sigmoid_contrastive_loss = sigmoid_contrastive_loss
+ self.decoupled_contrastive_learning = decoupled_contrastive_learning
+ self.rank_contrast = rank_contrast
+ self.norm_type = norm_type
+
+ # use decoupled contrastive learning or not, where self.contrast is loss module for contrastive learning
+ if sigmoid_contrastive_loss:
+ klass = SigmoidContrastiveLearning
+ else:
+ if rank_contrast:
+ klass = partial(RankSoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
+ else:
+ klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
+
+ self.contrast = klass()
+
+ self.multi_layer_contrastive_learning = None
+
+ if hierarchical_contrastive_loss:
+ num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1)
+ assert num_layers > 0
+
+ self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth))
+ self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth))
+
+ self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss(
+ audio_dim = self.audio.dim,
+ text_dim = self.text.dim,
+ dim_latent = dim_latent,
+ layers = num_layers,
+ decoupled_contrastive_learning = decoupled_contrastive_learning,
+ sigmoid_contrastive_loss = sigmoid_contrastive_loss
+ )
+
+ def get_audio_latents(
+ self,
+ wavs,
+ return_all_layers = False,
+ ):
+ audio_embeds, audio_layers = self.audio(wavs, return_all_layers = True)
+ audio_latents = self.audio_to_latents(audio_embeds)
+ out = self._norm_latents(audio_latents) #->[Batch, Feat=128]
+
+ if not return_all_layers:
+ return out
+
+ return out, audio_layers #[nLayer=5, Batch=2, 15, 512]
+
+ def get_text_latents(
+ self,
+ texts = None,
+ raw_texts: Optional[List[str]] = None,
+ return_all_layers = False
+ ):
+ text_embeds, text_layers = self.text(texts, raw_texts = raw_texts, return_all_layers = True)
+ text_latents = self.text_to_latents(text_embeds)
+ out = self._norm_latents(text_latents)
+
+ if not return_all_layers:
+ return out
+
+ return out, text_layers
+
+ def _norm_latents(self, latents):
+ if self.norm_type == 'l2norm':
+ return l2norm(latents)
+ else:
+ return self.norm(latents)
+
+ def forward(
+ self,
+ wavs,
+ texts = None,
+ raw_texts: Optional[List[str]] = None,
+ return_latents = False,
+ return_similarities = False,
+ return_pairwise_similarities = False,
+ ):
+ batch, device = wavs.shape[0], wavs.device
+
+ # both latents are of [Batch, Feat=128]
+ audio_latents, audio_layers = self.get_audio_latents(wavs, return_all_layers = True)
+ text_latents, text_layers = self.get_text_latents(texts, raw_texts = raw_texts, return_all_layers = True)
+
+ if return_latents: # used in inference
+ return audio_latents, text_latents
+
+ if return_similarities:
+ return einsum('i d, i d -> i', audio_latents, text_latents)
+
+ if return_pairwise_similarities:
+ cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
+ return cosine_sim
+
+ cl_loss = self.contrast(audio_latents, text_latents) #contrastive loss
+
+ if not exists(self.multi_layer_contrastive_learning):
+ return cl_loss
+
+ audio_layers = audio_layers[self.audio_layers_indices]
+ text_layers = text_layers[self.text_layers_indices]
+
+ hierarchical_cl_loss = self.multi_layer_contrastive_learning(
+ audio_layers = audio_layers,
+ text_layers = text_layers
+ )
+
+ return cl_loss + hierarchical_cl_loss
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/models/text.py b/src/third_party/MuQ/src/muq/muq_mulan/models/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..8df6a1da744fef8fbc4abc0d1ccb123fa2b1fc3c
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/models/text.py
@@ -0,0 +1,241 @@
+from typing import Optional, List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from x_clip.tokenizer import tokenizer
+from einops import rearrange, repeat, reduce, pack, unpack
+from transformers import AutoTokenizer,XLMRobertaModel,AutoModelForCausalLM
+
+from ..modules.utils import *
+from ..modules.transformer import Transformer, LayerNorm
+from ..modules.utils import frozen_params
+
+
+# text transformer
+
+class TextTransformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ num_tokens = tokenizer.vocab_size,
+ max_seq_len = 256,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_dropout = 0.,
+ ff_mult = 4,
+ pad_id = 0
+ ):
+ super().__init__()
+ self.dim = dim
+
+ self.token_emb = nn.Embedding(num_tokens, dim)
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
+
+ self.depth = depth
+ self.max_seq_len = max_seq_len
+
+ self.cls_token = nn.Parameter(torch.randn(dim))
+
+ self.transformer = Transformer(
+ dim = dim,
+ depth = depth,
+ dim_head = dim_head,
+ heads = heads,
+ attn_dropout = attn_dropout,
+ ff_dropout = ff_dropout,
+ ff_mult = ff_mult
+ )
+
+ self.pad_id = pad_id
+ self.norm = LayerNorm(dim)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(
+ self,
+ x = None,
+ raw_texts: Optional[List[str]] = None,
+ mask = None,
+ return_all_layers = False
+ ):
+ assert exists(x) ^ exists(raw_texts)
+
+ if exists(raw_texts):
+ x = tokenizer.tokenize(raw_texts).to(self.device)
+
+ if not exists(mask):
+ mask = x != self.pad_id
+
+ b, n, device = *x.shape, x.device
+
+ # token embedding + positional embedding
+
+ x = self.token_emb(x)
+
+ assert n <= self.max_seq_len, f'text sequence length {n} must be less than {self.max_seq_len}'
+
+ x = x + self.pos_emb(torch.arange(n, device = device))
+
+ # cls tokens, as in bert
+
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
+ x, ps = pack([cls_tokens, x], 'b * d')
+
+ # account for attending to cls token with self attention mask
+
+ mask = F.pad(mask, (1, 0), value = True)
+
+ # attention
+
+ x, all_layers = self.transformer(x, mask = mask, return_all_layers = True)
+
+ # unpack the cls tokens
+
+ cls_tokens, _ = unpack(x, ps, 'b * d')
+
+ out = self.norm(cls_tokens)
+
+ if not return_all_layers:
+ return out
+
+ return out, all_layers
+
+class TextPretrainedModelType:
+ Roberta = 'roberta'
+ Qwen = 'qwen'
+
+class TextTransformerPretrained(nn.Module):
+ def __init__(
+ self,
+ model_name = 'xlm-roberta-base',
+ dim = 768,
+ model_dim = None,
+ max_seq_len = 256,
+ tf_depth = 12,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_dropout = 0.,
+ ff_mult = 4,
+ frozen_pretrained = True,
+ hf_hub_cache_dir = None,
+ ):
+ super().__init__()
+ self.dim = dim
+
+ self.model_name = model_name
+
+ self.hf_hub_cache_dir = hf_hub_cache_dir
+
+ self.pretrained_model_type = self._get_pretrained_model_type(model_name)
+
+ self.model = self._init_pretrained_model()
+
+ self._tokenizer = None
+
+ self.max_seq_len = max_seq_len
+
+ self.transformer = Transformer(
+ dim = dim,
+ depth = tf_depth,
+ dim_head = dim_head,
+ heads = heads,
+ attn_dropout = attn_dropout,
+ ff_dropout = ff_dropout,
+ ff_mult = ff_mult
+ ) # if tf_depth > 0 else torch.nn.Identity()
+
+ is_proj = exists(model_dim) and model_dim != dim
+
+ self.proj = nn.Linear(model_dim, dim) if is_proj else torch.nn.Identity()
+ if frozen_pretrained:
+ frozen_params(self.model)
+ self.frozen_pretrained = frozen_pretrained
+
+ @staticmethod
+ def _get_pretrained_model_type(model_name):
+ if 'xlm-roberta' in model_name:
+ return TextPretrainedModelType.Roberta
+ elif 'Qwen' in model_name:
+ return TextPretrainedModelType.Qwen
+ else:
+ raise ValueError(f"Unknown pretrained model named: {model_name}")
+
+ def _init_pretrained_model(self):
+ if self.pretrained_model_type == TextPretrainedModelType.Roberta:
+ model = XLMRobertaModel.from_pretrained(self.model_name, trust_remote_code=True, cache_dir=self.hf_hub_cache_dir)
+ elif self.pretrained_model_type == TextPretrainedModelType.Qwen:
+ model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True, fp16=True, cache_dir=self.hf_hub_cache_dir)
+ else:
+ raise ValueError(f"Failed to init pretrained model type: {self.pretrained_model_type}")
+ return model
+
+ def _init_tokenizer(self):
+ if self.pretrained_model_type == TextPretrainedModelType.Roberta:
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, cache_dir=self.hf_hub_cache_dir)
+ elif self.pretrained_model_type == TextPretrainedModelType.Qwen:
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True, cache_dir=self.hf_hub_cache_dir)
+ tokenizer.pad_token = '<|im_end|>'
+ else:
+ raise ValueError(f"Failed to init tokenizer of pretrained model type: {self.pretrained_model_type}")
+ return tokenizer
+
+ @property
+ def tokenizer(self):
+ if not exists(self._tokenizer):
+ self._tokenizer = self._init_tokenizer()
+ return self._tokenizer
+
+ @property
+ def device(self):
+ return next(self.model.parameters()).device
+
+ @property
+ def dtype(self):
+ return next(self.transformer.parameters()).dtype
+
+
+ def pred_pretrained_model_hidden(self, **kw):
+ if self.pretrained_model_type == TextPretrainedModelType.Roberta:
+ outputs = self.model(**kw)
+ outputs = outputs.last_hidden_state
+ elif self.pretrained_model_type == TextPretrainedModelType.Qwen:
+ last_hidden_state = self.model(**kw, output_hidden_states=True)['hidden_states'][-1]
+ outputs = last_hidden_state.to(dtype = self.dtype)
+ else:
+ raise ValueError(f"Unknown pretrained model type: {self.pretrained_model_type}")
+ return outputs
+
+ def forward(
+ self,
+ x = None,
+ raw_texts: Optional[List[str]] = None,
+ mask = None,
+ return_all_layers = False,
+ return_mean = True
+ ):
+ assert exists(x) ^ exists(raw_texts)
+ with torch.no_grad():
+ if exists(raw_texts):
+ inputs = self.tokenizer(raw_texts, return_tensors='pt', padding=True)
+ inputs = inputs.to(self.device)
+
+ if exists(mask):
+ inputs['attention_mask'] = mask
+
+ outputs = self.pred_pretrained_model_hidden(**inputs)
+
+ outputs = self.proj(outputs)
+
+ outputs, layer_results = self.transformer(outputs, return_all_layers=True)
+ if return_mean:
+ outputs = outputs.mean(dim = -2)
+
+ if return_all_layers:
+ return outputs, layer_results
+ return outputs
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/__init__.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/contrastive.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/contrastive.py
new file mode 100644
index 0000000000000000000000000000000000000000..064377b58105f16780c37998a4f37d8fa62d772a
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/modules/contrastive.py
@@ -0,0 +1,238 @@
+import math
+from functools import partial
+
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+from einops import rearrange, reduce
+from torch import einsum
+import torch.distributed as dist
+
+from .utils import exists, l2norm, log, print_once
+from .distributed import AllGather
+from .extend_distributed import all_gather
+from .transformer import LayerNorm
+
+def matrix_diag(t):
+ device = t.device
+ i, j = t.shape[-2:]
+ num_diag_el = min(i, j)
+ i_range = torch.arange(i, device = device)
+ j_range = torch.arange(j, device = device)
+ diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j')
+ diag_el = t.masked_select(diag_mask)
+ return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)
+
+# contrastive losses
+
+class SoftmaxContrastiveLearning(nn.Module):
+ def __init__(
+ self,
+ *,
+ layers = 1,
+ decoupled_contrastive_learning = False,
+ init_temp = 10
+ ):
+ super().__init__()
+ self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
+ self.decoupled_contrastive_learning = decoupled_contrastive_learning
+
+ self.all_gather = AllGather(dim = 2)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, audio_latents, text_latents):
+ if audio_latents.ndim == 2:
+ audio_latents = rearrange(audio_latents, '... -> 1 ...')
+
+ if text_latents.ndim == 2:
+ text_latents = rearrange(text_latents, '... -> 1 ...')
+
+ batch = audio_latents.shape[1]
+
+ if self.all_gather.is_distributed:
+ latents = torch.stack((audio_latents, text_latents))
+ latents, _ = self.all_gather(latents)
+ audio_latents, text_latents = latents
+
+ sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)
+
+ sims = sims * self.temperatures.exp()
+
+ cosine_sims_exp = sims.exp() # Similarity matrix [Rank, N, N]
+
+ numerator = matrix_diag(cosine_sims_exp) # Take diagonal elements, that is, for t [l, i, j], take all elements of i==j to obtain a array of l * min (i, j)
+
+ if self.decoupled_contrastive_learning:
+ eye = torch.eye(batch, device = self.device, dtype = torch.bool)
+ cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) # Set the diagonal to 0
+
+ denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
+ denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')
+
+ contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
+
+ contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
+ return contrastive_loss.sum()
+
+
+class RankSoftmaxContrastiveLearning(nn.Module):
+ def __init__(
+ self,
+ *,
+ layers = 1,
+ decoupled_contrastive_learning = False,
+ init_temp = 10,
+ ):
+ super().__init__()
+ self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
+ self.decoupled_contrastive_learning = decoupled_contrastive_learning
+
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, audio_latents, text_latents):
+ if audio_latents.ndim == 2:
+ audio_latents = rearrange(audio_latents, '... -> 1 ...')
+
+ if text_latents.ndim == 2:
+ text_latents = rearrange(text_latents, '... -> 1 ...')
+
+ audio_latents = all_gather(audio_latents, None)
+ text_latents = all_gather(text_latents, None)
+
+ print_once("audio_latents:"+str(audio_latents.shape) + "text_latents:" + str(text_latents.shape))
+
+
+ batch = audio_latents.shape[1]
+ rank = audio_latents.shape[0]
+
+ audio_latents = rearrange(audio_latents, 'l i d -> (l i) d')
+ text_latents = rearrange(text_latents, 'l j d -> (l j) d')
+
+ sims = einsum('i d, j d -> i j', audio_latents, text_latents)
+
+ sims = sims * self.temperatures.exp()
+
+ sims = rearrange(sims, '1 i j -> i j')
+
+ cosine_sims_exp = sims.exp() # Similarity matrix [Rank, N, N]
+
+
+ numerator = matrix_diag(cosine_sims_exp) # Take diagonal elements, that is, for t [l, i, j], take all elements of i==j to obtain a array of l * min (i, j)
+
+ if self.decoupled_contrastive_learning:
+ eye = torch.eye(batch*rank, device = self.device, dtype = torch.bool)
+ cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) # Set the diagonal to 0
+
+ denominator_i = reduce(cosine_sims_exp, 'i j -> i', 'sum')
+ denominator_j = reduce(cosine_sims_exp, 'i j -> j', 'sum')
+
+ contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
+
+ contrastive_loss = reduce(contrastive_loss, '1 n -> 1', 'mean')
+ return contrastive_loss
+
+
+class SigmoidContrastiveLearning(nn.Module):
+ """ https://arxiv.org/abs/2303.15343 """
+
+ def __init__(
+ self,
+ *,
+ layers = 1,
+ init_temp = 10,
+ init_bias = -10
+ ):
+ super().__init__()
+ self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
+ self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)
+
+ self.all_gather = AllGather(dim = 1, all_reduce_grads = True)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def forward(self, audio_latents, text_latents):
+ device = self.device
+
+ if audio_latents.ndim == 2:
+ audio_latents = rearrange(audio_latents, '... -> 1 ...') # To [Rank, Batch, Latent]
+
+ if text_latents.ndim == 2:
+ text_latents = rearrange(text_latents, '... -> 1 ...')
+
+ text_latents, rank_sizes = self.all_gather(text_latents)
+
+ n = text_latents.shape[1]
+
+ sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) # Calculate dot product similarity between pairs
+
+ sims = sims * self.temperatures.exp() + self.bias
+
+ labels = torch.eye(n, device = device)
+
+ if exists(rank_sizes):
+ labels_by_ranks = labels.split(rank_sizes.tolist(), dim = 0)
+ labels = labels_by_ranks[dist.get_rank()] # labels to the n elements of the current rank
+
+ labels = 2 * rearrange(labels, 'i j -> 1 i j') - torch.ones_like(sims)
+
+ return -F.logsigmoid(labels * sims).sum() / n
+
+
+
+
+# hierarchical cl loss
+
+def interspersed_indices(layers, total_layers):
+ assert total_layers >= layers
+ step = total_layers / layers
+ return (torch.arange(0, layers) * step).floor().long()
+
+class MultiLayerContrastiveLoss(nn.Module):
+ def __init__(
+ self,
+ *,
+ audio_dim,
+ text_dim,
+ dim_latent,
+ layers,
+ decoupled_contrastive_learning = False,
+ sigmoid_contrastive_loss = False
+ ):
+ super().__init__()
+ self.layers = layers
+
+ self.audio_norm = LayerNorm(audio_dim, scale = False)
+ self.audio_gamma = nn.Parameter(torch.ones(layers, 1, audio_dim))
+ self.audio_latent_weight = nn.Parameter(torch.randn(layers, audio_dim, dim_latent))
+ self.audio_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))
+
+ self.text_norm = LayerNorm(text_dim, scale = False)
+ self.text_gamma = nn.Parameter(torch.ones(layers, 1, text_dim))
+ self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))
+ self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))
+
+ klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
+ self.contrast = klass(layers = layers)
+
+ def forward(self, *, audio_layers, text_layers):
+ device, batch = audio_layers.device, audio_layers.shape[1]
+
+ audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean')
+ audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma
+ audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias
+ audio_latents = l2norm(audio_latents)
+
+ text_cls_tokens = text_layers[:, :, 0]
+ text_embeds = self.text_norm(text_cls_tokens) * self.text_gamma
+ text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
+ text_latents = l2norm(text_latents)
+
+ return self.contrast(audio_latents, text_latents)
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/distributed.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f360a357f8c68c30febf73fb9539e5f5413b65a
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/modules/distributed.py
@@ -0,0 +1,83 @@
+import torch
+from torch import nn
+from torch.autograd import Function
+import torch.distributed as dist
+
+from einops import rearrange
+
+def exists(val):
+ return val is not None
+
+# distributed helpers
+
+def all_gather_same_dim(t):
+ world_size = dist.get_world_size()
+ gathered_tensors = [torch.empty_like(t, device = t.device, dtype = t.dtype) for i in range(world_size)]
+ dist.all_gather(gathered_tensors, t)
+ return gathered_tensors
+
+def all_gather_variable_dim(t, dim = 0, sizes = None):
+ device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
+
+ if not exists(sizes):
+ size = torch.tensor(t.shape[dim], device = device, dtype = torch.long)
+ sizes = all_gather_same_dim(size)
+ sizes = torch.stack(sizes)
+
+ if torch.unique(sizes).numel() == 1:
+ gathered_tensors = all_gather_same_dim(t)
+ return torch.cat(gathered_tensors, dim = dim), sizes
+
+ max_size = sizes.amax().item()
+
+ padded_t = pad_dim_to(t, max_size, dim = dim)
+ gathered_tensors = all_gather_same_dim(padded_t)
+
+ gathered_tensor = torch.cat(gathered_tensors, dim = dim)
+ seq = torch.arange(max_size, device = device)
+
+ mask = rearrange(seq, 'j -> 1 j') < rearrange(sizes, 'i -> i 1')
+ mask = rearrange(mask, 'i j -> (i j)')
+ seq = torch.arange(mask.shape[-1], device = device)
+ indices = seq[mask]
+
+ gathered_tensor = gathered_tensor.index_select(dim, indices)
+
+ return gathered_tensor, sizes
+
+class AllGatherFunction(Function):
+ @staticmethod
+ def forward(ctx, x, dim, sizes, all_reduce_grads):
+ x, batch_sizes = all_gather_variable_dim(x, dim = dim, sizes = sizes)
+ ctx.dim = dim
+ ctx.all_reduce_grads = all_reduce_grads
+ ctx.batch_sizes = batch_sizes.tolist()
+ return x, batch_sizes
+
+ @staticmethod
+ def backward(ctx, grads, _):
+ batch_sizes, rank = ctx.batch_sizes, dist.get_rank()
+ if ctx.all_reduce_grads:
+ dist.all_reduce(grads)
+
+ grads_by_rank = grads.split(batch_sizes, dim = ctx.dim)
+ return grads_by_rank[rank], None, None, None
+
+class AllGather(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ all_reduce_grads = False
+ ):
+ super().__init__()
+ self.dim = dim
+ self.all_reduce_grads = all_reduce_grads
+ self.is_distributed = dist.is_initialized() and dist.get_world_size() > 1
+
+ def forward(
+ self,
+ x,
+ sizes = None
+ ):
+ return AllGatherFunction.apply(x, self.dim, sizes, self.all_reduce_grads)
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/extend_distributed.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/extend_distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0d0c3f57135173747c06f378b1f5a6d3568668
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/modules/extend_distributed.py
@@ -0,0 +1,604 @@
+
+import builtins
+import os
+import sys
+
+import torch
+import torch.distributed as dist
+from torch.autograd import Function
+from torch.autograd.profiler import record_function
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import torch.distributed as dist
+
+try:
+ import torch_ccl
+except ImportError as e:
+ # print(e)
+ torch_ccl = False
+
+try:
+ import torch_ucc
+except ImportError as e:
+ torch_ucc = False
+
+
+my_rank = -1
+my_size = -1
+my_local_rank = 1
+my_local_size = 1
+alltoall_supported = False
+a2a_impl = os.environ.get("DLRM_ALLTOALL_IMPL", "")
+
+myreq = None
+
+
+def env2int(env_list, default=-1):
+ for e in env_list:
+ val = int(os.environ.get(e, -1))
+ if val >= 0:
+ return val
+ return default
+
+
+def get_my_slice(n):
+ k, m = divmod(n, my_size)
+ return slice(
+ my_rank * k + min(my_rank, m), (my_rank + 1) * k + min(my_rank + 1, m), 1
+ )
+
+
+def get_split_lengths(n):
+ k, m = divmod(n, my_size)
+ if m == 0:
+ splits = None
+ my_len = k
+ else:
+ splits = [(k + 1) if i < m else k for i in range(my_size)]
+ my_len = splits[my_rank]
+ return (my_len, splits)
+
+
+def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend=""):
+ global myreq
+ global my_rank
+ global my_size
+ global my_local_rank
+ global my_local_size
+ global a2a_impl
+ global alltoall_supported
+
+ # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2)
+ num_mpi_ranks = env2int(
+ ["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"]
+ )
+ if backend == "" and num_mpi_ranks > 1:
+ if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0:
+ backend = "ccl"
+ elif use_gpu and dist.is_nccl_available():
+ backend = "nccl"
+ elif dist.is_mpi_available():
+ backend = "mpi"
+ else:
+ print(
+ "WARNING: MPI multi-process launch detected but PyTorch MPI backend not available."
+ )
+ backend = "gloo"
+
+ if backend != "":
+ # guess Rank and size
+ if rank == -1:
+ rank = env2int(
+ ["PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "RANK"], 0
+ )
+ if size == -1:
+ size = env2int(
+ [
+ "PMI_SIZE",
+ "OMPI_COMM_WORLD_SIZE",
+ "MV2_COMM_WORLD_SIZE",
+ "WORLD_SIZE",
+ ],
+ 1,
+ )
+ if not os.environ.get("RANK", None) and rank != -1:
+ os.environ["RANK"] = str(rank)
+ if not os.environ.get("WORLD_SIZE", None) and size != -1:
+ os.environ["WORLD_SIZE"] = str(size)
+ if not os.environ.get("MASTER_PORT", None):
+ os.environ["MASTER_PORT"] = "29500"
+ if not os.environ.get("MASTER_ADDR", None):
+ local_size = env2int(
+ [
+ "MPI_LOCALNRANKS",
+ "OMPI_COMM_WORLD_LOCAL_SIZE",
+ "MV2_COMM_WORLD_LOCAL_SIZE",
+ ],
+ 1,
+ )
+ if local_size != size and backend != "mpi":
+ print(
+ "Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default"
+ )
+ print(
+ "If this run hangs, try exporting rank 0's hostname as MASTER_ADDR"
+ )
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
+
+ if size > 1:
+ if local_rank == -1:
+ my_local_rank = env2int(
+ [
+ "MPI_LOCALRANKID",
+ "OMPI_COMM_WORLD_LOCAL_RANK",
+ "MV2_COMM_WORLD_LOCAL_RANK",
+ "LOCAL_RANK",
+ ],
+ 0,
+ )
+ else:
+ my_local_rank = local_rank
+ my_local_size = env2int(
+ [
+ "MPI_LOCALNRANKS",
+ "OMPI_COMM_WORLD_LOCAL_SIZE",
+ "MV2_COMM_WORLD_LOCAL_SIZE",
+ ],
+ 1,
+ )
+ if use_gpu:
+ if my_local_size > torch.cuda.device_count():
+ print(
+ "Not sufficient GPUs available... local_size = %d, ngpus = %d"
+ % (my_local_size, torch.cuda.device_count())
+ )
+ sys.exit(1)
+ torch.cuda.set_device(my_local_rank)
+ dist.init_process_group(backend, rank=rank, world_size=size)
+ my_rank = dist.get_rank()
+ my_size = dist.get_world_size()
+ if my_rank == 0:
+ print("Running on %d ranks using %s backend" % (my_size, backend))
+ if hasattr(dist, "all_to_all_single"):
+ try:
+ t = torch.zeros([4])
+ if use_gpu:
+ t = t.cuda()
+ dist.all_to_all_single(t, t)
+ alltoall_supported = True
+ except RuntimeError as err:
+ print("fail to enable all_to_all_single primitive: %s" % err)
+ if a2a_impl == "alltoall" and alltoall_supported == False:
+ print(
+ "Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall"
+ % (a2a_impl, backend)
+ )
+ a2a_impl = "scatter"
+ if a2a_impl != "":
+ print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl)
+ else:
+ my_rank = 0
+ my_size = 1
+ my_local_rank = 0
+ my_local_size = 1
+ print_all(
+ "world size: %d, current rank: %d, local rank: %d"
+ % (my_size, my_rank, my_local_rank)
+ )
+ myreq = Request()
+
+
+class Request(object):
+ def __init__(self):
+ self.req = None
+ self.tensor = None
+ self.WaitFunction = All2All_Scatter_Wait
+
+ def wait(self):
+ ret = self.WaitFunction.apply(*self.tensor)
+ self.req = None
+ self.tensor = None
+ return ret
+
+
+class All2All_ScatterList_Req(Function):
+ @staticmethod
+ def forward(ctx, a2a_info, *inputs):
+ global myreq
+ batch_split_lengths = (
+ a2a_info.global_batch_partition_slices
+ if a2a_info.global_batch_partition_slices
+ else a2a_info.local_batch_num
+ )
+ table_split_lengths = (
+ a2a_info.global_table_wise_parition_slices
+ if a2a_info.global_table_wise_parition_slices
+ else [a2a_info.local_table_num] * my_size
+ )
+ gather_list = []
+ req_list = []
+ for i in range(my_size):
+ for j in range(table_split_lengths[i]):
+ out_tensor = inputs[0].new_empty(
+ [a2a_info.local_batch_num, a2a_info.emb_dim]
+ )
+ scatter_list = (
+ list(inputs[j].split(batch_split_lengths, dim=0))
+ if i == my_rank
+ else []
+ )
+ req = dist.scatter(out_tensor, scatter_list, src=i, async_op=True)
+ gather_list.append(out_tensor)
+ req_list.append(req)
+ myreq.req = req_list
+ myreq.tensor = tuple(gather_list)
+ myreq.a2a_info = a2a_info
+ return myreq.tensor
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ global myreq
+ for r in myreq.req:
+ r.wait()
+ myreq.req = None
+ grad_inputs = myreq.tensor
+ myreq.tensor = None
+ return (None, *grad_inputs)
+
+
+class All2All_ScatterList_Wait(Function):
+ @staticmethod
+ def forward(ctx, *output):
+ global myreq
+ ctx.a2a_info = myreq.a2a_info
+ for r in myreq.req:
+ r.wait()
+ myreq.req = None
+ myreq.tensor = None
+ return output
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ global myreq
+ a2a_info = ctx.a2a_info
+ grad_output = [t.contiguous() for t in grad_output]
+ batch_split_lengths = (
+ a2a_info.global_batch_partition_slices
+ if a2a_info.global_batch_partition_slices
+ else [a2a_info.local_batch_num] * my_size
+ )
+ per_rank_table_splits = (
+ a2a_info.global_table_wise_parition_slices
+ if a2a_info.global_table_wise_parition_slices
+ else [a2a_info.local_table_num] * my_size
+ )
+ grad_inputs = [
+ grad_output[0].new_empty([ctx.a2a_info.batch_size, ctx.a2a_info.emb_dim])
+ for _ in range(a2a_info.local_table_num)
+ ]
+ req_list = []
+ ind = 0
+ for i in range(my_size):
+ for j in range(per_rank_table_splits[i]):
+ gather_list = (
+ list(grad_inputs[j].split(batch_split_lengths, dim=0))
+ if i == my_rank
+ else None
+ )
+ req = dist.gather(grad_output[ind], gather_list, dst=i, async_op=True)
+ req_list.append(req)
+ ind += 1
+ myreq.req = req_list
+ myreq.tensor = grad_inputs
+ return tuple(grad_output)
+
+
+class All2All_Scatter_Req(Function):
+ @staticmethod
+ def forward(ctx, a2a_info, *inputs):
+ global myreq
+ batch_split_lengths = (
+ a2a_info.global_batch_partition_slices
+ if a2a_info.global_batch_partition_slices
+ else a2a_info.local_batch_num
+ )
+ table_split_lengths = (
+ a2a_info.global_table_wise_parition_slices
+ if a2a_info.global_table_wise_parition_slices
+ else [a2a_info.local_table_num] * my_size
+ )
+ input = torch.cat(inputs, dim=1)
+ scatter_list = list(input.split(batch_split_lengths, dim=0))
+ gather_list = []
+ req_list = []
+ for i in range(my_size):
+ out_tensor = input.new_empty(
+ [a2a_info.local_batch_num, table_split_lengths[i] * a2a_info.emb_dim]
+ )
+ req = dist.scatter(
+ out_tensor, scatter_list if i == my_rank else [], src=i, async_op=True
+ )
+ gather_list.append(out_tensor)
+ req_list.append(req)
+ myreq.req = req_list
+ myreq.tensor = tuple(gather_list)
+ myreq.a2a_info = a2a_info
+ ctx.a2a_info = a2a_info
+ return myreq.tensor
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ global myreq
+ for r in myreq.req:
+ r.wait()
+ myreq.req = None
+ grad_input = myreq.tensor
+ grad_inputs = grad_input.split(ctx.a2a_info.emb_dim, dim=1)
+ myreq.tensor = None
+ return (None, *grad_inputs)
+
+
+class All2All_Scatter_Wait(Function):
+ @staticmethod
+ def forward(ctx, *output):
+ global myreq
+ ctx.a2a_info = myreq.a2a_info
+ for r in myreq.req:
+ r.wait()
+ myreq.req = None
+ myreq.tensor = None
+ return output
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ global myreq
+ assert len(grad_output) == my_size
+ scatter_list = [t.contiguous() for t in grad_output]
+ a2a_info = ctx.a2a_info
+ batch_split_lengths = (
+ a2a_info.global_batch_partition_slices
+ if a2a_info.global_batch_partition_slices
+ else a2a_info.local_batch_num
+ )
+ table_split_lengths = (
+ a2a_info.global_table_wise_parition_slices
+ if a2a_info.global_table_wise_parition_slices
+ else [a2a_info.local_table_num] * my_size
+ )
+ grad_input = grad_output[0].new_empty(
+ [a2a_info.batch_size, a2a_info.emb_dim * a2a_info.local_table_num]
+ )
+ gather_list = list(grad_input.split(batch_split_lengths, dim=0))
+ req_list = []
+ for i in range(my_size):
+ req = dist.gather(
+ scatter_list[i],
+ gather_list if i == my_rank else [],
+ dst=i,
+ async_op=True,
+ )
+ req_list.append(req)
+ myreq.req = req_list
+ myreq.tensor = grad_input
+ return grad_output
+
+
+class All2All_Req(Function):
+ @staticmethod
+ def forward(ctx, a2a_info, *inputs):
+ global myreq
+ with record_function("DLRM alltoall_req_fwd_single"):
+ batch_split_lengths = a2a_info.global_batch_partition_slices
+ if batch_split_lengths:
+ batch_split_lengths = [
+ m * a2a_info.emb_dim * a2a_info.local_table_num
+ for m in batch_split_lengths
+ ]
+ table_split_lengths = a2a_info.global_table_wise_parition_slices
+ if table_split_lengths:
+ table_split_lengths = [
+ a2a_info.local_batch_num * e * a2a_info.emb_dim
+ for e in table_split_lengths
+ ]
+ input = torch.cat(inputs, dim=1).view([-1])
+ output = input.new_empty(
+ [
+ a2a_info.global_table_num
+ * a2a_info.local_batch_num
+ * a2a_info.emb_dim
+ ]
+ )
+ req = dist.all_to_all_single(
+ output, input, table_split_lengths, batch_split_lengths, async_op=True
+ )
+
+ myreq.req = req
+ myreq.tensor = []
+ myreq.tensor.append(output)
+ myreq.tensor = tuple(myreq.tensor)
+ a2a_info.batch_split_lengths = batch_split_lengths
+ a2a_info.table_split_lengths = table_split_lengths
+ myreq.a2a_info = a2a_info
+ ctx.a2a_info = a2a_info
+ return myreq.tensor
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ global myreq
+ with record_function("DLRM alltoall_req_bwd_single"):
+ a2a_info = ctx.a2a_info
+ myreq.req.wait()
+ myreq.req = None
+ grad_input = myreq.tensor
+ grad_inputs = grad_input.view([a2a_info.batch_size, -1]).split(
+ a2a_info.emb_dim, dim=1
+ )
+ grad_inputs = [gin.contiguous() for gin in grad_inputs]
+ myreq.tensor = None
+ return (None, *grad_inputs)
+
+
+class All2All_Wait(Function):
+ @staticmethod
+ def forward(ctx, *output):
+ global myreq
+ with record_function("DLRM alltoall_wait_fwd_single"):
+ a2a_info = myreq.a2a_info
+ ctx.a2a_info = a2a_info
+ myreq.req.wait()
+ myreq.req = None
+ myreq.tensor = None
+ table_split_lengths = (
+ a2a_info.table_split_lengths
+ if a2a_info.table_split_lengths
+ else a2a_info.local_table_num
+ * a2a_info.local_batch_num
+ * a2a_info.emb_dim
+ )
+ outputs = output[0].split(table_split_lengths)
+ outputs = tuple(
+ [out.view([a2a_info.local_batch_num, -1]) for out in outputs]
+ )
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *grad_outputs):
+ global myreq
+ with record_function("DLRM alltoall_wait_bwd_single"):
+ a2a_info = ctx.a2a_info
+ grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs]
+ grad_output = torch.cat(grad_outputs)
+ grad_input = grad_output.new_empty(
+ [a2a_info.batch_size * a2a_info.local_table_num * a2a_info.emb_dim]
+ )
+ req = dist.all_to_all_single(
+ grad_input,
+ grad_output,
+ a2a_info.batch_split_lengths,
+ a2a_info.table_split_lengths,
+ async_op=True,
+ )
+ myreq.req = req
+ myreq.tensor = grad_input
+ return (grad_output,)
+
+
+class AllGather(Function):
+ @staticmethod
+ def forward(ctx, input, global_lengths, dim=0):
+ if not isinstance(global_lengths, (list, tuple)):
+ global_lengths = [global_lengths] * my_size
+
+ assert len(global_lengths) == my_size
+ assert global_lengths[my_rank] == input.size(dim)
+ local_start = sum(global_lengths[:my_rank])
+
+ output_size = list(input.size())
+
+ ctx.dim = dim
+ ctx.local_start = local_start
+ ctx.local_length = global_lengths[my_rank]
+
+ input = input.contiguous()
+ if dim == 0:
+ out_len = sum(global_lengths)
+ output_size[dim] = out_len
+ output = input.new_empty(output_size)
+ gather_list = list(output.split(global_lengths, dim=0))
+ else:
+ gather_list = [torch.empty_like(input) for _ in range(my_size)]
+ gather_list = []
+ for length in global_lengths:
+ output_size[dim] = length
+ gather_list.append(input.new_empty(output_size))
+
+ dist.all_gather(gather_list, input)
+
+ if dim != 0:
+ output = torch.cat(gather_list, dim=dim)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # print("Inside All2AllBackward")
+ dim = ctx.dim
+ start = ctx.local_start
+ length = ctx.local_length
+
+ grad_input = grad_output.narrow(dim, start, length)
+
+ return (grad_input, None, None)
+
+
+class All2AllInfo(object):
+ pass
+
+
+def alltoall(inputs, per_rank_table_splits):
+ global myreq
+ batch_size, emb_dim = inputs[0].size()
+ a2a_info = All2AllInfo()
+ a2a_info.local_table_num = len(inputs)
+ a2a_info.global_table_wise_parition_slices = per_rank_table_splits
+ (
+ a2a_info.local_batch_num,
+ a2a_info.global_batch_partition_slices,
+ ) = get_split_lengths(batch_size)
+ a2a_info.emb_dim = emb_dim
+ a2a_info.batch_size = batch_size
+ a2a_info.global_table_num = (
+ sum(per_rank_table_splits)
+ if per_rank_table_splits
+ else a2a_info.local_table_num * my_size
+ )
+
+ if a2a_impl == "" and alltoall_supported or a2a_impl == "alltoall":
+ # print("Using All2All_Req")
+ output = All2All_Req.apply(a2a_info, *inputs)
+ myreq.WaitFunction = All2All_Wait
+ elif a2a_impl == "" or a2a_impl == "scatter":
+ # print("Using All2All_Scatter_Req")
+ output = All2All_Scatter_Req.apply(a2a_info, *inputs)
+ myreq.WaitFunction = All2All_Scatter_Wait
+ elif a2a_impl == "scatter_list":
+ # print("Using All2All_ScatterList_Req")
+ output = All2All_ScatterList_Req.apply(a2a_info, *inputs)
+ myreq.WaitFunction = All2All_ScatterList_Wait
+ else:
+ print(
+ "Unknown value set for DLRM_ALLTOALL_IMPL (%s), "
+ "please use one of [alltoall, scatter, scatter_list]" % a2a_impl
+ )
+ return myreq
+
+
+def all_gather(input, lengths, dim=0):
+ global my_rank, my_size
+ if my_size == -1:
+ my_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+ if not lengths:
+ lengths = [input.size(0)] * my_size
+ return AllGather.apply(input, lengths, dim)
+
+
+def barrier():
+ if my_size > 1:
+ dist.barrier()
+
+
+# Override builtin print function to print only from rank 0
+orig_print = builtins.print
+
+
+def rank0_print(*args, **kwargs):
+ if my_rank <= 0 or kwargs.get("print_all", False):
+ orig_print(*args, **kwargs)
+
+
+# builtins.print = rank0_print
+
+# Allow printing from all rank with explicit print_all
+def print_all(*args, **kwargs):
+ orig_print(*args, **kwargs)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/transformer.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27c8821436f46adf875c94e9ad6d21321509150
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/modules/transformer.py
@@ -0,0 +1,185 @@
+# attention
+from torch import einsum
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+from einops import rearrange
+from .utils import l2norm, default, exists
+
+# 2d sinusoidal positional embedding
+# simple vit paper shows it is good enough compared to learned
+
+def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
+ _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
+
+ y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
+ assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
+
+ omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
+ omega = 1. / (temperature ** omega)
+
+ y = y.flatten()[:, None] * omega[None, :]
+ x = x.flatten()[:, None] * omega[None, :]
+
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
+ pe = pe.type(dtype)
+
+ return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
+
+# biasless layernorm
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim, scale = True):
+ super().__init__()
+ self.learned_gamma = nn.Parameter(torch.ones(dim)) if scale else None
+
+ self.register_buffer('gamma', torch.ones(dim), persistent = False)
+ self.register_buffer('beta', torch.zeros(dim), persistent = False)
+
+ def forward(self, x):
+ return F.layer_norm(x, x.shape[-1:], default(self.learned_gamma, self.gamma), self.beta)
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def forward(self, x):
+ x, gate = x.chunk(2, dim = -1)
+ return F.gelu(gate) * x
+
+def FeedForward(dim, mult = 4, dropout = 0.):
+ dim_hidden = int(dim * mult * 2 / 3)
+
+ return nn.Sequential(
+ LayerNorm(dim),
+ nn.Linear(dim, dim_hidden * 2, bias = False),
+ GEGLU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim_hidden, dim, bias = False)
+ )
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ causal = False,
+ dim_head = 64,
+ heads = 8,
+ dropout = 0.,
+ scale = 8
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = scale
+ self.causal = causal
+ inner_dim = dim_head * heads
+
+ self.norm = LayerNorm(dim)
+
+ self.attn_dropout = nn.Dropout(dropout)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
+
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim, bias = False),
+ nn.Dropout(dropout)
+ )
+
+ def forward(
+ self,
+ x,
+ rel_pos_bias = None,
+ mask = None
+ ):
+ b, n, _, device = *x.shape, x.device
+
+ # prenorm
+
+ x = self.norm(x)
+
+ # project for queries, keys, values
+
+ q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
+
+ # split for multi-headed attention
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
+
+ # qk rmsnorm, technique circulating within brain used to stabilize a 22B parameter vision model training
+
+ q, k = map(l2norm, (q, k))
+ q = q * self.q_scale
+ k = k * self.k_scale
+
+ # similarities
+
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+
+ if exists(rel_pos_bias):
+ sim = sim + rel_pos_bias
+
+ if exists(mask):
+ mask = rearrange(mask, 'b j -> b 1 1 j')
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
+
+ if self.causal:
+ i, j = sim.shape[-2:]
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
+
+ # attention
+
+ attn = sim.softmax(dim = -1)
+ attn = self.attn_dropout(attn)
+
+ # aggregate
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+
+ # merge heads
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+# transformer
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_mult = 4,
+ ff_dropout = 0.
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
+ ]))
+
+ def forward(
+ self,
+ x,
+ rel_pos_bias = None,
+ mask = None,
+ return_all_layers = False
+ ):
+ layers = []
+
+ for attn, ff in self.layers:
+ x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
+ x = ff(x) + x
+ layers.append(x)
+
+ if not return_all_layers:
+ return x
+
+ return x, torch.stack(layers[:-1]) if len(self.layers)>1 else None
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/modules/utils.py b/src/third_party/MuQ/src/muq/muq_mulan/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2757fffc71898c3b9ea1a216689a9c09144e829a
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/modules/utils.py
@@ -0,0 +1,45 @@
+from functools import wraps
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+def exists(val):
+ return val is not None
+
+def first(it):
+ return it[0]
+
+def default(val, d):
+ return val if exists(val) else d
+
+def round_down_nearest_multiple(n, divisor):
+ return n // divisor * divisor
+
+def Sequential(*modules):
+ return nn.Sequential(*filter(exists, modules))
+
+
+def once(fn):
+ called = False
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+ return inner
+
+print_once = once(print)
+
+# tensor functions
+
+def log(t, eps = 1e-20):
+ return torch.log(t.clamp(min = eps))
+
+def l2norm(t):
+ return F.normalize(t, p = 2, dim = -1)
+
+def frozen_params(model:nn.Module):
+ for param in model.parameters():
+ param.requires_grad = False
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/muq/muq_mulan/muq_mulan.py b/src/third_party/MuQ/src/muq/muq_mulan/muq_mulan.py
new file mode 100644
index 0000000000000000000000000000000000000000..96f90e4066dba10107a96c078e8e09178b4a3049
--- /dev/null
+++ b/src/third_party/MuQ/src/muq/muq_mulan/muq_mulan.py
@@ -0,0 +1,271 @@
+from typing import List, Optional
+from dataclasses import dataclass, field
+import os
+
+from torch.nn.parallel.distributed import DistributedDataParallel
+import torch
+import torch.nn as nn
+from torch import einsum
+from einops import rearrange
+from huggingface_hub import PyTorchModelHubMixin
+from easydict import EasyDict
+
+from .models.mulan import MuLanModel
+from .models.audio import AudioSpectrogramTransformerPretrained
+from .models.text import TextTransformerPretrained
+from .modules.utils import exists, frozen_params
+
+
+@dataclass
+class MuLanConfig:
+ sr:int = field(default=24000)
+ clip_secs:float = field(default=10)
+ dim_latent:int = field(default=512)
+ decoupled_contrastive_learning:bool = field(default=True)
+ hierarchical_contrastive_loss:bool = field(default=False)
+ hierarchical_contrastive_loss_layers:Optional[List] = field(default=None)
+ sigmoid_contrastive_loss:bool = field(default=False)
+ rank_contrast:bool = field(default=True)
+
+@dataclass
+class AudioTransformerConfig:
+ dim:int = field(default=768)
+ tf_depth:int = field(default=8)
+ heads:int = field(default=8)
+ dim_head:int = field(default=64)
+ attn_dropout:float = field(default=0.)
+ ff_dropout:float = field(default=0.)
+ ff_mult:int = field(default=4)
+
+@dataclass
+class TextTransformerConfig:
+ dim:int = field(default=768)
+ tf_depth:int = field(default=8)
+ max_seq_len:int = field(default=1024)
+ dim_head:int = field(default=64)
+ heads:int = field(default=8)
+ attn_dropout:float = field(default=0.)
+ ff_dropout:float = field(default=0.)
+ ff_mult:int = field(default=4)
+
+@dataclass
+class ModalModelConfig:
+ name:str = field(default='')
+ model_dim: Optional[int] = field(default=None)
+ use_layer_idx: int = field(default=-1)
+
+
+@dataclass
+class MuQMuLanConfig:
+ mulan: MuLanConfig
+ audio_model: ModalModelConfig
+ text_model: ModalModelConfig
+ audio_transformer: AudioTransformerConfig
+ text_transformer: TextTransformerConfig
+
+class MuQMuLan(nn.Module, PyTorchModelHubMixin):
+ def __init__(self, config: MuQMuLanConfig, hf_hub_cache_dir=None):
+ super().__init__()
+ config = self._to_obj(config)
+ self.config = config
+ self.mulan = self.create_MuLan_from_config(config, hf_hub_cache_dir)
+ self.sr = config.mulan.sr
+ self.clip_secs = config.mulan.clip_secs
+
+ def _to_obj(self, config):
+ if isinstance(config, MuQMuLanConfig):
+ config = EasyDict(
+ mulan = config.mulan,
+ audio_model = config.audio_model,
+ text_model = config.text_model,
+ audio_transformer = config.audio_transformer,
+ text_transformer = config.text_transformer,
+ )
+ else:
+ config = EasyDict(config)
+ return config
+
+ @classmethod
+ def from_pretrained(cls, *args, cache_dir=None, **kwargs):
+ kwargs['hf_hub_cache_dir'] = cache_dir
+ return super().from_pretrained(*args, cache_dir=cache_dir, **kwargs)
+
+
+ @classmethod
+ def create_MuLan_from_config(cls, config:MuQMuLanConfig, hf_hub_cache_dir=None) -> MuLanModel:
+
+ audio_transformer = AudioSpectrogramTransformerPretrained(
+ model_name = config.audio_model.name,
+ model_dim = config.audio_model.model_dim,
+ use_layer_idx = config.audio_model.use_layer_idx,
+ **config.audio_transformer,
+ frozen_pretrained = False,
+ hf_hub_cache_dir = hf_hub_cache_dir,
+ )
+ text_transformer = TextTransformerPretrained(
+ model_name = config.text_model.name,
+ model_dim = config.text_model.model_dim,
+ **config.text_transformer,
+ frozen_pretrained = False,
+ hf_hub_cache_dir = hf_hub_cache_dir,
+ )
+
+ mulan = MuLanModel(
+ audio_transformer = audio_transformer,
+ text_transformer = text_transformer,
+ **config.mulan
+ )
+
+ return mulan
+
+ def frozen(self):
+ frozen_params(self)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def mulan_module(self):
+ if isinstance(self.mulan, DistributedDataParallel):
+ return self.mulan.module
+ else:
+ return self.mulan
+
+ def forward(self,
+ wavs: Optional[torch.Tensor] = None,
+ texts: Optional[List[str]] = None,
+ *,
+ parallel_processing = False,
+ ) -> torch.Tensor:
+ """
+ Extract audio or text features, takes audio OR texts batch as input.
+ Note that if the audio is longer than 10s, it will be crop to multi cips and returns the average latent.
+ The param `parallel_processing` is used to control whether to use parallel processing or not.
+ If set to True, it uses parallel processing extractraction, which is faster but uses more GPU memory.
+ If set to False(the default), it uses serial processing extraction, which is slower but memory-friendly.
+
+ Args:
+ wavs (Optional[torch.Tensor]): Audio waveform tensor. Defaults to None.
+ texts (Optional[List[str]]): List of text strings. Defaults to None.
+ parallel_processing (bool): Whether to use parallel processing. Defaults to False.
+
+ Returns:
+ torch.Tensor: Latent representation of audio or text input.
+
+ Raises:
+ AssertionError: If both wavs and texts are provided or if neither is provided.
+
+ Note:
+ - Either wavs or texts must be provided, but not both.
+ - If wavs is provided, it calls extract_audio_latents method to process audio.
+ - If texts is provided, it calls extract_text_latents method to process text.
+ """
+ assert exists(wavs) ^ exists(texts), "Please provide either wavs or texts, but not both"
+
+ if exists(wavs):
+ return self.extract_audio_latents(wavs = wavs, parallel_processing = parallel_processing)
+ else:
+ return self.extract_text_latents(texts = texts)
+
+ def calc_similarity(self, audio_latents: torch.Tensor, text_latents: torch.Tensor) -> torch.Tensor:
+ """
+ Calculate the dot-product similarity between audio and text latent representations.
+ It supports various dimensions of input tensors (with/without batch dimension) for both audio and text.
+
+ Note:
+ The effect of this function is basically equivalent to the dot product.
+ mulan.calc_similarity(lat_a, lat_t) <==> einsum('i d, j d -> i j', lat_a, lat_t)
+
+ Args:
+ audio_latents (torch.Tensor): Latent representation of audio.
+ text_latents (torch.Tensor): Latent representation of text.
+
+ Returns:
+ torch.Tensor: Similarity scores between audio and text latent representations.
+
+ """
+ dim_a, dim_t = len(audio_latents.shape), len(text_latents.shape)
+ if dim_a == 2 and dim_t == 2:
+ return einsum('i d, j d -> i j', audio_latents, text_latents)
+ elif dim_a == 1 and dim_t == 1:
+ return torch.dot(audio_latents, text_latents)
+ elif dim_a == 2 and dim_t == 1:
+ return einsum('i d, d -> i', audio_latents, text_latents)
+ elif dim_a == 1 and dim_t == 2:
+ return einsum('d, j d -> j', audio_latents, text_latents)
+
+ raise RuntimeError(f"Invalid dimensions: audio {dim_a}, text {dim_t}")
+
+
+ def extract_audio_latents(self, wavs:torch.Tensor, *, parallel_processing = False) -> torch.Tensor:
+ """
+ Extract latent representations from audio waveforms.
+
+ This function processes a batch of audio waveforms and extracts their latent representations.
+ It supports parallel processing for faster computation but uses more GPU memory.
+
+ Args:
+ wavs (torch.Tensor): A batch of audio waveform tensors.
+ parallel_processing (bool): Flag to enable parallel processing. Defaults to False.
+
+ Returns:
+ torch.Tensor: A tensor containing the latent representations of the input audio waveforms.
+ """
+ audio_latents = []
+
+ def audio_to_latent(wav):
+ return self.mulan_module.get_audio_latents(wav)
+ for wav in wavs:
+ wav_tensors = []
+ if isinstance(wav, torch.Tensor):
+ wav_tensors = self._get_all_clips(wav)
+ else:
+ raise TypeError('wavs must be a Tensor')
+
+ if parallel_processing:
+ wav_tensors = wav_tensors.to(self.device)
+ audio_latent = audio_to_latent(wav_tensors)
+ audio_latent = audio_latent.mean(dim=0)
+ else:
+ wav_tensors = rearrange(wav_tensors, "i j -> i 1 j")
+ audio_latent = []
+ for wav_tensor in wav_tensors:
+ audio_latent.append(audio_to_latent(wav_tensor).squeeze(0))
+ del wav_tensor
+ audio_latent = torch.stack(audio_latent, dim=0)
+ audio_latent = audio_latent.mean(dim=0).to(self.device)
+
+ audio_latents.append(audio_latent)
+ audio_latents = torch.stack(audio_latents, dim=0)
+ return audio_latents
+
+ def extract_text_latents(self, texts: List[str]) -> torch.Tensor:
+ """
+ Extract latent representations from text inputs.
+
+ This function processes a list of text strings and extracts their latent representations
+ using the MuLan model's text tower.
+
+ Args:
+ texts (List[str]): A list of text strings to be processed.
+
+ Returns:
+ torch.Tensor: A tensor containing the latent representations of the input texts.
+ """
+ return self.mulan_module.get_text_latents(raw_texts=texts)
+
+ def _get_all_clips(self, audio):
+ origin_length = len(audio)
+ accum_length = 0
+ delta = self.sr * self.clip_secs
+ audio_clips = []
+ while accum_length + delta <= origin_length:
+ clip = audio[accum_length:accum_length + delta]
+ audio_clips.append(clip)
+ accum_length += delta
+ if accum_length < origin_length:
+ audio_clips.append(torch.cat([audio[accum_length:], audio[0:delta - (origin_length - accum_length)]]))
+
+ return torch.stack(audio_clips, dim=0)
+
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/README.md b/src/third_party/MuQ/src/recipes/contrastive_learning/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ad0e259357a689fdbf282f6026f4d56714628df7
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/README.md
@@ -0,0 +1,107 @@
+`
+# Guidance on MuQ-MuLan Training (Contrastive Learning)
+
+This guide provides instructions for training **MuQ-MuLan**, a contrastive learning model that jointly encodes music and text.
+
+We recommend training on **32 GPUs**, each with **at least 32 GB of memory**. Training typically takes **1–2 days**, depending on hardware environment.
+
+---
+
+## Step 1: Environment Setup
+
+First, install all required dependencies listed in [requirements.txt](./requirements.txt):
+
+```bash
+pip install -r requirements.txt
+```
+
+Then, install this repository as a library in editable mode:
+
+```bash
+pip install -e .
+```
+
+
+---
+
+## Step 2: Data Preparation
+
+We provide an example setup using the [MTG-Jamendo](https://github.com/MTG/mtg-jamendo-dataset) open-source music dataset.
+
+Please download our preprocessed data split files from [this link](https://drive.google.com/file/d/1PMCUpmtw8JwUv9Y-bNl8WIAZDtX8UaFY/view?usp=sharing) and place them in the appropriate directory.
+
+If you wish to use your own dataset, ensure that your data follow the required format.
+
+---
+
+## Step 3: Run Training
+
+### Option 1: Manual Configuration
+
+This method will prompt you to configure distributed training manually via `accelerate`:
+
+```bash
+accelerate config
+accelerate launch train.py
+```
+
+### Option 2: Using a Predefined Config File
+
+We provide example config files for multi-node multi-GPU training. For example, to launch training on **4 nodes**, each with **8 GPUs**, run:
+
+```bash
+accelerate launch \
+ --config_file config/accelerate/32gpu4node_fp16.yaml \
+ --machine_rank $NODE_RANK \
+ --main_process_ip $CHIEF_IP \
+ --main_process_port 29500 \
+ train.py
+```
+
+* `$NODE_RANK`: Index of the current machine (0 for the main node).
+* `$CHIEF_IP`: IP address of the main (rank-0) node.
+* Adjust the config file as needed for different hardware setups.
+
+If you are training on a **single machine**, simply set `--main_process_ip=127.0.0.1` and `--machine_rank=0`.
+
+If you wish to use your own pretrained MuQ model for initialization, simply modify the `model.mulan.audio_model.name` field in `config/model/muq_mulan.yaml`.
+
+---
+
+## Step 4: Convert to HF Checkpoint
+
+After training, convert your Fairseq-style checkpoint to HuggingFace format:
+
+```bash
+python scripts/convert_muqmulan_fairseq_ckpt_to_huggingface.py \
+ --checkpoint_path outputs/YYYY-MM-DD/hh-mm-ss/ckpt/mulan.1100.pt \
+ --save_dir outputs/hf-username/My-MuQ-MuLan-large
+```
+
+You can then load and use the model via the HuggingFace-style interface:
+
+```python
+from muq import MuQMuLan
+
+# Load from local HuggingFace-style checkpoint
+mulan = MuQMuLan.from_pretrained("outputs/hf-username/My-MuQ-MuLan-large")
+mulan = mulan.to(device).eval()
+
+# Extract music embeddings
+audio_embeds = mulan(wavs=wavs)
+
+# Extract text embeddings (in English or Chinese)
+text_embeds = mulan(texts=texts)
+
+# Compute similarity
+sim = mulan.calc_similarity(audio_embeds, text_embeds)
+```
+
+You can also upload your converted checkpoint to the Hugging Face Hub using `huggingface-cli`. The uploaded model will remain fully compatible with the `MuQMuLan.from_pretrained()` interface.
+
+---
+
+## Evaluation
+
+For evaluation, we recommend using the [sota-music-tagging-models](https://github.com/minzwon/sota-music-tagging-models/) toolkit. It supports various metrics and datasets widely used in music tagging and retrieval.
+
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/2gpu.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/2gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dc213d6ebd2a86a93427412dddfbae95d63fb1ea
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/2gpu.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 1
+num_processes: 2
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..df1245ad0ceeea1d33696a6e50adc3e8c9f9c978
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node.yaml
@@ -0,0 +1,19 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_process_ip: '11.214.123.26'
+main_process_port: 25545
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_fp16.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_fp16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d90bcd1b72f0b0759bc719164289a85dafcac986
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_fp16.yaml
@@ -0,0 +1,19 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_process_ip: $INDEX
+main_process_port: 25529
+main_training_function: main
+mixed_precision: fp16
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_zero2.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5efac83d7fd32f652ed6cf9ee761729181b6185e
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/32gpu4node_zero2.yaml
@@ -0,0 +1,25 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ deepspeed_multinode_launcher: standard
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_process_ip: $CHIEF_IP
+main_process_port: 25520
+main_training_function: main
+mixed_precision: 'no'
+num_machines: 4
+num_processes: 32
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e2674d103901d3d2a27d2c6351104fe08306e25b
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: all
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16_zero2.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16_zero2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9e59275dabc4f23c5fbb444ddbe8bebec60dc90c
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/accelerate/8gpu_fp16_zero2.yaml
@@ -0,0 +1,22 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+deepspeed_config:
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+machine_rank: 0
+main_training_function: main
+mixed_precision: fp16
+num_machines: 1
+num_processes: 8
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/dataset/mtg_en_zh.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/dataset/mtg_en_zh.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec54b6a5a6095dec9a2c6e4e374f2af3c0bc5adc
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/dataset/mtg_en_zh.yaml
@@ -0,0 +1,28 @@
+# @package _group_
+mtg_jamendo_json_en: &MTG
+ use: true
+ ratio: 1
+ duration: ${basics.duration}
+ lang: en
+ plain_rate: 0.5
+
+ data_dir: path/to/MTG-Jamendo # Please modify this
+ json_path:
+ train: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/music_info.train.json
+ valid: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/music_info.valid.json
+ tag_types: [genre, instrument, moods] # artist, title
+ prompt: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/en.txt
+
+
+mtg_jamendo_json_zh:
+ <<: *MTG
+ use: true
+ ratio: 1
+
+ lang: zh
+ prompt: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/zh.txt
+ translate:
+ genre: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/genre_en2zh.json
+ instrument: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/instrument_en2zh.json
+ keywords: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/keywords_en2zh.json
+ moods: /root/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/moods_en2zh.json
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/model/muq_mulan.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/model/muq_mulan.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5996428037e38d8eb23337cbc242ddc45154ab50
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/model/muq_mulan.yaml
@@ -0,0 +1,40 @@
+# @package _global_
+model:
+ mulan:
+ sr: 24000
+ clip_secs: 10
+ dim_latent: 512
+ decoupled_contrastive_learning: true
+ hierarchical_contrastive_loss: false
+ hierarchical_contrastive_loss_layers: null
+ sigmoid_contrastive_loss: false
+ rank_contrast: true
+
+ audio_model:
+ name: OpenMuQ/MuQ-large-msd-iter
+ model_dim: 1024
+ use_layer_idx: -1
+
+ text_model:
+ name: xlm-roberta-base
+ model_dim: null
+ use_layer_idx: -1
+
+ audio_transformer:
+ dim: 768
+ tf_depth: 0
+ heads: 8
+ dim_head: 64
+ attn_dropout: 0
+ ff_dropout: 0
+ ff_mult: 4
+
+ text_transformer:
+ dim: 768
+ tf_depth: 8
+ max_seq_len: 1024
+ dim_head: 64
+ heads: 8
+ attn_dropout: 0
+ ff_dropout: 0
+ ff_mult: 4
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/config/train.yaml b/src/third_party/MuQ/src/recipes/contrastive_learning/config/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..960a12c1f6ee8cf76b51cad54d43a3814aa4e4b2
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/config/train.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - model: muq_mulan
+ - dataset: mtg_en_zh
+
+basics:
+ #logs and checkpoints will be saved to outputs/yyyy-mm-dd/hh-mm-ss/
+ sr: 24000 #sample rate
+ orig_sr: 44100 #origin sample rate used for datatset loading
+ random_seed: 42 #set to null if require non-deterministic
+ duration: 10
+
+train:
+ lr: 1e-4
+ batch_size: 24
+ num_workers: 6
+ data_max_secs: 10
+ save_model_every: 100
+ valid_every: 1
+ log_tensorboard: true
+ resume:
+ use: false
+ checkpoint_path: path/to/mulan.1100.pt
+ load_optimizer: false
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/genre_idx.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/genre_idx.json
new file mode 100644
index 0000000000000000000000000000000000000000..fff212bf5fa83164fc3cb41a423f1996b9aaf7c8
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/genre_idx.json
@@ -0,0 +1,97 @@
+{
+ "club": 0,
+ "gothic": 1,
+ "easylistening": 2,
+ "60s": 3,
+ "folk": 4,
+ "medieval": 5,
+ "tribal": 6,
+ "downtempo": 7,
+ "breakbeat": 8,
+ "90s": 9,
+ "synthpop": 10,
+ "industrial": 11,
+ "jazzfunk": 12,
+ "dubstep": 13,
+ "instrumentalpop": 14,
+ "bluesrock": 15,
+ "grunge": 16,
+ "ska": 17,
+ "deephouse": 18,
+ "house": 19,
+ "ethno": 20,
+ "jazzfusion": 21,
+ "poprock": 22,
+ "trance": 23,
+ "orchestral": 24,
+ "funk": 25,
+ "blues": 26,
+ "ethnicrock": 27,
+ "dub": 28,
+ "singersongwriter": 29,
+ "techno": 30,
+ "choir": 31,
+ "electronica": 32,
+ "reggae": 33,
+ "rock": 34,
+ "country": 35,
+ "postrock": 36,
+ "idm": 37,
+ "groove": 38,
+ "celtic": 39,
+ "triphop": 40,
+ "rocknroll": 41,
+ "rnb": 42,
+ "hardrock": 43,
+ "electropop": 44,
+ "latin": 45,
+ "alternative": 46,
+ "swing": 47,
+ "soundtrack": 48,
+ "darkambient": 49,
+ "drumnbass": 50,
+ "hiphop": 51,
+ "psychedelic": 52,
+ "classical": 53,
+ "dance": 54,
+ "instrumentalrock": 55,
+ "worldfusion": 56,
+ "fusion": 57,
+ "edm": 58,
+ "80s": 59,
+ "improvisation": 60,
+ "rap": 61,
+ "soul": 62,
+ "punkrock": 63,
+ "disco": 64,
+ "hard": 65,
+ "atmospheric": 66,
+ "newage": 67,
+ "minimal": 68,
+ "70s": 69,
+ "newwave": 70,
+ "acidjazz": 71,
+ "experimental": 72,
+ "alternativerock": 73,
+ "african": 74,
+ "eurodance": 75,
+ "darkwave": 76,
+ "pop": 77,
+ "heavymetal": 78,
+ "classicrock": 79,
+ "world": 80,
+ "contemporary": 81,
+ "metal": 82,
+ "chanson": 83,
+ "progressive": 84,
+ "ambient": 85,
+ "chillout": 86,
+ "symphonic": 87,
+ "oriental": 88,
+ "lounge": 89,
+ "bossanova": 90,
+ "electronic": 91,
+ "jazz": 92,
+ "popfolk": 93,
+ "indie": 94
+}
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/instrument_idx.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/instrument_idx.json
new file mode 100644
index 0000000000000000000000000000000000000000..2a260481d87550a29847992f5c4842c8b64b9016
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/instrument_idx.json
@@ -0,0 +1,43 @@
+{
+ "accordion": 0,
+ "classicalguitar": 1,
+ "ukulele": 2,
+ "voice": 3,
+ "violin": 4,
+ "horn": 5,
+ "percussion": 6,
+ "oboe": 7,
+ "strings": 8,
+ "electricpiano": 9,
+ "orchestra": 10,
+ "clarinet": 11,
+ "piano": 12,
+ "organ": 13,
+ "bass": 14,
+ "electricguitar": 15,
+ "bongo": 16,
+ "beat": 17,
+ "synthesizer": 18,
+ "harmonica": 19,
+ "saxophone": 20,
+ "guitar": 21,
+ "rhodes": 22,
+ "doublebass": 23,
+ "acousticguitar": 24,
+ "drums": 25,
+ "sampler": 26,
+ "harp": 27,
+ "pad": 28,
+ "keyboard": 29,
+ "trombone": 30,
+ "trumpet": 31,
+ "viola": 32,
+ "brass": 33,
+ "computer": 34,
+ "acousticbassguitar": 35,
+ "cello": 36,
+ "bell": 37,
+ "flute": 38,
+ "drummachine": 39,
+ "pipeorgan": 40
+}
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/keywords_idx.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/keywords_idx.json
new file mode 100644
index 0000000000000000000000000000000000000000..878cf942e36071a1342207741d39e34d3854880e
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/keywords_idx.json
@@ -0,0 +1,61 @@
+{
+ "retro": 0,
+ "soft": 1,
+ "soundscape": 2,
+ "cool": 3,
+ "romantic": 4,
+ "heavy": 5,
+ "upbeat": 6,
+ "calm": 7,
+ "melancholic": 8,
+ "fun": 9,
+ "adventure": 10,
+ "sport": 11,
+ "corporate": 12,
+ "commercial": 13,
+ "happy": 14,
+ "children": 15,
+ "summer": 16,
+ "action": 17,
+ "energetic": 18,
+ "hopeful": 19,
+ "dream": 20,
+ "drama": 21,
+ "love": 22,
+ "trailer": 23,
+ "space": 24,
+ "deep": 25,
+ "background": 26,
+ "documentary": 27,
+ "sad": 28,
+ "holiday": 29,
+ "groovy": 30,
+ "dark": 31,
+ "horror": 32,
+ "emotional": 33,
+ "slow": 34,
+ "inspiring": 35,
+ "advertising": 36,
+ "film": 37,
+ "mellow": 38,
+ "game": 39,
+ "nature": 40,
+ "sexy": 41,
+ "christmas": 42,
+ "ballad": 43,
+ "fast": 44,
+ "powerful": 45,
+ "uplifting": 46,
+ "dramatic": 47,
+ "funny": 48,
+ "meditative": 49,
+ "positive": 50,
+ "relaxing": 51,
+ "party": 52,
+ "motivational": 53,
+ "travel": 54,
+ "movie": 55,
+ "epic": 56,
+ "melodic": 57,
+ "ambiental": 58
+}
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/moods_idx.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/moods_idx.json
new file mode 100644
index 0000000000000000000000000000000000000000..878cf942e36071a1342207741d39e34d3854880e
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/index/moods_idx.json
@@ -0,0 +1,61 @@
+{
+ "retro": 0,
+ "soft": 1,
+ "soundscape": 2,
+ "cool": 3,
+ "romantic": 4,
+ "heavy": 5,
+ "upbeat": 6,
+ "calm": 7,
+ "melancholic": 8,
+ "fun": 9,
+ "adventure": 10,
+ "sport": 11,
+ "corporate": 12,
+ "commercial": 13,
+ "happy": 14,
+ "children": 15,
+ "summer": 16,
+ "action": 17,
+ "energetic": 18,
+ "hopeful": 19,
+ "dream": 20,
+ "drama": 21,
+ "love": 22,
+ "trailer": 23,
+ "space": 24,
+ "deep": 25,
+ "background": 26,
+ "documentary": 27,
+ "sad": 28,
+ "holiday": 29,
+ "groovy": 30,
+ "dark": 31,
+ "horror": 32,
+ "emotional": 33,
+ "slow": 34,
+ "inspiring": 35,
+ "advertising": 36,
+ "film": 37,
+ "mellow": 38,
+ "game": 39,
+ "nature": 40,
+ "sexy": 41,
+ "christmas": 42,
+ "ballad": 43,
+ "fast": 44,
+ "powerful": 45,
+ "uplifting": 46,
+ "dramatic": 47,
+ "funny": 48,
+ "meditative": 49,
+ "positive": 50,
+ "relaxing": 51,
+ "party": 52,
+ "motivational": 53,
+ "travel": 54,
+ "movie": 55,
+ "epic": 56,
+ "melodic": 57,
+ "ambiental": 58
+}
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/en.txt b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/en.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5a34a5bd5ab0c9ededb4b2097d66ed42c987c22e
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/en.txt
@@ -0,0 +1,50 @@
+Create a [{genre}] track [with {instrument}][ that evokes a {moods} atmosphere][ by {artist}][, named {title}].
+Compose a [{genre}] piece [by {artist}] [with a focus on {instrument}][, named {title}][, and evokes a {moods} moods].
+Produce a [{genre}] song [that captures the essence of {moods}][ by {artist}][, featuring {instrument}][, named {title}].
+Craft a [{genre}] composition [by {artist}][ that explores the {moods}][, with {instrument}][, named {title}].
+Develop a [{genre}] track [with a {moods} vibe][ by {artist}][, featuring {instrument}][, named {title}].
+Generate a [{genre}] song [that highlights {artist}'s signature sound][, with {instrument}][, named {title}][, inspired by {moods} moods].
+Write a [{genre}] piece [that combines {instrument}][ to create a {moods} atmosphere][ by {artist}][, named {title}].
+Arrange a [{genre}] composition [with a focus on {instrument}][, by {artist}][, named {title}][, and conveys {moods} moods].
+Collaborate on a [{genre}] track [with {artist}][ that showcases {instrument}][, named {title}][, inspired by {moods}].
+Experiment [with {genre}] and [{instrument}] to create a [{moods}-infused] piece [by {artist}][, named {title}].
+Mix [{genre} elements] [with {instrument}] to produce a [{moods}-driven] track [by {artist}][, named {title}].
+Fuse [{genre}] and [{instrument}] to craft [an innovative {moods}-inspired] piece [by {artist}][, named {title}].
+Synthesize a [{genre}] composition [that highlights {artist}'s skills] on [{instrument}] [, named {title}][, exressing {moods}].
+Explore [{genre}] and [{moods}] through a composition [featuring {instrument}][ by {artist}][, named {title}].
+Combine [{genre}] and [{instrument}] to create a track [that exudes {moods}][ by {artist}][, named {title}].
+[Channel {moods} into] a [{genre}] piece [with {instrument}][, by {artist}][, named {title}].
+[Capture the essence of {moods} in] a [{genre}] composition [by {artist}][, featuring {instrument}][, named {title}].
+[Express {moods} through] a [{genre}] track with [{instrument}][ by {artist}][, named {title}].
+[With {title}, ][{artist}] brings us a [{genre}] track [highlighting {instrument}][, and evoking a {moods} atmosphere].
+[Introducing {title},] a song [by {artist}], that [features {instrument}] and [captures the essence of {moods}][, as a {genre} masterpiece].
+[Explore {title},] a [{genre}] track [by {artist}][, that masterfully utilizes {instrument}][, and encapsulates a {moods} essence].
+[Enter the world of {title},] a [{genre}] composition [by {artist}][, that merges {instrument}][, and creates a {moods} ambience].
+[Dive into {title},] a [{genre}] piece [by {artist}][, that seamlessly integrates {instrument}][, and exudes a {moods} atmosphere].
+This is a [{genre}] masterpiece [by {artist}][, which is titled {title}][, featuring {instrument}][, and exudes a {moods} moods].
+It is a [{genre}] composition [by {artist}][, named {title}][, that showcases the beauty of {instrument}][, and captures the essence of {moods}].
+A [{genre}] track [by {artist}][, titled {title}][, skillfully blending {instrument}][, and creating a {moods} atmosphere].
+This is a song [by {artist}][, which is titled {title},] that [combines {instrument}][, has a {moods} moods][, belongs to the {genre} category].
+It is a [{genre}] piece [by {artist}][, called {title}][, where {instrument} takes center stage][, and evokes a {moods} feeling].
+A captivating [{genre}] composition [by {artist}][, named {title}][, highlighting {instrument}][, and infused with a {moods} vibe].
+This is a [{genre}] creation [by {artist}][, titled {title}][, that seamlessly merges {instrument}][, and transports listeners to a {moods} state].
+It is a [{genre}] song [by {artist}][, called {title}][, that masterfully interweaves {instrument}][, and conveys a {moods} emotion].
+A [{genre}] musical journey [by {artist}][, named {title}][, featuring {instrument}][, and inspired by the {moods} sentiment].
+This is a [{genre}] gem [by {artist}][, titled {title},] that [skillfully intertwines {instrument}][, and resonates with a {moods} feeling].
+It is a [{genre}] work of art [by {artist}][, called {title}][, that showcases {instrument}][, and embodies the essence of {moods}].
+This is a [{genre}] opus [by {artist}][, titled {title}][, that fuses {instrument}][, and immerses listeners in a {moods} atmosphere].
+It is a [{genre}] creation [by {artist}][, called {title}][, that celebrates {instrument}][, and evokes a {moods} sensation].
+A [{genre}] piece [by {artist}][, named {title},] [that focuses on {instrument}] [, and transports the audience to a {moods} realm].
+This is a [{genre}] composition [by {artist}][, titled {title}][, that embraces {instrument}][, and captures the spirit of {moods}].
+A [{genre}] song [by {artist}][, named {title},] [that experiments with {instrument}][, and dives into the depths of {moods}].
+[{genre}][, by {artist}][, {instrument}][, {moods}][, titled {title}]
+[{moods}][, {genre}][, from {artist}][, {instrument}][, name: {title}]
+[{instrument}][, {moods}][, {genre}][, artist: {artist}][, song: {title}]
+[by {artist}][, {genre}][, {moods}][, {instrument}][, called {title}]
+[{genre}][, {instrument}][, by {artist}][, {moods}][, title: {title}]
+[{instrument}][, {moods}][, by {artist}][, {genre}][, titled {title}]
+[{genre}][, {instrument}][, {artist}][, {moods}][, name: {title}]
+[by {artist}][, {genre}][, {instrument}][, {moods}][, called {title}]
+[{moods}][, {genre}][, by {artist}][, {instrument}][, title: {title}]
+[{instrument}][, {genre}][, artist: {artist}][, {moods}][, song: {title}]
+[{genre}][, {instrument}][, {moods}][, {artist}][, song: {title}]
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/zh.txt b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/zh.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8fa47849436da25b6e26af79182774f51ab278aa
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/prompt/zh.txt
@@ -0,0 +1,51 @@
+创作一首[{genre}风格的]曲目,[使用{instrument}][营造一种{moods}氛围,][由{artist}创作][取名为{title}]。
+创作一段[{genre}流派的]乐曲,[由{artist}创作][以{instrument}为重点][,取名为{title}][,并营造{moods}的氛围]。
+创作一首[{genre}风格的]歌曲,[捕捉{moods}的精髓][,类似{artist}的作品][,以{instrument}为特色][,取名为{title}]。
+创作一段[{genre}流派的]作曲,[由{artist}创作][探索{moods}][,以{instrument}为特色][,取名为{title}]。
+创作一首[{genre}风格的]曲目,[呈现出{moods}氛围][由{artist}创作][,以{instrument}为特色][,取名为{title}]。
+生成一首[{genre}流派的]歌曲,[突显出{artist}独有的音乐风格][,以{instrument}为特色][,取名为{title}][,灵感来自{moods}的氛围]。
+创作一段[{genre}风格的]乐曲,[结合{instrument}][创造出{moods}的氛围][,模仿{artist}的作品][,取名为{title}]。
+编排一段[{genre}流派的]作曲,[以{instrument}为重点][由{artist}创作][,取名为{title}][,传达{moods}的感受]。
+一首[{genre}]歌曲[,具有{moods}气息][,以{instrument}为中心][,并由以下标签识别:{Tags}][,{artist}的风格][,取名为{title}]。
+一首[{genre}]歌[,带有{moods}感觉][,展示了{instrument}][,{Tags}]。
+一首[{genre}]曲子[,具有{moods}氛围][,专注于{instrument}][,并由以下标签识别:{Tags}]。
+一首[{genre}]作品[带有{moods}氛围][,强调{instrument}的声音][,并由以下标签定义:{Tags}]。
+一首[{genre}]音乐[,突出使用{instrument}][,具有{moods}触感][,由以下标签定义:{Tags}]。
+一首[{genre}]作品[带有{moods}氛围][,强调{instrument}的声音][,具有标签:{Tags}]。
+一段[{genre}]表演[伴随着{moods}氛围][,以{instrument}为中心][,{Tags}]。
+一首[{genre}]乐章[,散发着{moods}气氛][,以{instrument}为主导][,标签为:{Tags}]。
+一首[{genre}]旋律[具有{moods}风格][,突出表现了{instrument}][,标签为{Tags}]。
+一首[{genre}]歌曲[,带有{moods}情感][,重点展现{instrument}][,{Tags}]。
+一首[{genre}]创作[融入了{moods}元素][,强调{instrument}的音色][,并由以下标签定义:{Tags}]。
+一首[{genre}]演奏曲[突显{instrument}的运用][,带有{moods}特质][,由以下标签:{Tags}]。
+一首[{genre}]乐曲[具有{moods}气氛][,凸显{instrument}的声音][,{Tags}]。
+一段[{genre}]演奏[营造出{moods}气氛][,以{instrument}为焦点][,标签:{Tags}]。
+[{genre}][,{moods}][,{instrument}][,{Tags}]
+[{Tags}][,{genre}],[{instrument}],[{moods}]
+[风格: {genre}][,情绪: {moods}][,乐器: {instrument}][,标签: {Tags}]
+[{genre}音乐][,{BPMDescript}][,{moods}氛围][,{instrument}声音][,{Tags}]
+[{genre}类型][,{moods}感觉][,节奏为{BPM}][,{BPMDescript}][,使用{instrument}][,{Tags}]
+[{instrument}演奏][,节奏{BPM}拍][,{BPMDescript}][,{moods}情感][,{genre}类别][,{Tags}]
+[{moods}的情感][,{BPMDescript}][,{genre}风格][,{instrument}融合][,{Tags}][,BPM: {BPM}]
+[BPM为{BPM},][{Tags}][{genre}形式][,{moods}本质][,{instrument}重点][,{BPMDescript}]
+[{BPMDescript}][,{genre}传统][,{BPM} bpm][,{moods}感觉][,{instrument}展示][,标签: {Tags}]
+[{instrument}为重点][,{genre}的风格][,{moods}氛围][,{BPMDescript}] [,{Tags}][,{BPM} bpm]
+[{genre}风格][,{moods}情感][,节奏{BPM}bpm][,{BPMDescript}][,{instrument}特色][,{Tags}]
+[{Tags}][,{BPMDescript}][,{genre}类型][,节奏{BPM} bpm,],[{instrument}],[{moods}]
+[类型: {genre}][,氛围: {moods}][,{BPM}BPM][,乐器特点: {instrument}][,关键词: {Tags}][,{BPMDescript}]
+[{genre}类别][,{BPMDescript}][,{moods}基调][,{instrument}元素][,{Tags}][,{BPM}bpm]
+[{genre}特点][,{moods}背景][,{Tags}][,{instrument}主题][,{BPMDescript}][,{BPM} bpm]
+[{instrument}呈现][,节奏{BPM}][,{BPMDescript}][,{moods}内涵][,{genre}流派][,{Tags}]
+[{moods}表达][,{BPMDescript}][,{genre}艺术形式][,{instrument}结合][,{Tags}][,BPM为{BPM}]
+[每分钟{BPM}拍][,{Tags}][,{genre}手法][,{moods}内核][,{instrument}突出][,{BPMDescript}]
+[{BPMDescript}][,{genre}特色][,{BPM} bpm][,{moods}情感渲染][,{instrument}呈现方式][,关键词: {Tags}]
+[{instrument}核心][,{genre}表现形式][,{moods}气氛][,{BPMDescript}] [,{Tags}][,节奏{BPM} bpm]
+[{genre}][,{moods}][,节奏{BPM}bpm][,{BPMDescript}][,{instrument}][,{Tags}]
+[{Tags}] [,{BPMDescript}][,{genre}][,节奏{BPM} bpm,],[{instrument}],[{moods}]
+[类型: {genre}][,氛围: {moods}][,{BPM}BPM][,乐器: {instrument}][,标签: {Tags}][,{BPMDescript}]
+[{genre}类别][,{BPMDescript}][,{moods}的感觉][,{instrument}音色][,{Tags}][,节奏{BPM}bpm]
+[{genre}][,{moods}情绪][,{Tags}][,{instrument}特点][,{BPMDescript}][,节奏{BPM} bpm]
+[{instrument}演绎][,节奏{BPM}][,{BPMDescript}][,{moods}心情][,{genre}风格][,{Tags}]
+[{moods}的心境][,{BPMDescript}][,{genre}类别][,{instrument}结合][,{Tags}][,BPM: {BPM}]
+[{BPMDescript}][,{genre}习惯][,节奏{BPM} bpm][,{moods}感受][,{instrument}呈现][,标签: {Tags}]
+[{instrument}][,{genre}风格][,{moods}氛围][,{BPMDescript}] [,{Tags}][,节奏{BPM} bpm]
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/genre_en2zh.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/genre_en2zh.json
new file mode 100644
index 0000000000000000000000000000000000000000..130bad5f7ea0cc92a603788e6c0f772f936b1c48
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/genre_en2zh.json
@@ -0,0 +1,97 @@
+{
+ "club": "俱乐部",
+ "gothic": "哥特式",
+ "easylistening": "轻音乐",
+ "60s": "60年代",
+ "folk": "民谣",
+ "medieval": "中世纪",
+ "tribal": "Tribal音乐",
+ "downtempo": "慢摇",
+ "breakbeat": "碎拍乐",
+ "90s": "90年代",
+ "synthpop": "合成器流行",
+ "industrial": "工业",
+ "jazzfunk": "爵士放克",
+ "dubstep": "Dubstep",
+ "instrumentalpop": "流行器乐",
+ "bluesrock": "布鲁斯摇滚",
+ "grunge": "垃圾摇滚",
+ "ska": "斯卡",
+ "deephouse": "深度浩室",
+ "house": "House音乐",
+ "ethno": "民族",
+ "jazzfusion": "爵士融合",
+ "poprock": "流行摇滚",
+ "trance": "Trance音乐",
+ "orchestral": "管弦乐",
+ "funk": "放克",
+ "blues": "蓝调",
+ "ethnicrock": "民族摇滚",
+ "dub": "Dub",
+ "singersongwriter": "Singer-songwriter",
+ "techno": "电子音乐",
+ "choir": "合唱",
+ "electronica": "电子乐",
+ "reggae": "雷鬼",
+ "rock": "摇滚",
+ "country": "乡村音乐",
+ "postrock": "后摇滚",
+ "idm": "智能舞曲",
+ "groove": "Groove",
+ "celtic": "凯尔特音乐",
+ "triphop": "Trip-Hop",
+ "rocknroll": "摇滚乐",
+ "rnb": "节奏布鲁斯",
+ "hardrock": "硬摇滚",
+ "electropop": "电子流行",
+ "latin": "拉丁",
+ "alternative": "非主流",
+ "swing": "摇摆",
+ "soundtrack": "原声带",
+ "darkambient": "黑暗氛围",
+ "drumnbass": "鼓打贝斯",
+ "hiphop": "嘻哈",
+ "psychedelic": "迷幻",
+ "classical": "古典",
+ "dance": "舞曲",
+ "instrumentalrock": "器乐摇滚",
+ "worldfusion": "世界融合",
+ "fusion": "融合",
+ "edm": "电子舞曲",
+ "80s": "80年代",
+ "improvisation": "即兴创作",
+ "rap": "说唱",
+ "soul": "灵魂乐",
+ "punkrock": "朋克摇滚",
+ "disco": "迪斯科",
+ "hard": "硬核",
+ "atmospheric": "Atmospheric音乐",
+ "newage": "新时代",
+ "minimal": "极简",
+ "70s": "70年代",
+ "newwave": "新浪潮",
+ "acidjazz": "酸爵士",
+ "experimental": "实验性的",
+ "alternativerock": "另类摇滚",
+ "african": "非洲",
+ "eurodance": "欧洲舞曲",
+ "darkwave": "暗潮音乐",
+ "pop": "流行",
+ "heavymetal": "重金属",
+ "classicrock": "经典摇滚",
+ "world": "世界",
+ "contemporary": "当代",
+ "metal": "金属",
+ "chanson": "Chanson",
+ "progressive": "前卫音乐",
+ "ambient": "氛围音乐",
+ "chillout": "驰放音乐",
+ "symphonic": "交响乐",
+ "oriental": "东方",
+ "lounge": "沙发音乐",
+ "bossanova": "博萨诺瓦",
+ "electronic": "电子音乐",
+ "jazz": "爵士",
+ "popfolk": "民谣流行",
+ "indie": "独立音乐"
+}
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/instrument_en2zh.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/instrument_en2zh.json
new file mode 100644
index 0000000000000000000000000000000000000000..8b6e825d04a74f6c4b3e5e98e6e3d7da05867262
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/instrument_en2zh.json
@@ -0,0 +1,43 @@
+{
+ "accordion": "手风琴",
+ "classicalguitar": "古典吉他",
+ "ukulele": "尤克里里",
+ "voice": "声音",
+ "violin": "小提琴",
+ "horn": "喇叭",
+ "percussion": "打击乐器",
+ "oboe": "双簧管",
+ "strings": "弦乐器",
+ "electricpiano": "电钢琴",
+ "orchestra": "管弦乐队",
+ "clarinet": "单簧管",
+ "piano": "钢琴",
+ "organ": "风琴",
+ "bass": "贝斯",
+ "electricguitar": "电吉他",
+ "bongo": "邦戈鼓",
+ "beat": "节奏",
+ "synthesizer": "合成器",
+ "harmonica": "口琴",
+ "saxophone": "萨克斯",
+ "guitar": "吉他",
+ "rhodes": "罗德斯钢琴",
+ "doublebass": "大提琴",
+ "acousticguitar": "木吉他",
+ "drums": "鼓",
+ "sampler": "抽样器",
+ "harp": "竖琴",
+ "pad": "电子合成器",
+ "keyboard": "键盘",
+ "trombone": "长号",
+ "trumpet": "小号",
+ "viola": "中提琴",
+ "brass": "铜管乐器",
+ "computer": "电脑",
+ "acousticbassguitar": "木贝斯",
+ "cello": "大提琴",
+ "bell": "钟",
+ "flute": "长笛",
+ "drummachine": "鼓机",
+ "pipeorgan": "管风琴"
+ }
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/keywords_en2zh.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/keywords_en2zh.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad39ffe98680977e943d9c261f1eae00505d909b
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/keywords_en2zh.json
@@ -0,0 +1,61 @@
+{
+ "retro": "复古",
+ "soft": "柔和",
+ "soundscape": "声景",
+ "cool": "酷",
+ "romantic": "浪漫",
+ "heavy": "沉重",
+ "upbeat": "欢快",
+ "calm": "平静",
+ "melancholic": "忧郁",
+ "fun": "有趣",
+ "adventure": "冒险",
+ "sport": "运动",
+ "corporate": "公司",
+ "commercial": "商业",
+ "happy": "快乐",
+ "children": "儿童",
+ "summer": "夏天",
+ "action": "动作",
+ "energetic": "有活力",
+ "hopeful": "充满希望",
+ "dream": "梦想",
+ "drama": "戏剧",
+ "love": "爱",
+ "trailer": "预告片",
+ "space": "太空",
+ "deep": "深沉",
+ "background": "背景",
+ "documentary": "纪录片",
+ "sad": "悲伤",
+ "holiday": "假期",
+ "groovy": "酷炫",
+ "dark": "黑暗",
+ "horror": "恐怖",
+ "emotional": "情感",
+ "slow": "慢",
+ "inspiring": "启发",
+ "advertising": "广告",
+ "film": "电影",
+ "mellow": "柔和",
+ "game": "游戏",
+ "nature": "自然",
+ "sexy": "性感",
+ "christmas": "圣诞节",
+ "ballad": "叙事歌",
+ "fast": "快速",
+ "powerful": "强大",
+ "uplifting": "鼓舞人心",
+ "dramatic": "戏剧性",
+ "funny": "有趣",
+ "meditative": "冥想",
+ "positive": "积极",
+ "relaxing": "放松",
+ "party": "派对",
+ "motivational": "激励",
+ "travel": "旅行",
+ "movie": "电影",
+ "epic": "史诗",
+ "melodic": "旋律优美",
+ "ambiental": "环境"
+ }
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/moods_en2zh.json b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/moods_en2zh.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad39ffe98680977e943d9c261f1eae00505d909b
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/datasets/mtg-jamendo/translate/moods_en2zh.json
@@ -0,0 +1,61 @@
+{
+ "retro": "复古",
+ "soft": "柔和",
+ "soundscape": "声景",
+ "cool": "酷",
+ "romantic": "浪漫",
+ "heavy": "沉重",
+ "upbeat": "欢快",
+ "calm": "平静",
+ "melancholic": "忧郁",
+ "fun": "有趣",
+ "adventure": "冒险",
+ "sport": "运动",
+ "corporate": "公司",
+ "commercial": "商业",
+ "happy": "快乐",
+ "children": "儿童",
+ "summer": "夏天",
+ "action": "动作",
+ "energetic": "有活力",
+ "hopeful": "充满希望",
+ "dream": "梦想",
+ "drama": "戏剧",
+ "love": "爱",
+ "trailer": "预告片",
+ "space": "太空",
+ "deep": "深沉",
+ "background": "背景",
+ "documentary": "纪录片",
+ "sad": "悲伤",
+ "holiday": "假期",
+ "groovy": "酷炫",
+ "dark": "黑暗",
+ "horror": "恐怖",
+ "emotional": "情感",
+ "slow": "慢",
+ "inspiring": "启发",
+ "advertising": "广告",
+ "film": "电影",
+ "mellow": "柔和",
+ "game": "游戏",
+ "nature": "自然",
+ "sexy": "性感",
+ "christmas": "圣诞节",
+ "ballad": "叙事歌",
+ "fast": "快速",
+ "powerful": "强大",
+ "uplifting": "鼓舞人心",
+ "dramatic": "戏剧性",
+ "funny": "有趣",
+ "meditative": "冥想",
+ "positive": "积极",
+ "relaxing": "放松",
+ "party": "派对",
+ "motivational": "激励",
+ "travel": "旅行",
+ "movie": "电影",
+ "epic": "史诗",
+ "melodic": "旋律优美",
+ "ambiental": "环境"
+ }
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/__init__.py b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c437640e4a4d4428985bde49d666d6501af22c5b
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/__init__.py
@@ -0,0 +1 @@
+from musiclm_pytorch.trainer import MuLaNTrainer
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/dataset.py b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f293015db2930d6ea690157e7dd8577665427b67
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/dataset.py
@@ -0,0 +1,459 @@
+from torch.utils.data import Dataset
+from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union
+from beartype import beartype
+from beartype.door import is_bearable
+import random
+import pandas as pd
+import os
+from torchaudio.functional import resample
+import torch
+import typing as tp
+from pathlib import Path
+import torchaudio as ta
+import torch.nn.functional as F
+import numpy as np
+import json
+import yaml
+import torchaudio
+import math
+import re
+from loguru import logger
+import ffmpeg
+
+class Read_and_PadCrop_Normalized_T(torch.nn.Module):
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
+ if self.n_samples < 0: #means not clip
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
+ t_start = 0.
+ t_end = 1.0
+ offset = 0
+ else:
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
+ t_start = 0.
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
+ offset = 0
+ # print('c1:',chunk.shape)
+ else:
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ t_start = offset / float(cur_sample_rate) / duration
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ # print('offset:',offset)
+ # print('c0:',chunk.shape)
+ # Pad with silence if necessary.
+ if(chunk.shape[0]>1):
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
+ else:
+ chunk = chunk[[0],:].float()
+ if(cur_sample_rate!=self.sample_rate):
+ # print('a:',cur_sample_rate,chunk.shape)
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
+ # print('b:',self.sample_rate,chunk.shape)
+
+ if self.n_samples > 0:
+ if chunk.shape[-1] < self.n_samples:
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
+ else:
+ chunk = chunk[:,0:self.n_samples]
+ seconds_start = math.floor(offset / cur_sample_rate)
+ seconds_total = math.floor(duration)
+
+ return (
+ chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total
+ )
+
+USE_DUMMY_AUDIO = False
+if USE_DUMMY_AUDIO:
+ logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!")
+
+class SafeAudioReader:
+ """
+ This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data.
+ """
+ def __init__(self,
+ duration: float,
+ sample_rate: int,
+ randomize: bool = True,
+ ):
+ self.n_samples = int(sample_rate * duration)
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize)
+
+ def __call__(self,
+ filepath: os.PathLike,
+ origin_sample_rate: Optional[int] = None,
+ origin_duration: float = None,
+ ) -> torch.Tensor:
+ if USE_DUMMY_AUDIO:
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
+ return wav
+ try:
+ if origin_sample_rate is None or origin_duration is None:
+ # audio_info = torchaudio.info(filepath)
+ # origin_sample_rate = audio_info.sample_rate
+ # origin_duration = audio_info.num_frames / origin_sample_rate
+ info = ffmpeg.probe(filepath)
+ origin_duration = float(info['format']['duration'])
+ origin_sample_rate = int(info['streams'][0]['sample_rate'])
+ wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate)
+ wav = wav.squeeze_(0)
+ except Exception as e:
+ logger.error(f"Error reading {filepath}: {e}")
+ from traceback import print_exc
+ print_exc()
+ wav = torch.zeros(self.n_samples, dtype=torch.float32)
+ return wav
+
+
+class PromptTemplate:
+ def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'):
+ self.template_text = template_text
+ self.tag_map = tag_map
+ self.lang = lang
+
+ @property
+ def tags(self):
+ return tuple(self.tag_map.keys())
+
+ def apply(self, **kwargs):
+ for tag in list(kwargs.keys()):
+ if kwargs[tag] == '':
+ kwargs.pop(tag)
+ for tag in self.tags:
+ if tag in kwargs:
+ kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]')
+ else:
+ kwargs[tag] = ''
+ prompt = self.template_text.format(**kwargs)
+
+ return self.beautify(prompt)
+
+ def beautify(self, text):
+ if self.lang == 'en':
+ return self._beautify_en(text)
+ elif self.lang == 'zh':
+ return self._beautify_zh(text)
+ else:
+ raise ValueError(f'Unknown language {self.lang}')
+
+ @staticmethod
+ def _beautify_en(text):
+ # no continuous commas without content between them
+ text = re.sub(r'[,\s]*,[,\s]*', r', ', text)
+ # no continuous whitespace
+ text = re.sub(r'\s+', ' ', text)
+ # the comma is NOT followed by whitespace, and should be followed by ONE whitespace
+ text = re.sub(r'\s+,', r',', text)
+ text = re.sub(r',\s+', r', ', text)
+ # no whitespace before the full stop
+ text = re.sub(r'\s+\.', r'.', text)
+ # strip whitespace, comma, and replace ',.'
+ text = text.strip(' ,')
+ text = text.replace(',.', '.')
+ return text
+
+ @staticmethod
+ def _beautify_zh(text):
+ # no continuous commas without content between them
+ text = re.sub(r'[,、\s]*,[,、\s]*', r',', text)
+ text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text)
+ # assume there should be NO whitespace in Chinese
+ text = re.sub(r'\s+', r'', text)
+ # strip whitespace, comma, and replace ',。'
+ text = text.strip(', 、')
+ text = text.replace(',。', '。')
+ return text
+
+ def __repr__(self):
+ return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})'
+
+ __str__ = __repr__
+
+def parse_prompt_template(prompt_template_text, lang='en'):
+ span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL)
+ tag_pattern = re.compile(r'{.+?}', re.DOTALL)
+
+ template_text = prompt_template_text.strip()
+ span_texts = span_pattern.findall(prompt_template_text)
+ tag_map = {}
+ for span_text in span_texts:
+ tag = tag_pattern.findall(span_text)[0].strip('{}')
+ tag_map[tag] = span_text
+ template_text = template_text.replace(span_text, '{'+tag+'}')
+
+ return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang)
+
+def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]:
+ with open(path, 'r') as f:
+ lines = f.readlines()
+ cnt = 0
+ pts = []
+ for line in lines:
+ pt = parse_prompt_template(line, lang=lang)
+ cnt += 1
+ if len(pt.tags) < num:
+ logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}')
+ pts.append(pt)
+
+ return pts
+
+
+def get_base_dir_file(key: os.PathLike):
+ base = os.path.basename(key)
+ dirname = os.path.basename(os.path.dirname(key))
+ return os.path.join(dirname, base)
+
+def read_jsonlike(path: os.PathLike):
+ #json or jsonl
+ if str(path).endswith(".json"):
+ with open(path, 'r', encoding='utf8') as f:
+ data = json.load(f)
+ return data
+ elif str(path).endswith(".jsonl"):
+ with open(path, 'r', encoding='utf8') as f:
+ data = [json.loads(line) for line in f.readlines()]
+ return data
+ else:
+ raise ValueError("Unknown file format")
+
+dist_prob_map = {
+ 1: (1.0,),
+ 2: (0.5, 0.5),
+ 3: (0.3, 0.4, 0.3),
+ 4: (0.2, 0.3, 0.3, 0.2),
+ 5: (0.2, 0.2, 0.3, 0.2, 0.1),
+ 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15),
+ 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1),
+ 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12),
+ 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08),
+ 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09)
+}
+
+dist_prob_map_low = {
+ 1: (1.0,),
+ 2: (0.8, 0.2),
+ 3: (0.8, 0.1, 0.1),
+ 4: (0.7, 0.1, 0.1, 0.1),
+ 5: (0.7, 0.1, 0.1, 0.05, 0.05),
+ 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05),
+}
+
+def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]):
+ if translate is None:
+ return None
+ if isinstance(translate, str):
+ return read_jsonlike(translate)
+ return {k: read_jsonlike(path) for k, path in translate.items()}
+
+
+def gen_plain_prompt(key_list, sep=', '):
+ if len(key_list) == 0:
+ return 'none'
+
+ key_list = [k.strip() for k in key_list]
+
+ if len(key_list) > 10:
+ random.shuffle(key_list)
+ key_list = key_list[:10]
+
+ probs = dist_prob_map[len(key_list)]
+
+ num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0]
+
+ random.shuffle(key_list)
+ tags = key_list[:num_tags]
+ tags_str = sep.join(tags)
+ return tags_str
+
+
+def tags_to_desc(tag_list, sep=',') -> str:
+ if not isinstance(tag_list, Sequence):
+ return str(tag_list)
+ if isinstance(tag_list, str):
+ return tag_list
+ if len(tag_list) <= 0:
+ return ''
+ elif len(tag_list) <= 5:
+ probs = dist_prob_map[len(tag_list)]
+ tags_num = random.choices(range(1, len(tag_list)+1), probs)[0]
+ random.shuffle(tag_list)
+ tag_list = tag_list[:tags_num]
+ return sep.join(tag_list)
+ else:
+ probs = dist_prob_map[5]
+ tags_num = random.choices(range(1, 6), probs)[0]
+ random.shuffle(tag_list)
+ tag_list = tag_list[:tags_num]
+ return sep.join(tag_list)
+
+def get_sr_and_duration_info(item):
+ return item.get('sample_rate', None), item.get('duration', None)
+
+class MtgJamendoDatasetFromJson(Dataset):
+ def __init__(self,
+ data_dir:str,
+ json_path:str,
+ duration:float=10,
+ sr:int = 0,
+ lang = 'en',
+ plain_rate = 0,
+ return_audio = True,
+ return_path = False,
+ prompt_template_path: os.PathLike = None,
+ tag_types = [],
+ translate:Optional[Dict[str, os.PathLike]] = None,
+ use_literal_none = True,
+ ):
+ self.audio_reader = SafeAudioReader(duration, sr)
+
+ self.data_dir = data_dir
+ self._load_metadata_json(json_path)
+ self.sr = sr
+ self.duration = duration
+ self.plain_rate = plain_rate
+ self.return_audio = return_audio
+ self.return_path = return_path
+ self.use_literal_none = use_literal_none
+ self.lang = lang
+
+ self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0
+ if self.use_dynamic_prompt:
+ self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types))
+ self.tag_types = tag_types
+
+ self.translate = read_translate(translate)
+
+ # These tags are considered to be weak semantics, avoiding text prompts containing only these tags
+ WEAK_TAG_LIST = ["title", "artist"]
+
+ def _load_metadata_json(self, json_path):
+ with open(json_path) as fp:
+ self.data = json.load(fp)
+
+ def convert_key_to_path(self, key):
+ return os.path.join(self.data_dir, get_base_dir_file(key))
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ item = self.data[idx]
+ path = self.convert_key_to_path(item['key'])
+ description = self.generate_description(item)
+
+ if self.return_audio:
+ sr, duration = get_sr_and_duration_info(item)
+ audio = self.audio_reader(path, sr, duration)
+ else:
+ audio = None
+
+ if self.return_path:
+ return audio, description, path
+ return audio, description
+
+ def tags_to_desc(self, tag_list, tag_type) -> str:
+ if self.lang == 'en':
+ return tags_to_desc(tag_list)
+ elif self.lang == 'zh':
+ translator = self.translate[tag_type]
+ translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ]
+ return tags_to_desc(translated_tag_list, sep='、')
+
+ def generate_description(self, item):
+ if random.random() > self.plain_rate:
+ # dynamically generate prompt from given prompt template
+ prompt_template = random.choice(self.prompt_templates)
+ description = self.generate_description_dynamic(item, prompt_template)
+ else:
+ # use plain prompt, i.e. tags sequence separated by comma
+ description = self.generate_description_plain(item)
+ return description
+
+ def generate_description_dynamic(self, data, prompt_template: PromptTemplate):
+ exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)]
+ exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag))
+ exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag))
+
+ if len(exists_strong_tag) > 0:
+ probs = dist_prob_map[len(exists_strong_tag)]
+ tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0]
+ random.shuffle(exists_strong_tag)
+ tags = exists_strong_tag[:tags_num]
+ weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1]
+ weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0]
+ random.shuffle(exists_weak_tag)
+ weak_tags = exists_weak_tag[:weak_tags_num]
+ tags += weak_tags
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags}
+ prompt = prompt_template.apply(**tags_args)
+ else:
+ # no strong tags, use all weak tags instead
+ tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag}
+ prompt = prompt_template.apply(**tags_args)
+
+ if self.use_literal_none and len(tags_args) == 0:
+ return 'none'
+
+ return prompt
+
+ def generate_description_plain(self, item):
+ keywords = []
+ for tag_t in self.tag_types:
+ this_key = item[tag_t]
+ if this_key is None:
+ continue
+ if isinstance(this_key, str):
+ this_key = [this_key]
+ if self.lang != 'en':
+ this_key = [self.get_translation(tag_t, k) for k in this_key]
+ keywords += this_key
+ return gen_plain_prompt(keywords, sep=self.keysep)
+
+ def get_translation(self, tag_t, k):
+ k = k.strip()
+ if k in self.translate[tag_t]:
+ return self.translate[tag_t][k]
+ else:
+ return k
+
+ @property
+ def keysep(self):
+ if self.lang == 'zh':
+ return ',' if random.random() > 0.5 else '、'
+ elif self.lang == 'en':
+ return ', '
+
+
+class CombinedDataset(Dataset):
+ @beartype
+ def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]):
+ self.datasets = datasets
+ self.datasets_index = []
+
+ for i,dataset in enumerate(datasets):
+ if dataset is None:
+ continue
+ for dup in range(ratios[i]):
+ for j in range(len(dataset)):
+ self.datasets_index.append((i,j))
+
+ def __len__(self):
+ return len(self.datasets_index)
+
+ def __getitem__(self, idx):
+ index = self.datasets_index[idx]
+ i,j = index
+ return self.datasets[i][j]
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/trainer.py b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c6d32e40a91705e0b55f5c23bf6c8e5e2f5a09d
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/musiclm_pytorch/trainer.py
@@ -0,0 +1,429 @@
+import copy
+from math import sqrt
+from random import choice
+from pathlib import Path
+from shutil import rmtree
+from functools import wraps, partial
+
+from typing_extensions import Annotated
+
+from beartype import beartype
+from beartype.door import is_bearable
+from beartype.vale import Is
+from beartype.typing import Union, List, Optional, Tuple, Callable, Any
+
+import torch
+from torch import nn
+from torch.optim import Adam
+from torch.utils.data import Dataset, DataLoader, random_split
+from torch.nn.utils.rnn import pad_sequence
+
+from lion_pytorch import Lion
+
+from einops import rearrange
+
+from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
+
+from loguru import logger
+
+from torchaudio.transforms import Resample
+
+import os
+#surpress warnings on huggingface model cause by num_workers!=0 in dataloader
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+#solving nccl childFailedError
+# os.environ['NCCL_DEBUG'] = 'INFO'
+os.environ['NCCL_P2P_DISABLE']='1'
+os.environ['NCCL_IB_GID_INDEX']='3'
+
+# for automatically routing data emitted from a dataset to keywords of the transformer wrappers
+
+DATASET_FIELD_TYPE_CONFIG = dict(
+ wavs = Annotated[
+ torch.Tensor,
+ Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
+ ],
+ raw_texts = List[str],
+ texts = Annotated[
+ torch.Tensor,
+ Is[lambda t: t.dtype == torch.long and t.ndim == 2]
+ ],
+)
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg
+ return None
+
+def noop(*args, **kwargs):
+ pass
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def cast_tuple(t):
+ return t if isinstance(t, (tuple, list)) else (t,)
+
+def yes_or_no(question):
+ return True
+ #NOTE: pause using interactive input for debugging convenience
+ answer = input(f'{question} (y/n) ')
+ return answer.lower() in ('yes', 'y')
+
+def accum_log(log, new_logs):
+ for key, new_value in new_logs.items():
+ old_value = log.get(key, 0.)
+ log[key] = old_value + new_value
+ return log
+
+# auto data to module keyword argument routing functions
+
+def has_duplicates(tup):
+ counts = dict()
+ for el in tup:
+ if el not in counts:
+ counts[el] = 0
+ counts[el] += 1
+ return any(filter(lambda count: count > 1, counts.values()))
+
+def determine_types(data, config):
+ output = []
+ for el in data:
+ for name, data_type in config.items():
+ if is_bearable(el, data_type):
+ output.append(name)
+ break
+ else:
+ raise TypeError(f'unable to determine type of {data}')
+
+ return tuple(output)
+
+# optimizer functions
+
+def separate_weight_decayable_params(params):
+ wd_params, no_wd_params = [], []
+ for param in params:
+ param_list = no_wd_params if param.ndim < 2 else wd_params
+ param_list.append(param)
+ return wd_params, no_wd_params
+
+# dataloader functions
+
+def collate_one_or_multiple_tensors(fn):
+ @wraps(fn)
+ def inner(data):
+ is_one_data = not isinstance(data[0], tuple)
+
+ if is_one_data:
+ data = torch.stack(data)
+ return (data,)
+
+ outputs = []
+ for datum in zip(*data):
+ if is_bearable(datum, Tuple[str, ...]):
+ output = list(datum)
+ else:
+ output = fn(datum)
+
+ outputs.append(output)
+
+ return tuple(outputs)
+
+ return inner
+
+@collate_one_or_multiple_tensors
+def curtail_to_shortest_collate(data):
+ min_len = min(*[datum.shape[0] for datum in data])
+ data = [datum[:min_len] for datum in data]
+ return torch.stack(data)
+
+@collate_one_or_multiple_tensors
+def pad_to_longest_fn(data):
+ return pad_sequence(data, batch_first = True)
+
+def get_dataloader(ds, pad_to_longest = True, **kwargs):
+ # collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
+ # return DataLoader(ds, collate_fn = collate_fn, **kwargs)
+ return DataLoader(ds, **kwargs)
+
+# semantic transformer trainer
+
+class MuLaNTrainer(nn.Module):
+ def __init__(
+ self,
+ config: Any,
+ mulan: nn.Module,
+ dataset: Dataset,
+ val_dataset: Optional[Dataset] = None,
+ *,
+ num_train_steps = None,
+ batch_size,
+ data_max_length = None,
+ lr = 3e-5,
+ sr = 24000,
+ orig_sr = 44100,
+ num_workers = 0,
+ grad_accum_every = 10,
+ betas = (0.9, 0.99),
+ max_grad_norm = 0.5,
+ valid_every = 1,
+ valid_frac = 0.05,
+ random_split_seed = 42,
+ save_model_every = 10,
+ results_folder = './results',
+ accelerate_kwargs: dict = dict(),
+ use_lion = False,
+ force_clear_prev_results = False, # set to True | False to skip the prompt
+ resume_from_checkpoint = None,
+ load_optimizer_when_resume = True,
+ ):
+ super().__init__()
+ assert batch_size > 1, 'batch size must be greater than 1 for contrastive learning (but ideally as large as possible)'
+
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
+ self.config = config
+
+ self.accelerator = Accelerator(**accelerate_kwargs, kwargs_handlers=[ddp_kwargs])
+
+ self.resampler = Resample(orig_sr, sr).to(self.device)
+
+ self.mulan = mulan
+
+ self.register_buffer('steps', torch.Tensor([0]))
+
+ self.num_train_steps = default(num_train_steps, float('inf')) # infinite by default
+ self.batch_size = batch_size
+ self.grad_accum_every = grad_accum_every
+
+ # optimizers
+
+ optim_klass = Lion if use_lion else Adam
+ self.optim = optim_klass(filter(lambda p:p.requires_grad, mulan.parameters()), lr = lr, betas = betas)
+
+ # max grad norm
+
+ self.max_grad_norm = max_grad_norm
+
+ self.data_max_length = data_max_length
+
+
+ # create dataset
+
+ self.ds = dataset
+ self.ds_fields = None
+
+ # split for validation
+ if val_dataset is None:
+ self.print("training with random split dataset as valid set, this can be dangerous, use with caution!")
+ train_size = int((1 - valid_frac) * len(self.ds))
+ valid_size = len(self.ds) - train_size
+ self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
+ self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
+ else:
+ self.print("training with fixed validation set")
+ self.valid_ds = val_dataset
+
+ # dataloader
+ self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True, num_workers=num_workers)
+
+ self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True, pad_to_longest = False, drop_last = True, num_workers=num_workers)
+
+ # handle resume
+ if resume_from_checkpoint:
+ self._load(resume_from_checkpoint, load_optimizer = load_optimizer_when_resume)
+ print("resume from", resume_from_checkpoint)
+
+ def getsize(model):
+ import numpy as np
+ s = 0
+ for param in model.parameters():
+ s += np.product(param.size())
+ print("[INFO] # of mulan's parameters: "+str(s/1024.0/1024.0))
+
+
+ getsize(self.mulan)
+
+ # prepare with accelerator
+
+ (
+ self.mulan,
+ self.optim,
+ self.dl,
+ self.valid_dl
+ ) = self.accelerator.prepare(
+ self.mulan,
+ self.optim,
+ self.dl,
+ self.valid_dl
+ )
+
+ # dataloader iterators
+
+ self.dl_iter = cycle(self.dl)
+ self.valid_dl_iter = cycle(self.valid_dl)
+
+ self.valid_every = valid_every
+ self.save_model_every = save_model_every
+
+ hps = dict(
+ num_train_steps = num_train_steps,
+ data_max_length = data_max_length,
+ learning_rate = lr
+ )
+
+ self.accelerator.init_trackers("tb", config = hps)
+
+ # results folder
+
+ self.results_folder = Path(results_folder)
+
+ if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
+ rmtree(str(self.results_folder))
+
+ self.results_folder.mkdir(parents = True, exist_ok = True)
+
+ # to device
+
+ self.mulan.to(self.device)
+
+ def save(self, path):
+ state_dict = self.accelerator.get_state_dict(self.mulan)
+ pkg = dict(
+ model = state_dict,
+ optim = self.optim.state_dict(),
+ config = self.config,
+ )
+ torch.save(pkg, path)
+
+ def _load(self, path, load_optimizer = True):
+ path = Path(path)
+ assert path.exists()
+ pkg = torch.load(str(path), map_location = 'cpu')
+
+ mulan = self.mulan
+
+ if 'model' in pkg:
+ model_state_dict = pkg['model']
+ else:
+ model_state_dict = pkg
+
+ if not load_optimizer and 'contrast.temperatures' in model_state_dict:
+ del model_state_dict['contrast.temperatures']
+ mulan.load_state_dict(model_state_dict, strict=False)
+
+ if load_optimizer:
+ if 'base_optimizer_state' in pkg['optim']:
+ optim_state = pkg['optim']['base_optimizer_state']
+ else:
+ optim_state = pkg['optim']
+ # self.optim.load_state_dict(optim_state)
+ # FIXME: There is a problem with the optimizer loading. Let's do this for the time being
+
+ def print(self, msg):
+ self.accelerator.print(msg)
+ logger.info(msg)
+
+ @property
+ def device(self):
+ return self.accelerator.device
+
+ @property
+ def is_distributed(self):
+ return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
+
+ @property
+ def is_main(self):
+ return self.accelerator.is_main_process
+
+ @property
+ def is_local_main(self):
+ return self.accelerator.is_local_main_process
+
+ def data_tuple_to_kwargs(self, data):
+
+ data_kwargs = dict(
+ wavs = data[0],
+ raw_texts = data[1],
+ )
+
+ wavs = data_kwargs['wavs']
+ wavs = self.resampler(wavs)
+ data_kwargs.update(wavs = wavs[..., :self.data_max_length]) #Use fixed maximum length audio as in Mulan
+
+ return data_kwargs
+
+ def train_step(self):
+
+ self.print(f"start steps: {self.steps.item()}")
+
+ device = self.device
+
+ steps = int(self.steps.item())
+
+ self.mulan.train()
+
+ logs = {}
+
+
+ for _ in range(self.grad_accum_every):
+ data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
+
+ loss = self.mulan(**data_kwargs)
+
+ self.accelerator.backward(loss / self.grad_accum_every )
+
+ accum_log(logs, {'loss/train': loss.item() / self.grad_accum_every })
+
+ del data_kwargs, loss
+
+ if exists(self.max_grad_norm):
+ self.accelerator.clip_grad_norm_(self.mulan.parameters(), self.max_grad_norm)
+
+ self.optim.step()
+ self.optim.zero_grad()
+
+ # log
+
+ self.print(f"{steps}: loss: {logs['loss/train']}")
+
+ # valid
+ if not (steps % self.valid_every):
+ with torch.no_grad():
+ # self.mulan.eval()
+ for _ in range(self.grad_accum_every):
+ data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))
+
+ loss = self.mulan(**data_kwargs)
+
+ accum_log(logs, {'loss/valid': loss.item() / self.grad_accum_every})
+
+ self.print(f"{steps}: valid_loss: {logs['loss/valid']}")
+
+ self.accelerator.log(logs, step = steps)
+
+ # save model every so often
+ if self.is_main and not (steps % self.save_model_every):
+ model_path = str(self.results_folder / f'mulan.{steps}.pt')
+ self.save(model_path)
+
+ self.print(f'{steps}: saving model to {str(self.results_folder)}')
+ #'''
+ self.steps += 1
+ return logs
+
+ def train(self, log_fn: Callable = noop):
+
+ while self.steps < self.num_train_steps:
+ logs = self.train_step()
+ log_fn(logs)
+
+ self.print('training complete')
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/requirements.txt b/src/third_party/MuQ/src/recipes/contrastive_learning/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fc38b60bb72325c9f6671676546c00eac0f9c6e3
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/requirements.txt
@@ -0,0 +1,31 @@
+accelerate
+audiolm-pytorch
+beartype
+bidict
+datasets
+deepspeed
+einops
+fairseq
+hydra-core==1.3
+librosa
+lion-pytorch
+loguru
+lxml
+networkx
+soundfile
+soxr
+tensorboardx
+torch
+torchaudio
+torchvision
+tqdm
+# transformers
+triton
+typing-extensions
+vector-quantize-pytorch
+x-clip
+xxhash
+yarl
+zipp
+ffmpeg-python
+fire
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/convert_muqmulan_fairseq_ckpt_to_huggingface.py b/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/convert_muqmulan_fairseq_ckpt_to_huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..328d5c33dc9bd8fc68f8b32bdd9094b6d2bd5647
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/convert_muqmulan_fairseq_ckpt_to_huggingface.py
@@ -0,0 +1,32 @@
+import os
+import torch
+import json
+import argparse
+from dataclasses import dataclass
+
+import torch
+from transformers import PretrainedConfig
+from omegaconf import OmegaConf
+
+from safetensors.torch import save_file
+
+def main(args):
+ checkpoint_path = args.checkpoint_path
+ save_dir = args.save_dir
+
+ # save model and config.json
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
+ os.makedirs(save_dir, exist_ok=True)
+ torch.save(ckpt['model'], os.path.join(save_dir, 'pytorch_model.bin'))
+ save_file(ckpt['model'], os.path.join(save_dir, 'model.safetensors'))
+ with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf8') as w:
+ config = OmegaConf.to_container(ckpt['config']['model'], resolve=True)
+ json.dump(config, w, indent=2)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", type=str, help="path to the fairseq checkpoint")
+ parser.add_argument("--save_dir", type=str, help="path to the result directory")
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/inference_muqmulan_huggingface.py b/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/inference_muqmulan_huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..36dfac8a6df77111cc5b9ef92af0c1f3246aafd5
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/scripts/inference_muqmulan_huggingface.py
@@ -0,0 +1,22 @@
+import torch, librosa
+from muq import MuQMuLan
+
+# This will automatically fetch checkpoints from huggingface
+device = 'cuda'
+mulan = MuQMuLan.from_pretrained("outputs/hf-username/My-MuQ-MuLan-large")
+mulan = mulan.to(device).eval()
+
+# Extract music embeddings
+wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
+wavs = torch.tensor(wav).unsqueeze(0).to(device)
+with torch.no_grad():
+ audio_embeds = mulan(wavs = wavs)
+
+# Extract text embeddings (texts can be in English or Chinese)
+texts = ["classical genres, hopeful mood, piano.", "一首适合海边风景的小提琴曲,节奏欢快"]
+with torch.no_grad():
+ text_embeds = mulan(texts = texts)
+
+# Calculate dot product similarity
+sim = mulan.calc_similarity(audio_embeds, text_embeds)
+print(sim)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/setup.py b/src/third_party/MuQ/src/recipes/contrastive_learning/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..d08100aaa4337904e2fdbd8a6e8aefc3cc1f58ac
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/setup.py
@@ -0,0 +1,50 @@
+from setuptools import setup, find_packages
+
+setup(
+ name = 'musiclm-pytorch',
+ packages = find_packages(exclude=[]),
+ version = '0.2.8',
+ license='MIT',
+ description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
+ author = 'Phil Wang',
+ author_email = 'lucidrains@gmail.com',
+ long_description_content_type = 'text/markdown',
+ url = 'https://github.com/lucidrains/musiclm-pytorch',
+ keywords = [
+ 'artificial intelligence',
+ 'deep learning',
+ 'transformers',
+ 'attention mechanism',
+ 'text to music',
+ 'contrastive learning'
+ ],
+ install_requires=[
+ 'numpy==1.23.5',
+ 'accelerate',
+ 'audiolm-pytorch>=0.17.0',
+ 'beartype',
+ 'einops>=0.6',
+ 'lion-pytorch',
+ 'vector-quantize-pytorch>=1.0.0',
+ 'x-clip',
+ 'torch>=1.12',
+ 'torchaudio',
+ 'loguru',
+ 'pandas',
+ 'librosa',
+ 'nnAudio',
+ 'bidict',
+ 'hydra-core>=1.3.0',
+ 'deepspeed',
+ 'diffusers[torch]',
+ 'ffmpeg-python',
+ 'tensorboardX',
+ ],
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'Intended Audience :: Developers',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 3.6',
+ ],
+)
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/train.py b/src/third_party/MuQ/src/recipes/contrastive_learning/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6066430d76a2587c51ba0b0d43646072220cdd65
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/train.py
@@ -0,0 +1,168 @@
+import torch
+from musiclm_pytorch.trainer import MuLaNTrainer
+
+from musiclm_pytorch.dataset import CombinedDataset, MtgJamendoDatasetFromJson
+from accelerate.utils import ProjectConfiguration
+
+import os
+import sys
+
+import random
+import numpy as np
+import torch
+
+from loguru import logger
+import hydra
+from omegaconf import DictConfig, OmegaConf
+
+from muq import MuQMuLan
+
+def gen_tag():
+ import time
+ import random
+ return "%d_%03d" % (time.time(), random.randint(0, 999))
+
+def get_or_default(value, default):
+ return value if value is not None else default
+
+def set_random_seed(seed):
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
+def create_datasets_from_config(config):
+ dataset_names = []
+ datasets = []
+ valid_datasets = []
+ ratios = []
+
+ sound_valid_datasets = []
+ sound_ratios = []
+
+ for dataset_name in config.dataset.keys():
+ cfg = config.dataset[dataset_name]
+ if not cfg.use:
+ continue
+ # print(f"Creating dataset {dataset_name}")
+
+ deco_tag_kwargs = cfg.get('deco_tag', dict(switch_on = False))
+ dataset_names.append(dataset_name)
+
+ if dataset_name.startswith('mtg_jamendo_json'):
+ mtg_dataset = MtgJamendoDatasetFromJson(
+ data_dir = cfg.data_dir,
+ json_path = cfg.json_path.train,
+ duration= cfg.duration,
+ sr = config.basics.orig_sr,
+ plain_rate = cfg.get('plain_rate', 0),
+ prompt_template_path= cfg.prompt,
+ tag_types = cfg.tag_types,
+ lang = cfg.get('lang', 'en'),
+ translate = cfg.get('translate', None)
+ )
+ datasets.append(mtg_dataset)
+ ratios.append(cfg.ratio)
+ if cfg.json_path.valid is not None:
+ valid_mtg_dataset = MtgJamendoDatasetFromJson(
+ data_dir = cfg.data_dir,
+ json_path = cfg.json_path.valid,
+ duration= cfg.duration,
+ sr = config.basics.orig_sr,
+ plain_rate = cfg.get('plain_rate', 0),
+ prompt_template_path = cfg.prompt,
+ tag_types = cfg.tag_types,
+ lang = cfg.get('lang', 'en'),
+ translate = cfg.get('translate', None)
+ )
+ valid_datasets.append(valid_mtg_dataset)
+ else:
+ valid_datasets.append(None)
+
+ print(
+ dataset_name, len(mtg_dataset), mtg_dataset[0]
+ )
+
+ else:
+ raise ValueError("Unknown dataset type: %s" % dataset_name)
+
+ if config.get('dataset_test', False):
+ if isinstance(config.dataset_test, str):
+ start_idx = int(config.dataset_test)
+ else:
+ start_idx = 0
+ for i in range(start_idx, start_idx + 100):
+ for ds_name, ds in zip(dataset_names, datasets):
+ if ds is None:
+ continue
+ print(f'{ds_name}: {ds[i]}')
+ input('continue?')
+
+
+ dataset = CombinedDataset(datasets=datasets, ratios=ratios)
+ val_dataset = CombinedDataset(datasets=valid_datasets, ratios=ratios) if any(valid_datasets) else None
+
+ print("dataset len:", len(dataset))
+ print("val_datatset len:", len(val_dataset))
+
+ return dataset, val_dataset
+
+
+@hydra.main(config_path='config', config_name='train')
+def main(config:DictConfig):
+ print(config)
+ ### read basics config
+ save_dir = config.basics.get('save_dir', os.getcwd())
+ os.makedirs(os.path.join(save_dir, 'tb'), exist_ok=True)
+ os.makedirs(os.path.join(save_dir, 'result'), exist_ok=True)
+ os.makedirs(os.path.join(save_dir, 'ckpt'), exist_ok=True)
+ tag = gen_tag()
+ out_fname = "output." + str(tag) + '.log'
+ out_fpath = os.path.join(save_dir, 'result', out_fname)
+ logger.add(out_fpath, level='INFO', format='{time} | {level} | {message}')
+ logger.info("Output log saved to " + out_fpath)
+
+
+ if config.basics.random_seed:
+ set_random_seed(config.basics.random_seed)
+
+ ### read dataset config
+
+ dataset, val_dataset = create_datasets_from_config(config)
+
+
+ ### read model config
+
+ mulan = MuQMuLan.create_MuLan_from_config(config.model)
+
+ ### read training config
+
+ trainer = MuLaNTrainer(
+ config = config,
+ mulan = mulan,
+ dataset = dataset,
+ val_dataset = val_dataset,
+ batch_size = config.train.batch_size,
+ num_workers = config.train.num_workers,
+ num_train_steps = config.train.get('num_train_steps', None),
+ lr = config.train.lr,
+ sr = config.basics.sr,
+ orig_sr = config.basics.orig_sr,
+ data_max_length = config.train.data_max_secs * config.basics.sr,
+ save_model_every = config.train.save_model_every,
+ valid_every = config.train.valid_every,
+ random_split_seed = get_or_default(config.basics.random_seed, 42),
+ results_folder=os.path.join(save_dir, 'ckpt'),
+ accelerate_kwargs = dict(log_with="tensorboard", project_dir=save_dir) if config.train.log_tensorboard else dict(),
+ resume_from_checkpoint = config.train.resume.checkpoint_path if config.train.resume.use else None,
+ load_optimizer_when_resume = config.train.resume.get('load_optimizer', True),
+ )
+
+ logger.info("Ready to start training.")
+ trainer.train()
+
+if __name__ == '__main__':
+ main()
diff --git a/src/third_party/MuQ/src/recipes/contrastive_learning/utils.py b/src/third_party/MuQ/src/recipes/contrastive_learning/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9d01e5291c548a630f3e4fc2ca15035eaa319aa
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/contrastive_learning/utils.py
@@ -0,0 +1,63 @@
+from omegaconf import DictConfig
+import os
+
+def get_pretrained_config(root, name):
+ if root is None:
+ return name
+ path = os.path.join(root, name)
+ config_dir = os.path.join(path, 'snapshots')
+ if os.path.exists(config_dir):
+ config_files = os.listdir(config_dir)
+ assert len(config_files) == 1
+ config_path = os.path.join(config_dir, config_files[0])
+ else:
+ config_path = path
+ return config_path
+
+def create_CLAP_model( model_kwargs = {}, ckpt_path = None ):
+ from musiclm_pytorch import SoftmaxContrastiveLearning
+ import laion_clap
+
+ from torch import nn
+ import torch
+ from torchaudio.functional import resample
+
+ import numpy as np
+
+ from functools import partial
+
+ # quantization
+ def int16_to_float32(x):
+ return (x / 32767.0).float()
+
+ def float32_to_int16(x):
+ x = torch.clip(x, min=-1., max=1.)
+ return (x * 32767.).int()
+
+ model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs)
+ if ckpt_path is not None:
+ model.load_ckpt(ckpt_path)
+ else:
+ model.load_ckpt()
+
+ class CLAP_Model(nn.Module):
+ def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True):
+ super().__init__()
+ self.model = model
+ self.model.eval()
+ self.orig_sr = sr
+
+ klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
+ self.contrast = klass()
+
+
+ def forward(self, wavs, raw_texts):
+ with torch.no_grad():
+ wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000)))
+ audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float()
+ text_latents = model.get_text_embedding(raw_texts, use_tensor=True)
+ cl_loss = self.contrast(audio_latents, text_latents)
+ return cl_loss
+
+ clap = CLAP_Model(model)
+ return clap
diff --git a/src/third_party/MuQ/src/recipes/pretrain/README.md b/src/third_party/MuQ/src/recipes/pretrain/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d531d05aa47eff7ceaee1b6ebe85bc7a68efa5e5
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/README.md
@@ -0,0 +1,83 @@
+`
+# Guidance on MuQ Pretraining (SSL)
+
+This page provides a detailed explanation of the MuQ pretraining process, including an example training setup using the **open-source [Music4All dataset](https://sites.google.com/view/contact4music4all)**.
+
+To train MuQ effectively, we recommend using **32 GPUs**, each with **at least 32 GB** of VRAM.
+The self-supervised pretraining process may take approximately **3 days to 1 week** to complete.
+
+Before getting started, initialize and update the submodules to sync the fairseq directory:
+```
+git submodule init
+git submodule update
+```
+Then, please make sure to install all required dependencies listed in [requirements.txt](./requirements.txt).
+
+## Step1: Data Preparation
+
+You need to download the Music4All dataset from the [Music4All site](https://sites.google.com/view/contact4music4all).
+Then, download our [music4all splits index](https://drive.google.com/file/d/1GbZRzPZP989j1b4SYFSlRkBJryykNCSF/view?usp=sharing) files and place it in the `src/recipes/pretrain/dataset` directory. Note that you must adjust the `path` field within the splits files to align with your local path.
+
+While we provide an example using Music4All here, we strongly recommend performing self-supervised pretraining on larger datasets such as [MSD](https://labs.acousticbrainz.org/million-song-dataset-echonest-archive/) for better performance.
+
+
+
+## Step2: Prepare the Mel-RVQ
+
+MuQ relies on a pretrained Mel-RVQ model for training. We provide a pretrained Mel-RVQ checkpoint trained on the Music4All dataset [here](https://drive.google.com/file/d/1GgY9ZfZQOJFQqLPcAyxxw6yhu09WGtI1/view?usp=sharing) for direct download. This checkpoint is suitable for music modality.
+
+If you want to train Mel-RVQ on your own dataset, you can run the following command:
+
+```
+python rvq_trainer.py --train_path dataset/music4all/train.json --valid_path dataset/music4all/valid.json
+```
+
+Training Mel-RVQ is very fast, it typically takes around 1 hour on a single GPU.
+
+
+## Step3: SSL Training
+To start self-supervised pretraining of MuQ, please run the following script. This script is designed for distributed training on 4 nodes, each with 8 GPUs (total 32 GPUs):
+
+```
+export NUM_NODES=4
+export GPUS_PER_NODE=8
+
+bash scripts/run_training_muq.sh $NODE_INDEX MuQ_large_multinodes_v100 MUQ $CHIEF_IP 25520 music4all $NUM_NODES $GPUS_PER_NODE
+```
+
+
+🔧 Arguments:
+
+* `$NODE_INDEX`: Index of the current node (0 for chief, 1, 2, … for workers).
+* `$CHIEF_IP`: IP address of the chief node (i.e., the node with `NODE_INDEX=0`).
+If you are **not using multi-node training**, simply set this to `127.0.0.1`.
+* `$NUM_NODES`: Total number of nodes (e.g., `1` for single-machine training, `4` for multi-node).
+* `$GPUS_PER_NODE`: Number of GPUs per node (e.g., `8`).
+
+
+## Step 4: Convert to HF Checkpoint
+
+After finishing the training, you can convert the trained Fairseq checkpoint to **HuggingFace format**, just run:
+
+```bash
+python scripts/convert_muq_fairseq_ckpt_to_huggingface.py \
+ --model_dir $PWD \
+ --checkpoint_path ./output/ckpt_MUQ_XXX/MuQ_large_multinodes_v100/checkpoint_XXX.pt \
+ --save_dir ./output/hf-username/My-MuQ-large-huggingface
+```
+
+Once converted, you can load the model using the HuggingFace-style interface:
+
+```python
+from muq import MuQ
+import torch
+
+# Load from local HuggingFace-style checkpoint
+muq = MuQ.from_pretrained("./output/hf-username/My-MuQ-large-huggingface")
+muq = muq.to(device).eval()
+
+with torch.no_grad():
+ output = muq(wavs, output_hidden_states=True)
+```
+
+You can also upload your converted checkpoint to the HuggingFace Hub using `huggingface-cli`. The model will be fully compatible with the `MuQ.from_pretrained()` interface. Feel free to share it :)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_iter_multinodes_v100.yaml b/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_iter_multinodes_v100.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a78f85321c5e9801903b0709dbce2631bb130b5
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_iter_multinodes_v100.yaml
@@ -0,0 +1,126 @@
+# @package _group_
+common:
+ fp16: false
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ # tensorboard_logdir: tblog_proj_name
+ # wandb_project: wandb_proj_name
+
+checkpoint:
+ save_interval_updates: 12500
+ keep_interval_updates: -1
+ no_epoch_checkpoints: true
+
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 64
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: muq_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ sample_rate: 24000
+
+ # crop to 30s
+ max_sample_size: null # 720000
+ min_sample_size: 720000
+ clip_secs: 30
+
+ pad_audio: false
+ random_crop: true
+ normalize: false # must be consistent with extractor
+
+ label_scp_path: ???
+ label_scp_clip_duration: 30
+
+
+dataset:
+ num_workers: 6
+ max_tokens: 2000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 1
+ validate_interval_updates: 10000
+
+criterion:
+ _name: model
+ # log_keys:
+ # - accuracies
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+ update_freq: [2]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: muq
+ label_rate: 25
+ num_codebooks: 1
+ codebook_dim: 16
+ codebook_size: 8192 # 4096
+ features: ["melspec_2048"]
+ hop_length: 240
+ n_mels: 128
+ conv_dim: 512
+ encoder_dim: 1024
+ encoder_depth: 12
+ mask_hop: 0.4
+ mask_prob: 0.6
+ is_flash: false
+
+ use_rvq_target: true
+ rvq_ckpt_path: null
+
+ stat: {melspec_2048_cnt: 14282760192, melspec_2048_mean: 6.768444971712967, melspec_2048_std: 18.417922652295623}
+ w2v2_config: {
+ activation_dropout: 0.1, adapter_kernel_size: 3, adapter_stride: 2, add_adapter: false, apply_spec_augment: true,
+ architectures: [Wav2Vec2ConformerForCTC], attention_dropout: 0.1, bos_token_id: 1, classifier_proj_size: 256,
+ codevector_dim: 768, conformer_conv_dropout: 0.1, contrastive_logits_temperature: 0.1, conv_bias: true,
+ conv_depthwise_kernel_size: 31, conv_dim: [512, 512, 512, 512, 512, 512, 512], conv_kernel: [10, 3, 3, 3, 3, 2, 2],
+ conv_stride: [5, 2, 2, 2, 2, 2, 2], ctc_loss_reduction: sum, ctc_zero_infinity: false, diversity_loss_weight: 0.1,
+ do_stable_layer_norm: true, eos_token_id: 2, feat_extract_activation: gelu, feat_extract_dropout: 0.0,
+ feat_extract_norm: layer, feat_proj_dropout: 0.1, feat_quantizer_dropout: 0.0, final_dropout: 0.1,
+ gradient_checkpointing: false, hidden_act: swish, hidden_dropout: 0.1, hidden_dropout_prob: 0.1, hidden_size: 1024,
+ initializer_range: 0.02, intermediate_size: 4096, layer_norm_eps: 1e-5, layerdrop: 0.0, mask_feature_length: 10,
+ mask_feature_min_masks: 0, mask_feature_prob: 0.0, mask_time_length: 10, mask_time_min_masks: 2, mask_time_prob: 0.05,
+ max_source_positions: 5000, model_type: wav2vec2-conformer, num_adapter_layers: 3, num_attention_heads: 16,
+ num_codevector_groups: 2, num_codevectors_per_group: 320, num_conv_pos_embedding_groups: 16,
+ num_conv_pos_embeddings: 128, num_feat_extract_layers: 7, num_hidden_layers: 24, num_negatives: 100,
+ output_hidden_size: 1024, pad_token_id: 0, position_embeddings_type: rotary, proj_codevector_dim: 768,
+ rotary_embedding_base: 10000, tdnn_dilation: [1, 2, 3, 1, 1], tdnn_dim: [512, 512, 512, 512, 1500],
+ tdnn_kernel: [5, 3, 3, 1, 1], torch_dtype: float32, transformers_version: 4.19.0.dev0,
+ use_weighted_layer_sum: false, vocab_size: 32, xvector_output_dim: 512
+ }
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_multinodes_v100.yaml b/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_multinodes_v100.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae18691233b928bbdd74d0e4ba948715608c48c1
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/config/pretrain/MuQ_large_multinodes_v100.yaml
@@ -0,0 +1,123 @@
+# @package _group_
+common:
+ fp16: false
+ log_format: json
+ log_interval: 200
+ seed: 1337
+ # tensorboard_logdir: tblog_proj_name
+ # wandb_project: wandb_proj_name
+
+checkpoint:
+ save_interval_updates: 12500
+ keep_interval_updates: -1
+ no_epoch_checkpoints: true
+
+
+distributed_training:
+ ddp_backend: no_c10d
+ distributed_backend: 'nccl'
+ distributed_world_size: 64
+ nprocs_per_node: 8
+ find_unused_parameters: true
+
+task:
+ _name: muq_pretraining
+ data: ???
+ label_dir: ???
+ labels: ???
+ label_rate: ${model.label_rate}
+ sample_rate: 24000
+
+ # crop to 30s
+ max_sample_size: 720000
+ min_sample_size: 432000
+ clip_secs: 30
+
+ pad_audio: false
+ random_crop: true
+ normalize: false # must be consistent with extractor
+
+
+dataset:
+ num_workers: 6
+ max_tokens: 2000000
+ skip_invalid_size_inputs_valid_test: true
+ validate_interval: 1
+ validate_interval_updates: 10000
+
+criterion:
+ _name: model
+ # log_keys:
+ # - accuracies
+
+optimization:
+ max_update: 400000
+ lr: [0.0005]
+ clip_norm: 10.0
+ update_freq: [2]
+
+optimizer:
+ _name: adam
+ adam_betas: (0.9,0.98)
+ adam_eps: 1e-06
+ weight_decay: 0.01
+
+lr_scheduler:
+ _name: polynomial_decay
+ warmup_updates: 32000
+
+model:
+ _name: muq
+ label_rate: 25
+ num_codebooks: 1
+ codebook_dim: 16
+ codebook_size: 8192 # 4096
+ features: ["melspec_2048"]
+ hop_length: 240
+ n_mels: 128
+ conv_dim: 512
+ encoder_dim: 1024
+ encoder_depth: 12
+ mask_hop: 0.4
+ mask_prob: 0.6
+ is_flash: false
+
+ use_rvq_target: true
+ rvq_ckpt_path: /inspire/hdd/project/project-public/niuzhikang-240108120093/hainazhu/MuQ/open/RVQ_4000.pth # ??? # Please specify this to your path/to/mel/rvq/RVQ_4000.pth
+
+ stat: {melspec_2048_cnt: 14282760192, melspec_2048_mean: 6.768444971712967, melspec_2048_std: 18.417922652295623}
+ w2v2_config: {
+ activation_dropout: 0.1, adapter_kernel_size: 3, adapter_stride: 2, add_adapter: false, apply_spec_augment: true,
+ architectures: [Wav2Vec2ConformerForCTC], attention_dropout: 0.1, bos_token_id: 1, classifier_proj_size: 256,
+ codevector_dim: 768, conformer_conv_dropout: 0.1, contrastive_logits_temperature: 0.1, conv_bias: true,
+ conv_depthwise_kernel_size: 31, conv_dim: [512, 512, 512, 512, 512, 512, 512], conv_kernel: [10, 3, 3, 3, 3, 2, 2],
+ conv_stride: [5, 2, 2, 2, 2, 2, 2], ctc_loss_reduction: sum, ctc_zero_infinity: false, diversity_loss_weight: 0.1,
+ do_stable_layer_norm: true, eos_token_id: 2, feat_extract_activation: gelu, feat_extract_dropout: 0.0,
+ feat_extract_norm: layer, feat_proj_dropout: 0.1, feat_quantizer_dropout: 0.0, final_dropout: 0.1,
+ gradient_checkpointing: false, hidden_act: swish, hidden_dropout: 0.1, hidden_dropout_prob: 0.1, hidden_size: 1024,
+ initializer_range: 0.02, intermediate_size: 4096, layer_norm_eps: 1e-5, layerdrop: 0.0, mask_feature_length: 10,
+ mask_feature_min_masks: 0, mask_feature_prob: 0.0, mask_time_length: 10, mask_time_min_masks: 2, mask_time_prob: 0.05,
+ max_source_positions: 5000, model_type: wav2vec2-conformer, num_adapter_layers: 3, num_attention_heads: 16,
+ num_codevector_groups: 2, num_codevectors_per_group: 320, num_conv_pos_embedding_groups: 16,
+ num_conv_pos_embeddings: 128, num_feat_extract_layers: 7, num_hidden_layers: 24, num_negatives: 100,
+ output_hidden_size: 1024, pad_token_id: 0, position_embeddings_type: rotary, proj_codevector_dim: 768,
+ rotary_embedding_base: 10000, tdnn_dilation: [1, 2, 3, 1, 1], tdnn_dim: [512, 512, 512, 512, 1500],
+ tdnn_kernel: [5, 3, 3, 1, 1], torch_dtype: float32, transformers_version: 4.19.0.dev0,
+ use_weighted_layer_sum: false, vocab_size: 32, xvector_output_dim: 512
+ }
+
+hydra:
+ job:
+ config:
+ override_dirname:
+ kv_sep: '-'
+ item_sep: '__'
+ exclude_keys:
+ - run
+ - task.data
+ - task.label_dir
+ run:
+ dir: ???
+ sweep:
+ dir: ???
+ subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
diff --git a/src/third_party/MuQ/src/recipes/pretrain/data/__init__.py b/src/third_party/MuQ/src/recipes/pretrain/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1aafc5d02bff3e82340b0e51d922dd28b4f7a7f
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/data/__init__.py
@@ -0,0 +1 @@
+from .mert_dataset import MERTDataset
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/data/ark_dataset.py b/src/third_party/MuQ/src/recipes/pretrain/data/ark_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f1a5cb04189c6a3e228e2f28815106eb6148948
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/data/ark_dataset.py
@@ -0,0 +1,109 @@
+import logging
+import torch
+import torch.nn.functional as F
+from fairseq.data.audio.raw_audio_dataset import RawAudioDataset
+from typing import Tuple
+try:
+ import kaldiio
+except:
+ kaldiio = None
+import warnings
+
+logger = logging.getLogger(__name__)
+
+
+class ArkDataset(RawAudioDataset):
+ def __init__(
+ self,
+ wav_scp,
+ dur_scp,
+ sr = 24000,
+ max_dur = 20,
+ num_buckets=0,
+ normalize=False,
+ ):
+ super().__init__(
+ sample_rate=sr,
+ max_sample_size=max_dur*sr,
+ min_sample_size=1200,
+ shuffle=True,
+ pad=True,
+ normalize=normalize,
+ compute_mask=False,
+ )
+ self.sr = sr
+ self.max_dur = max_dur
+ self.normalize = normalize
+
+ logger.info("Loading Kaldi scp files from {}".format(wav_scp))
+
+ self.wav_data = kaldiio.load_scp(wav_scp)
+ self.keys = list(self.wav_data.keys())
+ dur_data = {}
+ keys_set = set(self.keys)
+
+ with open(dur_scp, 'r') as f:
+ for line in f:
+ line = line.strip().split()
+ if line[0] in keys_set:
+ dur_data[line[0]] = float(line[-1])
+ self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys]
+
+ logger.info("Loading Kaldi scp files done")
+
+ self.dataset_len = len(self.keys)
+ self.set_bucket_info(num_buckets)
+
+ def __len__(self):
+ return self.dataset_len
+
+ def __getitem__(self, idx):
+ # print("getitem idx: ", idx)
+ try_cnt = 0
+ while True:
+ idx = idx + try_cnt
+ try:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ key = self.keys[idx]
+ # print(self.wav_data[key].keys())
+ wav = self.wav_data[key]['wav']
+
+ wav = torch.from_numpy(wav).float()
+ wav = self.postprocess(wav)
+ return {"id": idx, "source": wav}
+ except Exception as e:
+ try_cnt += 1
+ if try_cnt > 50:
+ return {"id": idx, "source": None}
+ continue
+
+ def size(self, idx):
+ return self.sizes[idx]
+
+ def postprocess(self, wav):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
+
+ def collater(self, samples):
+ return super().collater(samples)
+
+if __name__ == '__main__':
+ import torch
+ raw_tensor_str = torch.Tensor.__repr__
+ torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self)
+
+ ds = ArkDataset(
+ wav_scp='data/ark_demo/wav_ark.scp',
+ dur_scp='data/ark_demo/dur_ark.scp',
+ sr=24000,
+ )
+
+ for i in range(len(ds)):
+ print(ds[i])
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/data/mert_dataset.py b/src/third_party/MuQ/src/recipes/pretrain/data/mert_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..abecd750a627906b4bd8e4c0218394a189c56589
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/data/mert_dataset.py
@@ -0,0 +1,658 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import itertools
+import logging
+import os
+import sys
+from typing import Any, List, Optional, Union
+
+import numpy as np
+from typing import Tuple
+import torch
+import torch.nn.functional as F
+from fairseq.data import data_utils
+from fairseq.data.fairseq_dataset import FairseqDataset
+from fairseq.data.audio.audio_utils import (
+ parse_path,
+ read_from_stored_zip,
+)
+
+import math
+import io
+import torchaudio
+# this is in the user_dir
+from nnAudio import features as nnAudioFeatures
+
+# from tqdm import tqdm
+import tqdm
+import json
+import random
+import traceback
+from einops import rearrange
+# from scripts.prepare_codecs_from_manifest import *
+
+logger = logging.getLogger(__name__)
+
+class model_cqt_pred(torch.nn.Module):
+ def __init__(self, n_bins=84, sr=16000, freq=50):
+ super().__init__()
+ self.epsilon=1e-10
+ # Getting Mel Spectrogram on the fly
+ self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7,
+ fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7,
+ filter_scale=1, norm=1, window='hann', center=True,
+ pad_mode='constant', trainable=False,
+ output_format='Magnitude', verbose=True)
+
+ # self.fc = nn.Linear(input_dim, n_bins)
+
+ # self.criterion = nn.MSELoss()
+ self.forward_dict = {
+ # 'masked_transformer_output': self.plain_forward
+ 'compute_cqt': self.compute_cqt
+ }
+ def compute_cqt(self, x):
+ '''
+ convert waveform to CQT -> [batch, bins, len] -> transpose
+ '''
+ # align with the padding of HuBERT model,
+ # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different
+ # x = x[..., :-560]
+ return torch.transpose(self.spec_layer(x), -1, -2)
+
+ def forward(self, x, forward_type='masked_transformer_output'):
+ '''
+ take input from transformer hidden states: [batch, len_seq, channel]
+ output: [batch, len_seq, n_bins]
+ '''
+
+ return self.forward_dict[forward_type](x)
+
+def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5):
+ # read json file
+ print(json_path)
+ datas = []
+ inds = []
+ sizes = []
+ with open(json_path) as fp:
+ for ind,line in enumerate(fp):
+ data = json.loads(line)
+ if 'duration' in data and min_keep is not None and tgt_sample_rate*data['duration'] < min_keep:
+ continue
+ datas.append(data)
+ inds.append(ind)
+ # sz = int(data['duration'] * data['sample_rate'])
+ if clip_secs > 0:
+ sz = int(tgt_sample_rate * clip_secs)
+ else:
+ sz = int(tgt_sample_rate * data['duration'])
+ sizes.append(sz)
+ tot = ind + 1
+ return datas,inds,tot,sizes
+def load_audio(manifest_path, max_keep, min_keep):
+ print(manifest_path)
+
+ n_long, n_short = 0, 0
+ names, inds, sizes = [], [], []
+ with open(manifest_path) as f:
+ root = f.readline().strip()
+ for ind, line in enumerate(f):
+ items = line.strip().split("\t")
+ assert len(items) == 2, line
+ sz = int(items[1])
+ if min_keep is not None and sz < min_keep:
+ n_short += 1
+ elif max_keep is not None and sz > max_keep:
+ n_long += 1
+ else:
+ names.append(items[0])
+ inds.append(ind)
+ sizes.append(sz)
+ tot = ind + 1
+ logger.info(
+ (
+ f"max_keep={max_keep}, min_keep={min_keep}, "
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
+ )
+ )
+ return root, names, inds, tot, sizes
+
+
+def load_label(label_path, inds, tot):
+ with open(label_path) as f:
+ labels = []
+ for line in tqdm.tqdm(f):
+ labels.append(line.rstrip())
+ # labels = [line.rstrip() ]
+ assert (
+ len(labels) == tot
+ ), f"number of labels does not match ({len(labels)} != {tot})"
+ labels = [labels[i] for i in inds]
+ return labels
+
+def load_numpy_label(label_path, inds, tot):
+ labels = np.load(label_path, mmap_mode='r')
+ assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})"
+ return labels
+
+def verify_label_lengths(
+ audio_sizes,
+ audio_rate,
+ label_path,
+ label_rate,
+ inds,
+ tot,
+ tol=0.1, # tolerance in seconds
+):
+ if label_rate < 0:
+ logger.info(f"{label_path} is sequence label. skipped")
+ return
+
+ with open(label_path) as f:
+ lengths = []
+ for line in tqdm.tqdm(f):
+ lengths.append(len(line.rstrip().split()))
+ assert len(lengths) == tot
+ lengths = [lengths[i] for i in inds]
+ num_invalid = 0
+ for i, ind in enumerate(inds):
+ dur_from_audio = audio_sizes[i] / audio_rate
+ dur_from_label = lengths[i] / label_rate
+ if abs(dur_from_audio - dur_from_label) > tol:
+ logger.warning(
+ (
+ f"audio and label duration differ too much "
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
+ f"is correctly set (currently {label_rate}). "
+ f"num. of samples = {audio_sizes[i]}; "
+ f"label length = {lengths[i]}"
+ )
+ )
+ num_invalid += 1
+ if num_invalid > 0:
+ logger.warning(
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
+ )
+
+class Read_and_PadCrop_Normalized_T(torch.nn.Module):
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+
+
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int, fixed_offset_duration=None) -> Tuple[torch.Tensor, float, float, int, int]:
+ if self.n_samples is None:
+ chunk, cur_sample_rate = torchaudio.load(filename)
+ t_start = 0.
+ t_end = 1.0
+ offset = 0
+ elif(duration<(float(self.n_samples)/self.sample_rate+1)):
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
+ t_start = 0.
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
+ offset = 0
+ else:
+ if fixed_offset_duration is None:
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ else:
+ offset = int(cur_sample_rate*fixed_offset_duration)
+ t_start = offset / float(cur_sample_rate) / duration
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ # Pad with silence if necessary.
+ if(chunk.shape[0]>1):
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
+ else:
+ chunk = chunk[[0],:].float()
+ if(cur_sample_rate!=self.sample_rate):
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
+ if self.n_samples is None:
+ pass
+ elif chunk.shape[-1] < self.n_samples:
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
+ else:
+ chunk = chunk[:,0:self.n_samples]
+ seconds_start = math.floor(offset / cur_sample_rate)
+ seconds_total = math.floor(duration)
+
+ return (
+ chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total
+ )
+
+
+class MERTDataset(FairseqDataset):
+ def __init__(
+ self,
+ manifest_path: str,
+ sample_rate: float,
+ label_paths: List[str],
+ label_rates: Union[List[float], float], # -1 for sequence labels
+ pad_list: List[str],
+ eos_list: List[str],
+ label_scp_path: Optional[str] = None,
+ label_scp_clip_duration: float = -1,
+ label_processors: Optional[List[Any]] = None,
+ max_keep_sample_size: Optional[int] = None,
+ min_keep_sample_size: Optional[int] = None,
+ max_sample_size: Optional[int] = None,
+ shuffle: bool = True,
+ pad_audio: bool = False,
+ normalize: bool = False,
+ store_labels: bool = True,
+ npmemmap: bool = False,
+ random_crop: bool = False,
+ single_target: bool = False,
+ augmentation_effects: List[str] = [],
+ augmentation_probs: List[float] = [],
+ inbatch_noise_augment_len_range: List[int] = [8000, 24000],
+ inbatch_noise_augment_number_range: List[int] = [1, 3],
+ inbatch_noise_augment_volume: float = 1.0,
+ cqt_prediction_bin: int = -1,
+ dataset_len:int = 128*3000,
+ clip_secs = 5,
+ ):
+ self.sample_rate = sample_rate
+ self.shuffle = shuffle
+ self.random_crop = random_crop
+ self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs)
+ self.inds = inds
+
+ self.num_labels = len(label_paths)
+ self.pad_list = pad_list
+ self.eos_list = eos_list
+ self.label_processors = label_processors
+ self.single_target = single_target
+ self.label_rates = (
+ [label_rates for _ in range(len(label_paths))]
+ if isinstance(label_rates, float)
+ else label_rates
+ )
+ self.store_labels = store_labels
+ self.npmemmap = npmemmap
+ self.label_scp_path = label_scp_path
+ self.label_scp_clip_duration = label_scp_clip_duration
+
+
+ if self.label_scp_path is not None:
+ from kaldiio import load_scp
+ self.label_scp = load_scp(self.label_scp_path)
+
+ # self.dataset_len = dataset_len
+ self.dataset_len = len(self.datas)
+ logger.info('preparing labels')
+ logger.info('========dataset len: {}=========='.format(self.dataset_len))
+ if store_labels:
+ if self.npmemmap:
+ self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths]
+ else:
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
+ else:
+ self.label_paths = label_paths
+ # self.label_offsets_list = [
+ # load_label_offset(p, inds, tot) for p in label_paths
+ # ]
+ assert label_processors is None or len(label_processors) == self.num_labels
+ # logger.info('skip verify labels and audio lengths')
+ # for label_path, label_rate in zip(label_paths, self.label_rates):
+ # verify_label_lengths(
+ # self.sizes, sample_rate, label_path, label_rate, inds, tot
+ # )
+
+ self.max_sample_size = (
+ max_sample_size if max_sample_size is not None else sys.maxsize
+ )
+ self.pad_audio = pad_audio
+ self.normalize = normalize
+ logger.info(
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
+ )
+
+ self.augmentation_effects = augmentation_effects
+ self.augmentation_probs = augmentation_probs
+ # if len(self.augmentation_effects) > 0:
+ # self.augmentor_init()
+ # self.apply_augmentation = self.augmentation_factry(sample_rate)
+
+ self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range
+ self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range
+ self.inbatch_noise_augment_volume = inbatch_noise_augment_volume
+
+
+ self.cqt_prediction_bin = cqt_prediction_bin
+ if self.cqt_prediction_bin > 0:
+ self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin)
+ logger.info('preparing cqt loss objective in dataloader with cpu')
+
+ self.epoch = -1
+
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate if clip_secs>0 else None, sample_rate = self.sample_rate)
+
+
+
+ @property
+ def can_reuse_epoch_itr_across_epochs(self):
+ """
+ Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for
+ this dataset across epochs.
+
+ This needs to return ``False`` if the sample sizes can change across
+ epochs, in which case we may need to regenerate batches at each epoch.
+ If your dataset relies in ``set_epoch`` then you should consider setting
+ this to ``False``.
+ """
+ return False
+ def set_epoch(self, epoch):
+ """Will receive the updated epoch number at the beginning of the epoch."""
+ self.epoch = epoch
+
+ def inbatch_noise_augment(self,
+ target_audio: torch.Tensor, target_audio_idx: int ,
+ batch_audios: torch.Tensor, # [bsz, audio_lengths]
+ noise_len_min: int, noise_len_max: int,
+ n_noise_min: int, n_noise_max: int,
+ noise_vol: float = 1.0):
+ '''
+ augmenation that leverages in-batch noise audios.
+ noise_len_min and noise_len_max are the range of the lengths of noises (counted as samples)
+ n_noise_min and n_noise_max are the range of number of noises,
+ '''
+ # assert noise_len_max <= target_audio.shape[0] and noise_len_min >= 1 # should assert this outside?
+
+ augmented_audio = torch.clone(target_audio)
+ if noise_vol <= 0:
+ return augmented_audio
+
+ # exclude the target audio and use the rest as noise candidates
+ noise_pool = torch.cat( batch_audios[:target_audio_idx] + batch_audios[target_audio_idx+1:], dim=0).view(-1)
+
+ n_noise = np.random.randint(n_noise_min, n_noise_max)
+ # n_noise
+ random_start_idxs = np.random.randint(0, noise_pool.shape[0] - noise_len_max, size=(n_noise,))
+ random_durations = np.random.randint(noise_len_min, noise_len_max, size=(n_noise,))
+
+ for noise_idx in range(n_noise):
+ augmentation_position = np.random.randint(0, target_audio.shape[0] - random_durations[noise_idx], size=None)
+ # assign noise to the original audio
+ augmented_audio[augmentation_position:augmentation_position+random_durations[noise_idx]] += \
+ noise_vol * noise_pool[random_start_idxs[noise_idx]: random_start_idxs[noise_idx]+random_durations[noise_idx]]
+
+ return augmented_audio
+
+ def get_audio_by_slice(self,index):
+
+ metas = self.datas[index]
+ wav_path = metas['path']
+ # print(wav_path)
+
+ if 'sample_rate' in metas and 'duration' in metas:
+ origin_sample_rate = metas['sample_rate']
+ origin_duration = metas['duration']
+ else:
+ audio_info = torchaudio.info(wav_path)
+ origin_sample_rate = audio_info.sample_rate
+ origin_duration = audio_info.num_frames / origin_sample_rate
+
+ if self.label_scp_path is not None:
+ if origin_duration < self.label_scp_clip_duration:
+ raise ValueError(f"origin duration {origin_duration} too small")
+ max_n_step = min(int(origin_duration//self.label_scp_clip_duration), 10)
+ step = int(math.floor(random.random() * max_n_step)) % max_n_step
+ ind = self.inds[index]
+ scp_key = f'{ind}:{step}'
+ if scp_key not in self.label_scp:
+ raise KeyError(scp_key)
+ label = torch.from_numpy(self.label_scp[scp_key])
+ label = rearrange(label, "(n s) -> n s", n=8)
+ wav, *ignored = self.reader(wav_path, origin_duration, origin_sample_rate, fixed_offset_duration=step*self.label_scp_clip_duration)
+ else:
+ wav, *ignored = self.reader(wav_path, origin_duration, origin_sample_rate)
+ label = None
+ wav = wav.float()
+
+ wav = wav.permute(1,0)
+ wav = self.postprocess(wav, self.sample_rate)
+ return wav, label
+ def get_audio(self, index):
+ import soundfile as sf
+
+ # wav_path = os.path.join(self.audio_root, self.audio_names[index])
+ wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index])
+ _path, slice_ptr = parse_path(wav_path)
+ # original way
+ if len(slice_ptr) == 0:
+ wav, cur_sample_rate = sf.read(_path)
+ else:
+ assert _path.endswith(".zip")
+ data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
+ f = io.BytesIO(data)
+ wav, cur_sample_rate = sf.read(f)
+ wav = torch.from_numpy(wav).float()
+
+ wav = self.postprocess(wav, cur_sample_rate)
+ return wav
+
+ def get_label(self, index, label_idx):
+
+ if self.store_labels and (not self.npmemmap):
+ label = self.label_list[label_idx][index]
+ elif self.store_labels and self.npmemmap:
+ label = self.label_list[label_idx][index]
+ else:
+ with open(self.label_paths[label_idx]) as f:
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
+ f.seek(offset_s)
+ label = f.read(offset_e - offset_s)
+
+ if self.label_processors is not None:
+ label = self.label_processors[label_idx](label)
+ return 0
+
+ def get_labels(self, index):
+ return [self.get_label(index, i) for i in range(self.num_labels)]
+
+ def __getitem__(self, i):
+ index = i
+ item = None
+ while item is None:
+ try:
+ wav, label = self.get_audio_by_slice(index)
+ item = {"id": index, "source": wav}
+ if label is not None:
+ item['label'] = label
+ except Exception as e:
+ traceback.print_exc()
+ print(f'skip damaged data {index}')
+ index = np.random.randint(0,len(self.sizes)-1)
+ return item
+
+ def __len__(self):
+ return self.dataset_len
+
+ def crop_to_max_size(self, wav, target_size):
+ size = len(wav)
+ diff = size - target_size
+ if diff <= 0:
+ return wav, 0
+
+ start, end = 0, target_size
+ if self.random_crop:
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return wav[start:end], start
+
+ def collater(self, samples):
+ samples = [s for s in samples if s["source"] is not None]
+ if len(samples) == 0:
+ return {}
+
+ audios = [s["source"] for s in samples]
+ audio_sizes = [len(s) for s in audios]
+ # print("Audio sizes in batch:", audio_sizes, "Ids:", [s["id"] for s in samples])
+ if self.pad_audio:
+ audio_size = min(max(audio_sizes), self.max_sample_size)
+ else:
+ audio_size = min(min(audio_sizes), self.max_sample_size)
+ collated_audios, padding_mask, audio_starts, collated_cqt_labels = self.collater_audio(
+ audios, audio_size
+ )
+
+
+ net_input = {"source": collated_audios, "padding_mask": padding_mask, "cqt_labels": collated_cqt_labels}
+ if len(samples) > 0 and 'label' in samples[0]:
+ net_input['label'] = torch.stack([s['label'] for s in samples])
+
+ batch = {
+ "id": torch.LongTensor([s["id"] for s in samples]),
+ "net_input": net_input,
+ }
+
+ if self.single_target:
+ batch["target_lengths"] = None
+ batch["ntokens"] = None
+ batch["target"] = None
+ else:
+ batch["target_lengths_list"] = None
+ batch["ntokens_list"] = None
+ batch["target_list"] = None
+ return batch
+
+ def collater_audio(self, audios, audio_size):
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
+ padding_mask = (
+ torch.BoolTensor(collated_audios.shape).fill_(False)
+ )
+ audio_starts = [0 for _ in audios]
+
+ for i, audio in enumerate(audios):
+ diff = len(audio) - audio_size
+ if diff == 0:
+ collated_audios[i] = audio
+ elif diff < 0:
+ assert self.pad_audio
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
+ padding_mask[i, diff:] = True
+ else:
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
+ audio, audio_size
+ )
+
+ cqt_labels = None
+ if self.cqt_prediction_bin > 0:
+ cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt')
+
+ for i, _ in enumerate(audios):
+ # compute cqt labels in advance
+ # cqt_labels
+
+ # apply audio augmentation effects here
+ # the audio should be as the type torch.Tensor, in the shape [1, length] TODO?
+ if len(self.augmentation_effects) > 0:
+ with torch.no_grad():
+ for effect, prob in zip(self.augmentation_effects, self.augmentation_probs):
+ if torch.rand(1).item() > prob:
+ if effect == 'composed_augmentation_v1':
+ # collated_audios[i] = self.composed_augment_v1(collated_audios[i])
+ pass
+ elif effect == 'inbatch_noise_augment':
+ assert len(audios) > 1
+ collated_audios[i] = self.inbatch_noise_augment(
+ target_audio = collated_audios[i], target_audio_idx = i, batch_audios = audios,
+ noise_len_min = self.inbatch_noise_augment_len_range[0], noise_len_max = self.inbatch_noise_augment_len_range[1],
+ n_noise_min = self.inbatch_noise_augment_number_range[0], n_noise_max = self.inbatch_noise_augment_number_range[1],
+ noise_vol = self.inbatch_noise_augment_volume)
+ else:
+ raise NotImplementedError()
+
+
+ return collated_audios, padding_mask, audio_starts, cqt_labels
+
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
+ assert label_rate > 0
+ s2f = label_rate / self.sample_rate # 0.00625 for 100Hz and 16k sr
+ frm_starts = [int(round(s * s2f)) for s in audio_starts] # should be all 0 if the audios are not croped
+ frm_size = int(round(audio_size * s2f)) # this is the expected total number of given pseudo labels
+ if not self.pad_audio:
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
+ frm_size = min(frm_size, *rem_size)
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
+ logger.debug(f"audio_starts={audio_starts}")
+ logger.debug(f"frame_starts={frm_starts}")
+ logger.debug(f"frame_size={frm_size}")
+
+ lengths = torch.LongTensor([len(t) for t in targets])
+ ntokens = lengths.sum().item()
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
+ return targets, lengths, ntokens
+
+ def collater_seq_label(self, targets, pad):
+ lengths = torch.LongTensor([len(t) for t in targets])
+ ntokens = lengths.sum().item()
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
+ return targets, lengths, ntokens
+
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
+ targets_list, lengths_list, ntokens_list = [], [], []
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
+ for targets, label_rate, pad in itr:
+ if label_rate == -1.0:
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
+ else:
+ targets, lengths, ntokens = self.collater_frm_label(
+ targets, audio_size, audio_starts, label_rate, pad
+ )
+ targets_list.append(targets)
+ lengths_list.append(lengths)
+ ntokens_list.append(ntokens)
+ return targets_list, lengths_list, ntokens_list
+
+ def num_tokens(self, index):
+ return self.size(index)
+
+ def size(self, index):
+ if self.pad_audio:
+ return self.sizes[index]
+ return min(self.sizes[index], self.max_sample_size)
+
+ def ordered_indices(self):
+ if self.shuffle:
+ try:
+ print("========Local rank :",torch.distributed.get_rank(),"========")
+ WORLD_SIZE = int(torch.distributed.get_world_size())
+ WORLD_RANK = int(torch.distributed.get_rank())
+ np.random.seed(self.epoch * WORLD_SIZE + WORLD_RANK)
+ order = np.random.permutation(len(self.sizes))
+ print("==================multinode multigpu shuffle==================")
+ except:
+ print("==================singlenode shuffle==================")
+ order = np.random.permutation(len(self.sizes))
+ else:
+ order = np.arange(len(self.sizes))
+
+ return order
+
+ def postprocess(self, wav, cur_sample_rate):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if cur_sample_rate != self.sample_rate:
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
diff --git a/src/third_party/MuQ/src/recipes/pretrain/data/utils/data_utils.py b/src/third_party/MuQ/src/recipes/pretrain/data/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0234692a83f93272e868452e8ba13743264ce6d
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/data/utils/data_utils.py
@@ -0,0 +1,535 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+import math
+import numpy as np
+import torch
+
+from typing import Optional, Tuple
+
+
+
+logger = logging.getLogger(__name__)
+
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
+ add_masks: bool = False,
+ seed: Optional[int] = None,
+ epoch: Optional[int] = None,
+ indices: Optional[torch.Tensor] = None,
+ idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset
+ num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
+ mask_dropout: randomly dropout this percentage of masks in each example
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ if num_mask_ver == 1:
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length)
+ + np.random.rand()
+ )
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if seed is not None and epoch is not None and indices is not None:
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6)
+ else:
+ seed_i = None
+
+ rng = np.random.default_rng(seed_i)
+
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ assert sz >= 0, sz
+ else:
+ sz = all_sz
+
+ if num_mask_ver == 1:
+ if padding_mask is not None:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ num_mask = all_num_mask
+ elif num_mask_ver == 2:
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length)
+ + rng.random()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ raise ValueError()
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = rng.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = rng.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ if mask_type == "static":
+ raise ValueError(f"this should never happens")
+ else:
+ lengths = [min(mask_length, sz - 1)]
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = rng.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = rng.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ if idc_select_ver == 1:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
+ elif idc_select_ver == 2:
+ mask_idc = rng.choice(sz, num_mask, replace=False)
+ else:
+ raise ValueError()
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
+ if len(mask_idc) >= sz:
+ raise ValueError(
+ (
+ f"the entire sequence is masked. "
+ f"sz={sz}; mask_idc[mask_idc]; "
+ f"index={indices[i] if indices is not None else None}"
+ )
+ )
+ mask_idcs.append(mask_idc)
+
+ target_len = None
+ if require_same_masks:
+ if add_masks:
+ target_len = max([len(m) for m in mask_idcs])
+ else:
+ target_len = min([len(m) for m in mask_idcs])
+
+ for i, mask_idc in enumerate(mask_idcs):
+ if target_len is not None and len(mask_idc) > target_len:
+ mask_idc = rng.choice(mask_idc, target_len, replace=False)
+
+ mask[i, mask_idc] = True
+
+ if target_len is not None and len(mask_idc) < target_len:
+ unmasked = np.flatnonzero(~mask[i])
+ to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False)
+ mask[i, to_mask] = True
+
+ if mask_dropout > 0:
+ masked = np.flatnonzero(mask[i])
+ num_holes = np.rint(len(masked) * mask_dropout).astype(int)
+ to_drop = rng.choice(masked, num_holes, replace=False)
+ mask[i, to_drop] = False
+
+ return mask
+
+
+def compute_block_mask_2d(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ mask_prob_adjust: float = 0,
+ inverse_mask: bool = False,
+ require_same_masks: bool = True,
+ expand_adjcent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+ img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways
+ flexible_mask: bool = False,
+) -> torch.Tensor:
+
+ assert mask_length > 1
+
+ B, L = shape
+
+ d = (int(L**0.5),int(L**0.5))
+
+ if img_shape:
+ d = (img_shape[0],img_shape[1])
+
+ if flexible_mask:
+ index = np.random.randint(0,3)
+ block_size_options = np.array([(6, 4), (5, 5), (8, 3)])
+ block_size = block_size_options[index]
+
+ if inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ if flexible_mask:
+ mask = torch.zeros((B, d[0], d[1]))
+ mask_inds = torch.randint(
+ 0,
+ L,
+ size=(
+ B,
+ int(
+ L
+ * ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1]))
+ * (1 + mask_dropout)
+ ),
+ ),
+ )
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
+ centers = mask.nonzero(as_tuple=True)
+
+ inds = ([], [], [])
+
+ offset = mask_length // 2
+ for i in range(block_size[0]):
+ for j in range(block_size[1]):
+ k1 = i - offset
+ k2 = j - offset
+ inds[0].append(centers[0])
+ inds[1].append(centers[1] + k1)
+ inds[2].append(centers[2] + k2)
+
+ i0 = torch.cat(inds[0])
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
+
+ mask[(i0, i1, i2)] = 1
+
+ elif non_overlapping:
+ sz = math.ceil(d[0] / mask_length)
+ inp_len = sz * sz
+
+ inp = torch.zeros((B, 1, sz, sz))
+ w = torch.ones((1, 1, mask_length, mask_length))
+
+ mask_inds = torch.multinomial(
+ 1 - inp.view(B, -1),
+ int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
+ replacement=False,
+ )
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
+
+ mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze(
+ 1
+ )
+ if mask.size(-1) > d[0]:
+ mask = mask[..., :d, :d]
+ else:
+ mask = torch.zeros((B, d[0], d[1]))
+ mask_inds = torch.randint(
+ 0,
+ L,
+ size=(
+ B,
+ int(
+ L
+ * ((mask_prob + mask_prob_adjust) / mask_length**2)
+ * (1 + mask_dropout)
+ ),
+ ),
+ )
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
+ centers = mask.nonzero(as_tuple=True)
+
+ inds = ([], [], [])
+
+ offset = mask_length // 2
+ for i in range(mask_length):
+ for j in range(mask_length):
+ k1 = i - offset
+ k2 = j - offset
+ inds[0].append(centers[0])
+ inds[1].append(centers[1] + k1)
+ inds[2].append(centers[2] + k2)
+
+ i0 = torch.cat(inds[0])
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1)
+ i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1)
+
+ mask[(i0, i1, i2)] = 1
+
+ def get_nbs(b, m, w):
+ all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same")
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
+ return all_nbs
+
+ if require_same_masks and expand_adjcent:
+ w = torch.zeros((1, 1, 3, 3))
+ w[..., 0, 1] = 1
+ w[..., 2, 1] = 1
+ w[..., 1, 0] = 1
+ w[..., 1, 2] = 1
+
+ all_nbs = get_nbs(B, mask, w)
+
+ mask = mask.reshape(B, -1)
+
+ if require_same_masks:
+ n_masks = mask.sum(dim=-1)
+ final_target_len = int(L * (mask_prob))
+ target_len = int(final_target_len * (1 + mask_dropout))
+
+ for i in range(len(mask)):
+ n = n_masks[i]
+ m = mask[i]
+ r = 0
+ while expand_adjcent and n < target_len:
+ if r == 0:
+ nbs = all_nbs[i]
+ else:
+ nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten()
+
+ cands = (1 - m + nbs) > 1
+ cand_sz = int(cands.sum().item())
+
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
+
+ to_mask = torch.multinomial(
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
+ )
+ m[to_mask] = 1
+ assert to_mask.numel() > 0
+ n += to_mask.numel()
+ r += 1
+
+ if n > final_target_len:
+ to_unmask = torch.multinomial(
+ m, int(n - final_target_len), replacement=False
+ )
+ m[to_unmask] = 0
+ elif n < final_target_len:
+ to_mask = torch.multinomial(
+ (1 - m), int(final_target_len - n), replacement=False
+ )
+ m[to_mask] = 1
+
+ if inverse_mask:
+ mask = 1 - mask
+
+ return mask
+
+
+def compute_block_mask_1d(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ mask_prob_adjust: float = 0,
+ inverse_mask: bool = False,
+ require_same_masks: bool = True,
+ expand_adjcent: bool = False,
+ mask_dropout: float = 0,
+ non_overlapping: bool = False,
+) -> torch.Tensor:
+
+ B, L = shape
+
+ if inverse_mask:
+ mask_prob = 1 - mask_prob
+
+ if non_overlapping:
+ sz = math.ceil(L / mask_length)
+
+ inp = torch.zeros((B, 1, sz))
+ w = torch.ones((1, 1, mask_length))
+
+ mask_inds = torch.multinomial(
+ 1 - inp.view(B, -1),
+ int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)),
+ replacement=False,
+ )
+ inp.view(B, -1).scatter_(1, mask_inds, 1)
+
+ mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze(
+ 1
+ )
+ if mask.size(-1) > L:
+ mask = mask[..., :L]
+
+ else:
+ mask = torch.zeros((B, L))
+ mask_inds = torch.randint(
+ 0,
+ L,
+ size=(
+ B,
+ int(
+ L
+ * ((mask_prob + mask_prob_adjust) / mask_length)
+ * (1 + mask_dropout)
+ ),
+ ),
+ )
+
+ mask.view(B, -1).scatter_(1, mask_inds, 1)
+ centers = mask.nonzero(as_tuple=True)
+
+ inds = ([], [])
+
+ offset = mask_length // 2
+ for i in range(mask_length):
+ k1 = i - offset
+ inds[0].append(centers[0])
+ inds[1].append(centers[1] + k1)
+
+ i0 = torch.cat(inds[0])
+ i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1)
+
+ mask[(i0, i1)] = 1
+
+ def get_nbs(b, m, w):
+ all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same")
+ all_nbs = all_nbs.clamp_max_(1).view(b, -1)
+ return all_nbs
+
+ if require_same_masks and expand_adjcent:
+ w = torch.ones((1, 1, 3))
+ w[..., 1] = 0
+ all_nbs = get_nbs(B, mask, w)
+
+ mask = mask.view(B, -1)
+
+ if require_same_masks:
+ n_masks = mask.sum(dim=-1)
+ final_target_len = int(L * (mask_prob))
+ target_len = int(final_target_len * (1 + mask_dropout))
+
+ for i in range(len(mask)):
+ n = n_masks[i]
+ m = mask[i]
+ r = 0
+ while expand_adjcent and n < target_len:
+ if r == 0:
+ nbs = all_nbs[i]
+ else:
+ nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0)
+
+ cands = (1 - m + nbs) > 1
+ cand_sz = int(cands.sum().item())
+
+ assert cand_sz > 0, f"{nbs} {cand_sz}"
+
+ to_mask = torch.multinomial(
+ cands.float(), min(cand_sz, int(target_len - n)), replacement=False
+ )
+ m[to_mask] = 1
+ assert to_mask.numel() > 0
+ n += to_mask.numel()
+ r += 1
+
+ if n > final_target_len:
+ to_unmask = torch.multinomial(
+ m, int(n - final_target_len), replacement=False
+ )
+ m[to_unmask] = 0
+ elif n < final_target_len:
+ to_mask = torch.multinomial(
+ (1 - m), int(final_target_len - n), replacement=False
+ )
+ m[to_mask] = 1
+
+ if inverse_mask:
+ mask = 1 - mask
+
+ return mask
+
+
+def get_buckets(sizes, num_buckets):
+ buckets = np.unique(
+ np.percentile(
+ sizes,
+ np.linspace(0, 100, num_buckets + 1),
+ interpolation="lower",
+ )[1:]
+ )
+ return buckets
+
+
+def get_bucketed_sizes(orig_sizes, buckets):
+ sizes = np.copy(orig_sizes)
+ assert np.min(sizes) >= 0
+ start_val = -1
+ for end_val in buckets:
+ mask = (sizes > start_val) & (sizes <= end_val)
+ sizes[mask] = end_val
+ start_val = end_val
+ return sizes
+
+
diff --git a/src/third_party/MuQ/src/recipes/pretrain/data/utils/mixup.py b/src/third_party/MuQ/src/recipes/pretrain/data/utils/mixup.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cd0d2c333a8973e1c994bc686d935117854ef40
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/data/utils/mixup.py
@@ -0,0 +1,220 @@
+""" Mixup and Cutmix
+
+Papers:
+mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+
+CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
+
+Code Reference:
+CutMix: https://github.com/clovaai/CutMix-PyTorch
+
+Hacked together by / Copyright 2019, Ross Wightman
+"""
+import numpy as np
+import torch
+
+
+def one_hot(x, num_classes, on_value=1., off_value=0.):
+ x = x.long().view(-1, 1)
+ return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value)
+
+# adapted from using one_hot to directly using target values
+def mixup_target(target, num_classes, lam=1., smoothing=0.0):
+ # off_value = smoothing / num_classes
+ # on_value = 1. - smoothing + off_value
+ # y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
+ # y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
+ y1 = target
+ y2 = target.flip(0)
+ return y1 * lam + y2 * (1. - lam)
+
+
+def rand_bbox(img_shape, lam, margin=0., count=None):
+ """ Standard CutMix bounding-box
+ Generates a random square bbox based on lambda value. This impl includes
+ support for enforcing a border margin as percent of bbox dimensions.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ lam (float): Cutmix lambda value
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
+ count (int): Number of bbox to generate
+ """
+ ratio = np.sqrt(1 - lam)
+ img_h, img_w = img_shape[-2:]
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
+ return yl, yh, xl, xh
+
+
+def rand_bbox_minmax(img_shape, minmax, count=None):
+ """ Min-Max CutMix bounding-box
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
+ based on min/max percent values applied to each dimension of the input image.
+
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
+ count (int): Number of bbox to generate
+ """
+ assert len(minmax) == 2
+ img_h, img_w = img_shape[-2:]
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
+ yl = np.random.randint(0, img_h - cut_h, size=count)
+ xl = np.random.randint(0, img_w - cut_w, size=count)
+ yu = yl + cut_h
+ xu = xl + cut_w
+ return yl, yu, xl, xu
+
+
+def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
+ """ Generate bbox and apply lambda correction.
+ """
+ if ratio_minmax is not None:
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
+ else:
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
+ if correct_lam or ratio_minmax is not None:
+ bbox_area = (yu - yl) * (xu - xl)
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
+ return (yl, yu, xl, xu), lam
+
+
+class Mixup:
+ """ Mixup/Cutmix that applies different params to each element or whole batch
+
+ Args:
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
+ prob (float): probability of applying mixup or cutmix per batch or element
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
+ label_smoothing (float): apply label smoothing to the mixed target tensor
+ num_classes (int): number of classes for target
+ """
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
+ self.mixup_alpha = mixup_alpha
+ self.cutmix_alpha = cutmix_alpha
+ self.cutmix_minmax = cutmix_minmax
+ if self.cutmix_minmax is not None:
+ assert len(self.cutmix_minmax) == 2
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
+ self.cutmix_alpha = 1.0
+ self.mix_prob = prob
+ self.switch_prob = switch_prob
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+ self.mode = mode
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
+
+ def _params_per_elem(self, batch_size):
+ lam = np.ones(batch_size, dtype=np.float32)
+ use_cutmix = np.zeros(batch_size, dtype=bool)
+ if self.mixup_enabled:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
+ lam_mix = np.where(
+ use_cutmix,
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = np.ones(batch_size, dtype=bool)
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
+ return lam, use_cutmix
+
+ def _params_per_batch(self):
+ lam = 1.
+ use_cutmix = False
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand() < self.switch_prob
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = True
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = float(lam_mix)
+ return lam, use_cutmix
+
+ def _mix_elem(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_pair(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size // 2):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_batch(self, x):
+ lam, use_cutmix = self._params_per_batch()
+ if lam == 1.:
+ return 1.
+ if use_cutmix:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
+ else:
+ x_flipped = x.flip(0).mul_(1. - lam)
+ x.mul_(lam).add_(x_flipped)
+ return lam
+
+ def __call__(self, x, target):
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
+ if self.mode == 'elem':
+ lam = self._mix_elem(x)
+ elif self.mode == 'pair':
+ lam = self._mix_pair(x)
+ else:
+ lam = self._mix_batch(x)
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
+ return x, target
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/dataset/music4all/.gitkeep b/src/third_party/MuQ/src/recipes/pretrain/dataset/music4all/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/third_party/MuQ/src/recipes/pretrain/models/muq_fairseq.py b/src/third_party/MuQ/src/recipes/pretrain/models/muq_fairseq.py
new file mode 100644
index 0000000000000000000000000000000000000000..8705e71739cfc3486b8de888b98625c3d215f16b
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/models/muq_fairseq.py
@@ -0,0 +1,177 @@
+from muq import MuQ
+from fairseq.dataclass import FairseqDataclass
+from fairseq.models import BaseFairseqModel, register_model
+from fairseq.tasks.fairseq_task import FairseqTask
+
+from dataclasses import dataclass, field
+from typing import List, Tuple, Optional, Dict, Any
+import torch
+import json
+
+from logging import getLogger
+
+logger = getLogger(__name__)
+
+@dataclass
+class Stat:
+ melspec_2048_cnt: int = 14282760192
+ melspec_2048_mean: float = 6.768444971712967
+ melspec_2048_std: float = 18.417922652295623
+
+@dataclass
+class W2v2Config:
+ activation_dropout: float = 0.1
+ adapter_kernel_size: int = 3
+ adapter_stride: int = 2
+ add_adapter: bool = False
+ apply_spec_augment: bool = True
+ architectures: List[str] = field(default_factory=lambda: ["Wav2Vec2ConformerForCTC"])
+ attention_dropout: float = 0.1
+ bos_token_id: int = 1
+ classifier_proj_size: int = 256
+ codevector_dim: int = 768
+ conformer_conv_dropout: float = 0.1
+ contrastive_logits_temperature: float = 0.1
+ conv_bias: bool = True
+ conv_depthwise_kernel_size: int = 31
+ conv_dim: List[int] = field(default_factory=lambda: [512]*7)
+ conv_kernel: List[int] = field(default_factory=lambda: [10, 3, 3, 3, 3, 2, 2])
+ conv_stride: List[int] = field(default_factory=lambda: [5, 2, 2, 2, 2, 2, 2])
+ ctc_loss_reduction: str = "sum"
+ ctc_zero_infinity: bool = False
+ diversity_loss_weight: float = 0.1
+ do_stable_layer_norm: bool = True
+ eos_token_id: int = 2
+ feat_extract_activation: str = "gelu"
+ feat_extract_dropout: float = 0.0
+ feat_extract_norm: str = "layer"
+ feat_proj_dropout: float = 0.1
+ feat_quantizer_dropout: float = 0.0
+ final_dropout: float = 0.1
+ gradient_checkpointing: bool = False
+ hidden_act: str = "swish"
+ hidden_dropout: float = 0.1
+ hidden_dropout_prob: float = 0.1
+ hidden_size: int = 1024
+ initializer_range: float = 0.02
+ intermediate_size: int = 4096
+ layer_norm_eps: float = 1e-5
+ layerdrop: float = 0.0
+ mask_feature_length: int = 10
+ mask_feature_min_masks: int = 0
+ mask_feature_prob: float = 0.0
+ mask_time_length: int = 10
+ mask_time_min_masks: int = 2
+ mask_time_prob: float = 0.05
+ max_source_positions: int = 5000
+ model_type: str = "wav2vec2-conformer"
+ num_adapter_layers: int = 3
+ num_attention_heads: int = 16
+ num_codevector_groups: int = 2
+ num_codevectors_per_group: int = 320
+ num_conv_pos_embedding_groups: int = 16
+ num_conv_pos_embeddings: int = 128
+ num_feat_extract_layers: int = 7
+ num_hidden_layers: int = 24
+ num_negatives: int = 100
+ output_hidden_size: int = 1024
+ pad_token_id: int = 0
+ position_embeddings_type: str = "rotary"
+ proj_codevector_dim: int = 768
+ rotary_embedding_base: int = 10000
+ tdnn_dilation: List[int] = field(default_factory=lambda: [1, 2, 3, 1, 1])
+ tdnn_dim: List[int] = field(default_factory=lambda: [512, 512, 512, 512, 1500])
+ tdnn_kernel: List[int] = field(default_factory=lambda: [5, 3, 3, 1, 1])
+ torch_dtype: str = "float32"
+ transformers_version: str = "4.19.0.dev0"
+ use_weighted_layer_sum: bool = False
+ vocab_size: int = 32
+ xvector_output_dim: int = 512
+
+@dataclass
+class MuQFairseqConfig(FairseqDataclass):
+ label_rate:int = field(default=25)
+ num_codebooks:int = field(default=1)
+ codebook_dim:int = field(default=16)
+ codebook_size:int = field(default=4096)
+ features:List[str] = field(default_factory=lambda:["melspec_2048"])
+ hop_length:int = field(default=240)
+ n_mels:int = field(default=128)
+ conv_dim:int = field(default=512)
+ encoder_dim:int = field(default=1024)
+ encoder_depth:int = field(default=12)
+ mask_hop:float = field(default=0.4)
+ mask_prob:float = field(default=0.6)
+ is_flash:bool = field(default=False)
+ stat:Stat = field(default_factory=Stat)
+ w2v2_config:W2v2Config = field(default_factory=W2v2Config)
+ use_rvq_target:bool = field(default=False)
+ use_vq_target:bool = field(default=False)
+ use_encodec_target:bool = field(default=False)
+ rvq_ckpt_path: Optional[str] = field(default=None)
+ recon_loss_ratio: Optional[float] = field(default=None)
+ resume_checkpoint: Optional[str] = None
+ rvq_n_codebooks:int = field(default=8)
+ rvq_multi_layer_num:int = field(default=1)
+
+SAMPLE_RATE = 24_000
+
+@register_model("muq", dataclass=MuQFairseqConfig)
+class MuQFairseqModel(BaseFairseqModel):
+ def __init__(self, cfg: MuQFairseqConfig, task_cfg: FairseqTask):
+ super().__init__()
+ self.muq_config = cfg
+ muq = MuQ(self.muq_config)
+ self.muq = muq
+ self.model = muq.model
+
+ def forward(
+ self,
+ source: torch.Tensor, # B,L
+ features_only: bool = False,
+ label = None, # pre-extracted labeks, dim is [Batch, N_Codebook, SeqLen]
+ **kwargs,
+ ):
+ source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.muq_config['label_rate']))*(SAMPLE_RATE//self.muq_config['label_rate'])) ]
+ if features_only:
+ if 'attention_mask' in kwargs:
+ attention_mask = kwargs['attention_mask']
+ elif 'padding_mask' in kwargs:
+ attention_mask = ~kwargs['padding_mask'].bool()
+ else:
+ attention_mask = None
+ _, hidden_states = self.model.get_predictions(source, attention_mask=attention_mask, is_features_only=True)
+ result = {
+ "layer_results": hidden_states
+ }
+ return result
+ else:
+ result = {}
+ logits, hidden_emb, losses, accuracies = self.model(source, label=label)
+ result["losses"] = losses
+ result["accuracies"] = accuracies
+ result["logits"] = logits
+ result["hidden_emb"] = hidden_emb
+ for k, v in losses.items():
+ result[k] = v
+ return result
+
+ @classmethod
+ def build_model(cls, cfg: MuQFairseqConfig, task: FairseqTask):
+ """Build a new model instance."""
+
+ model = MuQFairseqModel(cfg, task.cfg)
+ import numpy as np
+ s = 0
+ for param in model.parameters():
+ s += np.product(param.size())
+ # print('# of parameters: '+str(s/1024.0/1024.0))
+
+ if cfg.get("resume_checkpoint", None):
+ print("Loading checkpoint from {}".format(cfg.resume_checkpoint))
+ model.load_state_dict(torch.load(cfg.resume_checkpoint)['model'], strict=False)
+
+ return model
+
+ def get_losses(self, result, batch):
+ return result['losses']
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/requirements.txt b/src/third_party/MuQ/src/recipes/pretrain/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a3444bb99deb6fe3cd079b2cb2c330efe75c30e8
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/requirements.txt
@@ -0,0 +1,5 @@
+fairseq # You may need to manually install fairseq by following the instructions in https://github.com/facebookresearch/fairseq
+hydra-core==1.3.1
+tensorboardX
+bitarray
+sacrebleu
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/rvq_trainer.py b/src/third_party/MuQ/src/recipes/pretrain/rvq_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2918e1e58469fbdee06ee779d0b84ae5d416a8fb
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/rvq_trainer.py
@@ -0,0 +1,376 @@
+import sys, os
+import argparse
+
+
+from muq.muq.modules.rvq import *
+from muq.muq.modules.features import MelSTFT
+import fairseq
+import torch
+from torch.utils.data import Dataset, DataLoader
+import json, traceback
+import torchaudio
+import math
+import torch.nn as nn
+
+from typing import List, Tuple, Dict, Any
+
+CLIPSECS = 5 # 5 for rvq, 30 for model
+
+
+def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate):
+ # read json file
+ print(json_path)
+ datas = []
+ inds = []
+ sizes = []
+ with open(json_path) as fp:
+ for ind,line in enumerate(fp):
+ data = json.loads(line)
+ datas.append(data)
+ inds.append(ind)
+ # sz = int(data['duration'] * data['sample_rate'])
+ sz = int(tgt_sample_rate * CLIPSECS)
+ sizes.append(sz)
+ tot = ind + 1
+ return datas,inds,tot,sizes
+
+class Read_and_PadCrop_Normalized_T(torch.nn.Module):
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+
+
+ def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]:
+ if(duration<(float(self.n_samples)/self.sample_rate+1)):
+ # print(duration,(float(self.n_samples)/self.sample_rate+1))
+ chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1)
+ t_start = 0.
+ t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration)
+ offset = 0
+ # print('c1:',chunk.shape)
+ else:
+ offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ t_start = offset / float(cur_sample_rate) / duration
+ t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration
+ chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate))
+ # print('offset:',offset)
+ # print('c0:',chunk.shape)
+ # Pad with silence if necessary.
+ if(chunk.shape[0]>1):
+ chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float()
+ else:
+ chunk = chunk[[0],:].float()
+ if(cur_sample_rate!=self.sample_rate):
+ # print('a:',cur_sample_rate,chunk.shape)
+ chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate)
+ # print('b:',self.sample_rate,chunk.shape)
+ if chunk.shape[-1] < self.n_samples:
+ chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1)
+ else:
+ chunk = chunk[:,0:self.n_samples]
+ seconds_start = math.floor(offset / cur_sample_rate)
+ seconds_total = math.floor(duration)
+
+ return (
+ chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total
+ )
+
+class RVQDataset(Dataset):
+ def __init__(
+ self,
+ manifest_path: str,
+ sample_rate: float,
+ normalize: bool = False,
+ ):
+ self.sample_rate = sample_rate
+ self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate)
+ self.dataset_len = len(self.datas)
+
+ self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate)
+ self.normalize = normalize
+
+
+ def __getitem__(self, i):
+ index = i
+ item = None
+ while item is None:
+ try:
+ wav = self.get_audio_by_slice(index)
+ item = {"id": index, "source": wav}
+ except Exception as e:
+ # print(e)
+ traceback.print_exc()
+ print(f'skip damaged data {index}')
+ index = np.random.randint(0,len(self.sizes)-1)
+ return item
+
+ def __len__(self):
+ return self.dataset_len
+
+ def get_audio_by_slice(self,index):
+
+ wav_path = self.datas[index]['path']
+ audio_info = torchaudio.info(wav_path)
+ origin_sample_rate = audio_info.sample_rate
+ origin_duration = audio_info.num_frames / origin_sample_rate
+
+ wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate)
+ wav = wav.float()
+
+ wav = wav.permute(1,0)
+ wav = self.postprocess(wav, self.sample_rate)
+ return wav
+
+ def postprocess(self, wav, cur_sample_rate):
+ if wav.dim() == 2:
+ wav = wav.mean(-1)
+ assert wav.dim() == 1, wav.dim()
+
+ if cur_sample_rate != self.sample_rate:
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
+
+ if self.normalize:
+ with torch.no_grad():
+ wav = F.layer_norm(wav, wav.shape)
+ return wav
+
+class Preprocessor(nn.Module):
+ def __init__(self,
+ codebook_dim=16,
+ codebook_size=4096,
+ hop_length=240,
+ n_mels=128,
+ stat_path=None,
+ is_spec_wise=False,
+ s=4,
+ ) -> None:
+ super().__init__()
+
+ self.features=["melspec_2048"]
+ self.s = s
+
+ # load feature mean / std stats
+ import os
+ if stat_path is not None and os.path.exists(stat_path):
+ with open(stat_path, "r") as f:
+ self.stat = json.load(f)
+ else:
+ # print("No stats file found at `{}`, use default from msd.".format(stat_path))
+ self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234}
+
+ # feature extractor
+ self.preprocessor_melspec_2048 = MelSTFT(
+ n_fft=2048, hop_length=hop_length, is_db=True
+ )
+
+ self.is_spec_wise = is_spec_wise
+
+
+ @torch.no_grad()
+ def normalize(self, x):
+ """normalize the input audio to have zero mean unit variance"""
+ for key in x.keys():
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967}
+ return x
+
+ @torch.no_grad()
+ def rearrange(self, x):
+ """rearrange the batch to flatten every 4 steps"""
+ for key in x.keys():
+ if key == "chromagram":
+ x[key] = rearrange(x[key], "b f t -> b t f")
+ else:
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=self.s)
+ return x
+
+ @torch.no_grad()
+ def preprocessing(self, x, features):
+ """extract classic audio features"""
+ # check precision
+ if x.dtype == torch.float16:
+ precision = 16
+ else:
+ precision = 32
+
+ out = {}
+ for key in features:
+ layer = getattr(self, "preprocessor_%s" % key)
+ out[key] = layer(x.float())[..., :-1]
+ if precision == 16:
+ out[key] = out[key].half()
+ return out
+
+ @torch.no_grad()
+ def tokenize(self, x):
+ out = {}
+ for key in x.keys():
+ layer = getattr(self, "quantizer_%s" % key)
+ out[key] = layer(x[key])
+ return out
+
+ def to_spec_wise(self, x):
+ Batch, Spec, Time = x.shape
+ SubSpec, N_SubSpec = 16, 8
+ assert SubSpec * N_SubSpec == Spec == 128
+ x = rearrange(x, "b (n s) t -> b s (n t)", n=N_SubSpec, s=SubSpec)
+ return x # [Batch, SubSpec=16, N_SubSpec*Time=8*100Hz]
+
+ @torch.no_grad()
+ def __call__(self, x):
+ x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}}
+ x = self.normalize(x)
+ if self.is_spec_wise:
+ x = {k:self.to_spec_wise(v) for k,v in x.items()}
+ x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}}
+ return x['melspec_2048'].permute((0, 2, 1))
+
+ def to(self, device):
+ self.preprocessor_melspec_2048.to(device)
+ return super().to(device)
+
+
+
+def main(config):
+ train_dataset = RVQDataset(**config['train_dataset'])
+ if config['valid_dataset']['manifest_path'] is None:
+ # split train and valid dataset
+ from torch.utils.data import random_split
+ train_dataset, valid_dataset = random_split(
+ train_dataset, lengths=[len(train_dataset) - 500, 500]
+ )
+ else:
+ valid_dataset = RVQDataset(**config['valid_dataset'])
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
+ valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers'])
+ model = ResidualVectorQuantize(**config['model'])
+
+ device = config['train']['device']
+ preprocess = config['train']['preprocess'].to(device)
+ model = model.to(device)
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
+ cur_updates = 0
+ is_running = True
+ result = {}
+ from tqdm import tqdm
+ from tensorboardX import SummaryWriter
+ writer = SummaryWriter()
+ from collections import defaultdict
+ import os
+ from logging import getLogger
+ logger = getLogger()
+
+ while is_running:
+ results = defaultdict(lambda:0)
+ for item in tqdm(train_dataloader, desc='train'):
+ wavs = item['source']
+ optimizer.zero_grad()
+ wavs = wavs.to(device)
+ x = preprocess(wavs)
+ model.train()
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
+ loss = eval(config['train']['loss'])
+ loss.backward()
+ optimizer.step()
+
+ results['loss/train'] += loss.item()
+ results['commitment_loss/train'] += commitment_loss.item()
+ results['codebook_loss/train'] += codebook_loss.item()
+ results['rvq_usage/train'] += rvq_usage.float().mean().item()
+
+ if cur_updates % config['train']['valid_interval'] == 0:
+ model.eval()
+ with torch.no_grad():
+ for item in tqdm(valid_dataloader, desc='valid'):
+ wavs = item['source']
+ wavs = wavs.to(device)
+ x = preprocess(wavs)
+ quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x)
+ valid_loss = eval(config['train']['loss'])
+
+ results['loss/valid'] += valid_loss.item()
+ results['commitment_loss/valid'] += commitment_loss.item()
+ results['codebook_loss/valid'] += codebook_loss.item()
+ results['rvq_usage/valid'] += rvq_usage.float().mean().item()
+
+ results['cur_updates'] = cur_updates
+ results['loss/train'] /= config['train']['valid_interval']
+ results['commitment_loss/train'] /= config['train']['valid_interval']
+ results['codebook_loss/train'] /= config['train']['valid_interval']
+ results['rvq_usage/train'] /= config['train']['valid_interval']
+
+ results['loss/valid'] /= len(valid_dataloader)
+ results['commitment_loss/valid'] /= len(valid_dataloader)
+ results['codebook_loss/valid'] /= len(valid_dataloader)
+ results['rvq_usage/valid'] /= len(valid_dataloader)
+
+ print('')
+ logger.info(str(results))
+ for k,v in results.items():
+ writer.add_scalar(k, v, cur_updates)
+
+ results.clear()
+
+ if cur_updates % config['train']['save_interval'] == 0:
+ os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True)
+ logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
+ torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth')
+
+
+ if cur_updates < config['train']['max_updates']:
+ cur_updates += 1
+ else:
+ is_running = False
+ break
+
+
+def Music_Mel_Target_Config(args):
+ config = dict(
+ train_dataset = dict(
+ manifest_path = args.train_path,
+ sample_rate = 24000,
+ normalize = False,
+ ),
+ valid_dataset = dict(
+ manifest_path = args.valid_path,
+ sample_rate = 24000,
+ normalize = False,
+ ),
+ model = dict(
+ input_dim = 128*4,
+ n_codebooks = 8,
+ codebook_size = 1024,
+ codebook_dim = 16,
+ quantizer_dropout = 0.0,
+ ),
+ train = dict(
+ batch_size = 32,
+ num_workers = 6,
+ valid_interval = 10,
+ save_interval = 100,
+ max_updates = 500000,
+ lr = 1e-4,
+ device = 'cuda:0',
+ loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()',
+ preprocess = Preprocessor()
+ )
+ )
+ return config
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--train_path', type=str, required=True, help="path to the train dataset split JSON file")
+ parser.add_argument('--valid_path', type=str, required=True, help="path to the valid dataset split JSON file")
+ args = parser.parse_args()
+
+ config = Music_Mel_Target_Config(args)
+ main(config)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/scripts/convert_muq_fairseq_ckpt_to_huggingface.py b/src/third_party/MuQ/src/recipes/pretrain/scripts/convert_muq_fairseq_ckpt_to_huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..928397e0c5876b72639fcc418a46953e766865bc
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/scripts/convert_muq_fairseq_ckpt_to_huggingface.py
@@ -0,0 +1,47 @@
+import os
+import torch
+import json
+import argparse
+from dataclasses import dataclass
+
+import torch
+import fairseq
+from transformers import PretrainedConfig
+from omegaconf import OmegaConf
+
+@dataclass
+class UserDirModule:
+ user_dir: str
+
+def load_model(model_dir, checkpoint_path):
+ '''Load Fairseq SSL model'''
+
+ model_path = UserDirModule(model_dir)
+ fairseq.utils.import_user_module(model_path)
+
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], strict=False)
+ model = model[0]
+
+ return model, cfg
+
+def main(args):
+ model_dir = args.model_dir
+ checkpoint_path = args.checkpoint_path
+ save_dir = args.save_dir
+
+ # save model and config.json
+ model, cfg = load_model(model_dir, checkpoint_path)
+ model.muq.save_pretrained(save_dir)
+ with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf8') as w:
+ model_cfg = OmegaConf.to_container(cfg['model'], resolve=True)
+ del model_cfg['_name']
+ json.dump(model_cfg, w, indent=2)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_dir", type=str, help="path to the `MuQ/src/recipes/pretrain` working directory")
+ parser.add_argument("--checkpoint_path", type=str, help="path to the fairseq checkpoint")
+ parser.add_argument("--save_dir", type=str, help="path to the result directory")
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/scripts/inference_muq_huggingface.py b/src/third_party/MuQ/src/recipes/pretrain/scripts/inference_muq_huggingface.py
new file mode 100644
index 0000000000000000000000000000000000000000..8422e356d472935265beb064c89e256bcefe4356
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/scripts/inference_muq_huggingface.py
@@ -0,0 +1,16 @@
+import torch, librosa
+from muq import MuQ
+
+device = 'cuda'
+wav, sr = librosa.load("path/to/music_audio.wav", sr = 24000)
+wavs = torch.tensor(wav).unsqueeze(0).to(device)
+
+# Use local huggingface checkpoint
+muq = MuQ.from_pretrained("./output/hf-username/My-MuQ-large-huggingface")
+muq = muq.to(device).eval()
+
+with torch.no_grad():
+ output = muq(wavs, output_hidden_states=True)
+
+print('Total number of layers: ', len(output.hidden_states))
+print('Feature shape: ', output.last_hidden_state.shape)
\ No newline at end of file
diff --git a/src/third_party/MuQ/src/recipes/pretrain/scripts/run_training_muq.sh b/src/third_party/MuQ/src/recipes/pretrain/scripts/run_training_muq.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c29e300e4ac62eb2255db1b077c8da3ae39086d1
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/scripts/run_training_muq.sh
@@ -0,0 +1,82 @@
+WORKER_RANK=${1:-0}
+YAML_NAME_WITHOUT_EXT=${2:-'MuQ_large_multinodes_v100'}
+TRAINING_SETTING=${3:-'MUQ'}
+MASTER_PROC_ADD=${4:-$CHIEF_IP}
+DIST_PORT=${5:-'25520'}
+DATASET_NAME=${6:-'music4all'}
+NNODS=${7:-4}
+NPROCES_PER_NODE=${8:-8}
+
+echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}"
+
+MAP_PROJ_DIR=$(pwd)
+echo $MAP_PROJ_DIR
+
+NUM_WOKERS=0
+
+run_command_prefix=' '
+# Loading folders
+# 1. json files for audio paths
+DATA_DIR=${MAP_PROJ_DIR}/dataset/${DATASET_NAME} #audio_manifest
+# 2. working folder for saving checkpoints and loading config files
+CONFIG_DIR=/${MAP_PROJ_DIR}/config/pretrain
+
+FAIRSEQ_PATH=${MAP_PROJ_DIR}/fairseq;
+SAVE_DIR=${MAP_PROJ_DIR}/output/
+
+LABEL_RATE=25
+
+case $YAML_NAME_WITHOUT_EXT in
+ MuQ_large_multinodes_v100)
+ TASK_LABELS_POSTFIX='[]'
+ # NNODS=4
+ LABEL_RATE=25
+ # NPROCES_PER_NODE=8
+ MAX_TOKENS=2400000
+ ;;
+ MuQ_large_iter_multinodes_v100)
+ TASK_LABELS_POSTFIX='[]'
+ # NNODS=4
+ LABEL_RATE=25
+ # NPROCES_PER_NODE=8
+ MAX_TOKENS=2400000
+ OTHER_PARAMS=" task.label_scp_path=${MAP_PROJ_DIR}/data/msd_ark/reiter_musicssl_msd/train.scp "
+ ;;
+ *)
+ echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT} = ${YAML_NAME_WITHOUT_EXT}"
+ exit 1
+ ;;
+ esac
+
+ echo running $YAML_NAME_WITHOUT_EXT ..
+
+ mkdir -p ${SAVE_DIR}
+ echo "checkpoint save at: ${SAVE_DIR}"
+ cd ${SAVE_DIR}
+
+ echo "NPROCES_PER_NODE is ${NPROCES_PER_NODE}"
+
+ DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}`
+ ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}`
+ echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}"
+
+ DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"`
+
+ OMP_NUM_THREADS=6 ${run_command_prefix} \
+ python3 -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
+ --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \
+ common.user_dir=${MAP_PROJ_DIR}/ \
+ common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \
+ checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \
+ distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \
+ distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \
+ distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \
+ distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \
+ distributed_training.distributed_init_method="tcp://${MASTER_PROC_ADD}:${DIST_PORT}" \
+ task.data=${DATA_DIR} \
+ task.label_dir=${LABEL_DIR} \
+ task.labels=${TASK_LABELS_POSTFIX} \
+ dataset.num_workers=${NUM_WOKERS} \
+ dataset.max_tokens=${MAX_TOKENS} \
+ dataset.disable_validation=true \
+ ${OTHER_PARAMS} \
diff --git a/src/third_party/MuQ/src/recipes/pretrain/tasks/muq_pretraining.py b/src/third_party/MuQ/src/recipes/pretrain/tasks/muq_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3374399e343e767c0830ea08b8d46aa45b89d2
--- /dev/null
+++ b/src/third_party/MuQ/src/recipes/pretrain/tasks/muq_pretraining.py
@@ -0,0 +1,453 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import logging
+import os
+import sys
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+
+from dataclasses import dataclass, field
+from fairseq.data import Dictionary, HubertDataset
+from fairseq.dataclass.configs import FairseqDataclass
+from fairseq.tasks import register_task
+from fairseq.tasks.fairseq_task import FairseqTask
+from omegaconf import MISSING
+
+from ..data.mert_dataset import MERTDataset
+from ..data.ark_dataset import ArkDataset
+
+logger = logging.getLogger(__name__)
+
+
+class LabelEncoder(object):
+ def __init__(self, dictionary: Dictionary) -> None:
+ self.dictionary = dictionary
+
+ def __call__(self, label: str) -> List[str]:
+ # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT
+ return self.dictionary.encode_line(
+ label,
+ append_eos=False,
+ add_if_not_exist=False,
+ )
+class PaddedNumpyLabelEncoder(object):
+ def __init__(self):
+ # self.dictionary = dictionary
+ pass
+
+ def __call__(self, label):
+ t = torch.IntTensor(np.asarray(label))
+ t = t[t>=0] # remove padded -1 values at the end
+ return t
+
+@dataclass
+class MuQPretrainingConfig(FairseqDataclass):
+ data: str = field(default=MISSING, metadata={"help": "path to data directory"})
+ sharding_data: int = field(
+ default=-1,
+ metadata={
+ "help": "set this para >1 to use sharding dataset to prevent OOM"
+ "prepare data tsv and label files by adding postfix for sharding 64 like:"
+ "train_28_64.tsv and train_28_64.encodec_6"
+ },
+ )
+ load_random_data_shard: bool = field(
+ default=True,
+ metadata={
+ "help": "whether to laod shards randomly or in order when use sharding_data"
+ },
+ )
+ fine_tuning: bool = field(
+ default=False, metadata={"help": "set to true if fine-tuning Hubert"}
+ )
+ labels: List[str] = field(
+ default_factory=lambda: ["ltr"],
+ metadata={
+ "help": (
+ "extension of the label files to load, frame-level labels for"
+ " pre-training, and sequence-level label for fine-tuning"
+ )
+ },
+ )
+ label_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "if set, looks for labels in this directory instead",
+ },
+ )
+ label_scp_path: Optional[str] = field(
+ default=None,
+ metadata={
+ 'help': 'if set, load label from scp file'
+ }
+ )
+ label_scp_clip_duration: float = field(
+ default=-1,
+ metadata={
+ 'help': 'clip duration for loading scp label. if set to -1, this will not make effect.'
+ }
+ )
+ label_rate: float = field(
+ default=-1.0,
+ metadata={"help": "label frame rate. -1.0 for sequence label"},
+ )
+ sample_rate: int = field(
+ default=16_000,
+ metadata={
+ "help": "target sample rate. audio files will be up/down "
+ "sampled to this rate"
+ },
+ )
+ normalize: bool = field(
+ default=False,
+ metadata={"help": "if set, normalizes input to have 0 mean and unit variance"},
+ )
+ enable_padding: bool = field(
+ default=False,
+ metadata={"help": "pad shorter samples instead of cropping"},
+ )
+ max_keep_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "exclude sample longer than this"},
+ )
+ max_sample_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "max sample size to crop to for batching"},
+ )
+ min_sample_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "min sample size to crop to for batching"},
+ )
+ single_target: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset"
+ },
+ )
+ random_crop: Optional[bool] = field(
+ default=True,
+ metadata={"help": "always crop from the beginning if false"},
+ )
+ pad_audio: Optional[bool] = field(
+ default=False,
+ metadata={"help": "pad audio to the longest one in the batch if true"},
+ )
+
+ store_labels: Optional[bool] = field(
+ default=False,
+ metadata={"help": "whether to load all of the label into memory"},
+ )
+
+ numpy_memmap_label: Optional[bool] = field(
+ default=False,
+ metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"},
+ )
+
+ augmentation_effects: Optional[str] = field(
+ default="[]",
+ metadata={
+ "help": (
+ "a list of effects that might apply to the audios"
+ "example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" "
+ "supported: random_mute,"
+ "todo: "
+ )
+ },
+ )
+ augmentation_probs: Optional[str] = field(
+ default="[]",
+ metadata={
+ "help": (
+ "the corresponding probabilities for the data augmentation effects"
+ "example: \"[0.1, 0.5, 0.8]\" "
+ "the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio"
+ )
+ },
+ )
+
+ # inbatch_noise_augment_len_range: Optional[List[int]] = field(
+ # default_factory=lambda: [8000, 24000],
+ # default = [8000, 24000],
+ inbatch_noise_augment_len_range: Optional[str] = field(
+ default = "[8000, 24000]",
+ metadata={
+ "help": (
+ "the range of length of the mix-up noise augmentation, unit in smaples"
+ )
+ },
+ )
+ # inbatch_noise_augment_number_range: Optional[List[int]] = field(
+ # default_factory=lambda: [1, 3],
+ # default = [1, 3],
+ inbatch_noise_augment_number_range: Optional[str] = field(
+ default = "[1, 3]",
+ metadata={
+ "help": (
+ "the range of numbers of the mix-up noise augmentation"
+ )
+ },
+ )
+ inbatch_noise_augment_volume: float = field(
+ default = 1.0,
+ metadata={
+ "help": (
+ "the coefficient used to modify the volume of the noise audios wavs"
+ )
+ },
+ )
+ dynamic_crops: Optional[str] = field(
+ default="[]",
+ metadata={
+ "help": (
+ "used to set the maximum audio length setting, for training"
+ "example: \"[1, 2, 3, 4, 5, 10]\" "
+ )
+ },
+ )
+ dynamic_crops_epoches: Optional[str] = field(
+ default="[]",
+ metadata={
+ "help": (
+ "used to set training epoches of changing the maximum audio length"
+ "example: \"[1, 10, 20, 40, 80, 160,]\" "
+ "then len need to be equal to len(dynamic_crops)"
+ )
+ },
+ )
+
+ cqt_loss_bin_dataloader: Optional[int] = field(
+ default=-1,
+ metadata={
+ "help": (
+ "use this parameter to prepare cqt prediction objective in dataloader"
+ )
+ },
+ )
+
+ clip_secs: int = field(
+ default=5,
+ metadata={
+ "help": "clip secs for each audio"
+ }
+ )
+
+ dataset_shuffle: bool = field(
+ default=True,
+ metadata={
+ "help": (
+ "dataset shuffle when sample a batch"
+ )
+ },
+ )
+
+
+@register_task("muq_pretraining", dataclass=MuQPretrainingConfig)
+class MuQPretrainingTask(FairseqTask):
+
+ cfg: MuQPretrainingConfig
+
+ def __init__(
+ self,
+ cfg: MuQPretrainingConfig,
+ ) -> None:
+ super().__init__(cfg)
+
+ logger.info(f"current directory is {os.getcwd()}")
+ logger.info(f"MuQPretrainingTask Config {cfg}")
+
+ self.cfg = cfg
+ self.fine_tuning = cfg.fine_tuning
+
+ if cfg.fine_tuning:
+ self.state.add_factory("target_dictionary", self.load_dictionaries)
+ else:
+ self.state.add_factory("dictionaries", self.load_dictionaries)
+
+ self.blank_symbol = ""
+
+ # use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle : attribute lookup Choices on fairseq.dataclass.constants failed
+ self.augmentation_effects = eval(self.cfg.augmentation_effects)
+ self.augmentation_probs = eval(self.cfg.augmentation_probs)
+ if len(self.augmentation_effects) > 0:
+ assert len(self.augmentation_effects) == len(self.augmentation_probs)
+ logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}")
+
+ self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range)
+ self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range)
+
+ self.max_sample_size = self.cfg.max_sample_size
+
+ self.dynamic_crops = eval(self.cfg.dynamic_crops)
+ self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches)
+ assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches)
+ if len(self.dynamic_crops) > 0:
+ assert self.dynamic_crops_epoches[0] == 1
+
+ self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader
+
+ self.numpy_memmap_label = self.cfg.numpy_memmap_label
+ self.store_labels = self.cfg.store_labels
+ if self.numpy_memmap_label:
+ assert self.store_labels
+
+ @property
+ def source_dictionary(self) -> Optional[Dictionary]:
+ return None
+
+ @property
+ def target_dictionary(self) -> Optional[Dictionary]:
+ return self.state.target_dictionary
+
+ @property
+ def dictionaries(self) -> List[Dictionary]:
+ return self.state.dictionaries
+
+ @classmethod
+ def setup_task(
+ cls, cfg: MuQPretrainingConfig, **kwargs
+ ) -> "MuQPretrainingTask":
+ return cls(cfg)
+
+ def load_dictionaries(self):
+ label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir
+ print(label_dir)
+ dictionaries = [
+ Dictionary.load(f"{label_dir}/dict.{label}.txt")
+ for label in self.cfg.labels
+ ]
+ return dictionaries[0] if self.cfg.fine_tuning else dictionaries
+
+ def get_label_dir(self) -> str:
+ if self.cfg.label_dir is None or self.cfg.label_dir=='':
+ return self.cfg.data
+ return self.cfg.label_dir
+
+ # def has_sharded_data(self, split):
+ # """overwrite this function for let the trainier do dataset reload for changing the the dynamic croppings"""
+ # logger.info(f"check whether to re-load dataset for epoch {epoch} by overwritting task.has_sharded_data()")
+ # # find the threshold that holds epoch \in [threshold, next_threshold)
+ # is_reload_dataset = epoch in self.dynamic_crops_epoches
+
+ # return os.pathsep in getattr(self.cfg, "data", "") or is_reload_dataset
+ # def is_force_load_dataset(self, epoch):
+ def is_force_load_dataset(self, epoch, training_restore=False):
+ # find the threshold that holds epoch \in [threshold, next_threshold)
+ return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1)
+ # for idx in range(len(self.dynamic_crops_epoches)):
+ # if (idx == len(self.dynamic_crops_epoches)-1) or \
+ # (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]):
+ # return True
+ # return False
+
+ def set_dynamic_crop_max_sample(self, epoch):
+ """ force to set the max_sample_size config for the dynamic cropping function"""
+ if epoch in self.dynamic_crops_epoches:
+ for idx in range(len(self.dynamic_crops_epoches)):
+ if (idx == len(self.dynamic_crops_epoches)-1) or \
+ (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]):
+ # set new cropping parameters and end loop
+ self.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate
+ self.cfg.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate
+ logger.info(f"epoch {epoch} forcely set new maximum audio length as {self.dynamic_crops[idx]}s == {self.max_sample_size} samples")
+ break
+ # logger.info(f'reloading dataset for changing the sequence length')
+ # self.load_dataset('train')
+ def load_dataset(self, split: str, **kwargs) -> None:
+ if len(list(filter(lambda s: s.endswith('.scp'), os.listdir(self.cfg.data)))) > 0:
+ return self.load_dataset_ark(split, **kwargs)
+ else:
+ return self.load_dataset_mert(split, **kwargs)
+
+ def load_dataset_ark(self, split, **kwargs):
+ if 'train' not in split:
+ logger.info(f'split {split} is only used for training')
+ # raise ValueError(f"No support for split: {split}")
+ else:
+ self.datasets[split] = ArkDataset(
+ wav_scp=os.path.join(self.cfg.data, f"wav_ark.scp"),
+ dur_scp=os.path.join(self.cfg.data, f"dur_ark.scp"),
+ sr=self.cfg.sample_rate,
+ )
+
+ def load_dataset_mert(self, split: str, **kwargs) -> None:
+ if 'train' in split:
+ epoch = kwargs['epoch']
+ # the epoch to change crops
+ if self.is_force_load_dataset(epoch):
+ self.set_dynamic_crop_max_sample(epoch)
+
+ # load all training data
+ if self.cfg.sharding_data <= 1:
+ # manifest = f"{self.cfg.data}/{split}.tsv"
+ manifest = f"{self.cfg.data}/{split}.json"
+
+ paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels]
+ # load part of the training data
+ else:
+ if self.cfg.load_random_data_shard:
+ data_shard_idx = np.random.randint(self.cfg.sharding_data)
+ else:
+ data_shard_idx = (epoch-1) % self.cfg.sharding_data # epoch start from 1
+ assert data_shard_idx < self.cfg.sharding_data
+ logger.info(f'loading shard {data_shard_idx} of {self.cfg.sharding_data} training data for ecpoh {epoch}')
+
+ # manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.tsv"
+ manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.json"
+
+ paths = [f"{self.get_label_dir()}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.{l}" for l in self.cfg.labels]
+ else:
+ # manifest = f"{self.cfg.data}/{split}.tsv"
+ manifest = f"{self.cfg.data}/{split}.json"
+
+ paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels]
+
+ dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries
+ pad_list = [dict.pad() for dict in dicts]
+ eos_list = [dict.eos() for dict in dicts]
+
+ if self.numpy_memmap_label:
+ procs = [PaddedNumpyLabelEncoder() for dict in dicts]
+ else:
+ procs = [LabelEncoder(dict) for dict in dicts]
+
+ self.datasets[split] = MERTDataset(
+ manifest,
+ sample_rate=self.cfg.sample_rate,
+ label_paths=paths, # this containes the ensemble label sequence names
+ label_rates=self.cfg.label_rate,
+ pad_list=pad_list,
+ eos_list=eos_list,
+ label_scp_path=self.cfg.label_scp_path,
+ label_scp_clip_duration=self.cfg.label_scp_clip_duration,
+ label_processors=procs,
+ max_keep_sample_size=self.cfg.max_keep_size,
+ min_keep_sample_size=self.cfg.min_sample_size,
+ max_sample_size=self.max_sample_size,
+ pad_audio=self.cfg.pad_audio,
+ normalize=self.cfg.normalize,
+ store_labels=self.store_labels,
+ npmemmap=self.numpy_memmap_label,
+ random_crop=self.cfg.random_crop,
+ single_target=self.cfg.single_target,
+ augmentation_effects=self.augmentation_effects,
+ augmentation_probs=self.augmentation_probs,
+ inbatch_noise_augment_len_range=self.inbatch_noise_augment_len_range,
+ inbatch_noise_augment_number_range=self.inbatch_noise_augment_number_range,
+ inbatch_noise_augment_volume=self.cfg.inbatch_noise_augment_volume,
+ cqt_prediction_bin=self.cqt_loss_bin_dataloader,
+ clip_secs=self.cfg.clip_secs,
+ shuffle=self.cfg.dataset_shuffle,
+ )
+
+ def max_positions(self) -> Tuple[int, int]:
+ return (sys.maxsize, sys.maxsize)
+
+ def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
+ return indices
diff --git a/src/third_party/musicfm/.gitignore b/src/third_party/musicfm/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7d7cc2e70e656a230ce0f1649a70836bb39c6ba3
--- /dev/null
+++ b/src/third_party/musicfm/.gitignore
@@ -0,0 +1,10 @@
+# mac
+.DS_Store
+
+# cache
+*.pyc
+
+# data
+*.json
+*.pt
+
diff --git a/src/third_party/musicfm/LICENSE b/src/third_party/musicfm/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9d8e4655469cb0acdae1e5ae01128c6499eaea09
--- /dev/null
+++ b/src/third_party/musicfm/LICENSE
@@ -0,0 +1,224 @@
+Dual Licensing Information
+-------------------------
+
+This software is dual-licensed under both the MIT License and the Apache License, Version 2.0.
+
+- The file `modules/flash_conformer.py` is distributed under the terms of the Apache License, Version 2.0.
+- All other files and modules in this software are distributed under the terms of the MIT License.
+
+### MIT License
+
+Copyright 2023 ByteDance Inc.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+
+### Apache License, Version 2.0
+
+Copyright 2018- The Hugging Face team. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/src/third_party/musicfm/README.md b/src/third_party/musicfm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7e8b77522214da5fc54d339fa175ce4b6fb7184d
--- /dev/null
+++ b/src/third_party/musicfm/README.md
@@ -0,0 +1,173 @@
+# MusicFM 🤖
+[](https://opensource.org/licenses/MIT)
+[](https://www.apache.org/licenses/LICENSE-2.0.html)
+
+
+**A Foundation Model for Music Informatics**, ICASSP 2024 [[paper](https://arxiv.org/abs/2311.03318)]
+
+-- Minz Won, Yun-Ning Hung, and Duc Le
+
+
+## Quick start
+### Download models
+
+**MusicFM-FMA**
+
+- Pretrained using [FMA-large](https://github.com/mdeff/fma) data
+
+```
+wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/fma_stats.json
+wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_fma.pt
+```
+⚠️ The model checkpoint prior to Feb 13, 2024, was incorrect. Please ensure to re-download these files if you've been using previous versions.
+
+
+**MusicFM-MSD**
+
+- Pretrained with the entire [Million Song Dataset](http://millionsongdataset.com/)
+- This version performs better than the FMA version
+- This version is not introduced in the paper
+
+```
+wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json
+wget -P YOUR_HOME_PATH/musicfm/data/ https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt
+```
+
+### Get embeddings
+```
+HOME_PATH = "/home/dev" # path where you cloned musicfm
+
+import os
+import sys
+import torch
+
+sys.path.append(HOME_PATH)
+from musicfm.model.musicfm_25hz import MusicFM25Hz
+
+# dummy audio (30 seconds, 24kHz)
+wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
+
+# load MusicFM
+musicfm = MusicFM25Hz(
+ is_flash=False,
+ stat_path=os.path.join(HOME_PATH, "musicfm", "data", "msd_stats.json"),
+ model_path=os.path.join(HOME_PATH, "musicfm", "data", "pretrained_msd.pt"),
+)
+
+# to GPUs
+wav = wav.cuda()
+musicfm = musicfm.cuda()
+
+# get embeddings
+musicfm.eval()
+emb = musicfm.get_latent(wav, layer_ix=7)
+```
+
+### Mixed precision and Flash attention
+Suffering from memory issues? [Mixed precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) and [Flash attention](https://arxiv.org/abs/2205.14135) will be good friends of yours!
+
+```
+# dummy audio (30 seconds, 24kHz)
+wav = (torch.rand(4, 24000 * 30) - 0.5) * 2
+
+# load MusicFM
+musicfm = MusicFM25Hz(is_flash=True)
+
+# to GPUs
+wav = wav.cuda().half()
+musicfm = musicfm.cuda().half()
+
+# get embeddings
+musicfm.eval()
+emb = musicfm.get_latent(wav, layer_ix=7)
+```
+
+However, I highly recommend using `float32` for better performance in specific downstream tasks, such as beat tracking.
+
+### Usage in downstream tasks
+The pretrained model operates at a 25Hz frame rate, but our downstream tasks demand varying temporal resolutions. To address this, we either summarize the sequence through global average pooling or adjust the temporal resolution using adaptive average pooling.
+
+```
+from torch import nn
+
+# Sequence-level representation
+seq_emb = emb.mean(-1) # (batch, time, channel) -> (batch, channel)
+
+# Frame-level representation
+"""
+ n_frame = desired_temporal_resolution * sequence_length_in_sec
+ 300 frames = 10Hz * 30s in this example
+ As a result, the sequence length becomes from 750 (25Hz * 30s) to 300
+"""
+n_frame = 300
+token_emb = nn.AdaptiveAvgPool1d(n_frame)(emb) # (batch, time, channel) -> (batch, time', channel)
+```
+We share the details of our downstream evaluation as follows. The selection of input lengths and temporal resolutions is based on our prior experience with each task.
+
+| | Beat | Chord | Structure | Key | Tagging |
+| :--------: | :--------: | :--------: | :--------: | :--------: | :--------: |
+| Input length | 6s | 12s | 24s | 12s | 29.1s |
+| Temporal resolution | 50Hz | 16Hz | 8Hz | 0.5Hz | - |
+| n_frame | 300 | 192 | 192 | 6 | 1 |
+
+### Fine-tuning
+You can expect better performance in downstream tasks by fine-tuning the foundation model. In this scenario, employ `musicfm.train()` and extract the final embeddings by setting `layer_ix=12`. However, when optimizing the model with the same learning rate, there's a risk of [catastrophic forgetting](https://en.wikipedia.org/wiki/Catastrophic_interference). To mitigate this issue, we utilized a learning rate of 1e-5 for the foundation model and 1e-4 for the probing layers.
+
+
+
+## Results
+
+
+
+\* FM1 is pretrained [MERT](https://arxiv.org/abs/2306.00107).
+
+\*\*FM8 mirrors the [BEST-RQ](https://arxiv.org/abs/2202.01855) but with the distinction that it was trained using music data.
+
+
+- Random tokenization generalizes well to music data.
+
+- Frame-level classification offers a more comprehensive understanding of foundation models. While FM4 excels in music tagging, its performance in structural analysis is subpar.
+
+- Input length used during training is critical for capturing
+long-term contexts. Check 5s models (FM1, FM2, and FM4) and a 30s model (FM5) in downbeat tracking and structure analysis.
+
+- Temporal resolution has less impact in our experimental setup. See FM5, FM6, and FM7.
+
+- Model architecture makes a significant difference. Conformer (FM5) consistently outperformed BERT encoder (FM3) for across all downstream tasks.
+
+- The influence of model size was relatively minimal (FM7 and FM8). However, we observed that FM8's performance continued to improve, which is typically indicative of underfitting. All models were trained for two weeks to ensure a fair comparison.
+
+- Data is undeniably crucial, as in any data-driven approach. Please compare FM7 and FM9.
+
+- Fine-tuning the foundation model further enhances downstream performance. However, we did observe a performance
+drop in the tagging task, primarily attributed to overfitting.
+
+## Masked token modeling
+
+
+MusicFM follows the training scheme of [BEST-RQ](https://arxiv.org/abs/2202.01855). Input audio is masked with noise, and the model predicts the masked representation. Target tokens are generated by random projection and a random codebook. Both the projection layer and codebook are **randomly initialized** and remain **non-trainable**. Isn't it fascinating?
+
+Note that input normalization is exceptionally crucial, considering the usage of random projection. You can check the details [here](https://github.com/minzwon/musicfm/blob/d5d0f313add9f3c32c41f95521760b1a136809ed/model/musicfm_25hz.py#L148).
+
+## Limitations
+- Self-supervised foundation models in music, such as [JukeMIR](https://arxiv.org/abs/2107.05677), [MERT](https://arxiv.org/abs/2306.00107), and [MusicFM](https://arxiv.org/abs/2311.03318), consistently report relatively low performance in key detection. While fine-tuning the model can help bridge the performance gap, the foundation model itself does not appear to learn musical keys inherently. Further investigation is required to develop more advanced music foundation models.
+
+- We share our model trained with the [FMA Dataset](https://github.com/mdeff/fma), which comprises 8k hours of Creative Common-licensed audio. While using a larger dataset (160k hours) can enhance performance, we've chosen to release the model trained on FMA to avoid potential licensing complications.
+
+- Fine-tuned models for downstream tasks are not made publicly available as they are primarily used for evaluation purposes. It is expected that carefully designed backends beyond simple probing layers will improve downstream performance. I look forward to the contributions of other researchers with more expertise in each specific task.
+
+- The downstream evaluation pipeline is not provided in this repository. Nonetheless, I believe creating a comprehensive evaluation pipeline is essential to expedite progress in music informatics research. I'm very open to discussing it together.
+
+
+## Acknowledgement
+We acknowledge and extend our sincere gratitude to Ju-Chiang Wang for his valuable contributions to data refinement and providing a crucial codebase for our downstream evaluation.
+
+## Citation
+```
+@article{won2023musicfm,
+ title={A Foundation Model for Music Informatics},
+ author = {Won, Minz and Hung, Yun-Ning and Le, Duc},
+ journal={arXiv preprint arXiv:2311.03318},
+ year={2023}
+}
+```
diff --git a/src/third_party/musicfm/data/.gitkeep b/src/third_party/musicfm/data/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/third_party/musicfm/figs/Fig1.png b/src/third_party/musicfm/figs/Fig1.png
new file mode 100644
index 0000000000000000000000000000000000000000..b9d1870e87f0486e62b8e7c1865622b07cd078fc
--- /dev/null
+++ b/src/third_party/musicfm/figs/Fig1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbbb7a435402555125e996c747a619585906bc2cb7911afa5521ac35af1201e3
+size 396320
diff --git a/src/third_party/musicfm/figs/Table1.png b/src/third_party/musicfm/figs/Table1.png
new file mode 100644
index 0000000000000000000000000000000000000000..a9c349a52a4b294a8e1f1a4799731871068db805
--- /dev/null
+++ b/src/third_party/musicfm/figs/Table1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:773da09077da92f3e41fb9a53aff4efc559dcfc07d2e65b142cf23af2de512d7
+size 806922
diff --git a/src/third_party/musicfm/model/__init__.py b/src/third_party/musicfm/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438
--- /dev/null
+++ b/src/third_party/musicfm/model/__init__.py
@@ -0,0 +1,2 @@
+
+
diff --git a/src/third_party/musicfm/model/musicfm_25hz.py b/src/third_party/musicfm/model/musicfm_25hz.py
new file mode 100644
index 0000000000000000000000000000000000000000..e567c562c44e739beb42ebff9234303652670f75
--- /dev/null
+++ b/src/third_party/musicfm/model/musicfm_25hz.py
@@ -0,0 +1,252 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import json
+import random
+import torch
+from torch import nn
+from einops import rearrange
+
+from musicfm.modules.random_quantizer import RandomProjectionQuantizer
+from musicfm.modules.features import MelSTFT
+from musicfm.modules.conv import Conv2dSubsampling
+
+
+class MusicFM25Hz(nn.Module):
+ """
+ MusicFM
+
+ Input: 128-band mel spectrogram
+ Frontend: 2-layer Residual convolution
+ Backend: 12-layer Conformer
+ Quantizer: a codebook for mel spectrogram
+ """
+
+ def __init__(
+ self,
+ num_codebooks=1,
+ codebook_dim=16,
+ codebook_size=4096,
+ features=["melspec_2048"],
+ hop_length=240,
+ n_mels=128,
+ conv_dim=512,
+ encoder_dim=1024,
+ encoder_depth=12,
+ mask_hop=0.4,
+ mask_prob=0.6,
+ is_flash=False,
+ stat_path="./data/fma_stats.json",
+ model_path="./data/pretrained_fma.pt",
+ ):
+ super(MusicFM25Hz, self).__init__()
+
+ # global variables
+ self.hop_length = hop_length
+ self.mask_hop = mask_hop
+ self.mask_prob = mask_prob
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.features = features
+
+ # load feature mean / std stats
+ with open(stat_path, "r") as f:
+ self.stat = json.load(f)
+
+ # feature extractor
+ self.preprocessor_melspec_2048 = MelSTFT(
+ n_fft=2048, hop_length=hop_length, is_db=True
+ )
+
+ # random quantizer
+ seed = 142
+ for feature in self.features:
+ for i in range(num_codebooks):
+ setattr(
+ self,
+ f"quantizer_{feature}_{i}",
+ RandomProjectionQuantizer(
+ n_mels * 4, codebook_dim, codebook_size, seed=seed + i
+ ),
+ )
+
+ # two residual convolution layers + one projection layer
+ self.conv = Conv2dSubsampling(
+ 1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels
+ )
+
+ # Conformer
+ if is_flash:
+ from modules.flash_conformer import (
+ Wav2Vec2ConformerEncoder,
+ Wav2Vec2ConformerConfig,
+ )
+ else:
+ from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ Wav2Vec2ConformerEncoder,
+ Wav2Vec2ConformerConfig,
+ )
+ config = Wav2Vec2ConformerConfig.from_pretrained(
+ "facebook/wav2vec2-conformer-rope-large-960h-ft"
+ )
+ config.num_hidden_layers = encoder_depth
+ config.hidden_size = encoder_dim
+
+ self.conformer = Wav2Vec2ConformerEncoder(config)
+
+ # projection
+ self.linear = nn.Linear(encoder_dim, codebook_size)
+
+ # loss function
+ self.loss = nn.CrossEntropyLoss()
+
+ # cls token (used for sequence classification)
+ random.seed(seed)
+ self.cls_token = nn.Parameter(torch.randn(encoder_dim))
+
+ # load model
+ if model_path:
+ S = torch.load(model_path)["state_dict"]
+ SS = {k[6:]: v for k, v in S.items()}
+ self.load_state_dict(SS, strict=True)
+
+ def masking(self, x):
+ """random masking of 400ms with given probability"""
+ mx = x.clone()
+ b, t = mx.shape
+ len_masking_raw = int(24000 * self.mask_hop)
+ len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
+
+ # get random mask indices
+ start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
+ time_domain_masked_indices = torch.nonzero(
+ start_indices.repeat_interleave(len_masking_raw, dim=1)
+ )
+ token_domain_masked_indices = torch.nonzero(
+ start_indices.repeat_interleave(len_masking_token, dim=1)
+ )
+
+ # mask with random values
+ masking_noise = (
+ torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
+ ) # 0 mean 0.1 std
+ mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
+
+ return mx, token_domain_masked_indices
+
+ @torch.no_grad()
+ def preprocessing(self, x, features):
+ """extract classic audio features"""
+ # check precision
+ if x.dtype == torch.float16:
+ precision = 16
+ else:
+ precision = 32
+
+ out = {}
+ for key in features:
+ layer = getattr(self, "preprocessor_%s" % key)
+ out[key] = layer.float()(x.float())[..., :-1]
+ if precision == 16:
+ out[key] = out[key].half()
+ return out
+
+ def encoder(self, x):
+ """2-layer conv + w2v-conformer"""
+ x = self.conv(x)
+ out = self.conformer(x, output_hidden_states=True)
+ hidden_emb = out["hidden_states"]
+ last_emb = out["last_hidden_state"]
+ logits = self.linear(last_emb)
+ logits = {
+ key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
+ for i, key in enumerate(self.features)
+ }
+ return logits, hidden_emb
+
+ @torch.no_grad()
+ def normalize(self, x):
+ """normalize the input audio to have zero mean unit variance"""
+ for key in x.keys():
+ x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
+ return x
+
+ @torch.no_grad()
+ def rearrange(self, x):
+ """rearrange the batch to flatten every 4 steps"""
+ for key in x.keys():
+ if key == "chromagram":
+ x[key] = rearrange(x[key], "b f t -> b t f")
+ else:
+ x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
+ return x
+
+ @torch.no_grad()
+ def tokenize(self, x):
+ out = {}
+ for key in x.keys():
+ layer = getattr(self, "quantizer_%s" % key)
+ out[key] = layer(x[key])
+ return out
+
+ def get_targets(self, x):
+ x = self.preprocessing(x, features=self.features)
+ x = self.normalize(x)
+ x = self.rearrange(x)
+ target_tokens = self.tokenize(x)
+ return target_tokens
+
+ def get_predictions(self, x):
+ # preprocessing
+ x = self.preprocessing(x, features=["melspec_2048"])
+ x = self.normalize(x)
+
+ # encoding
+ logits, hidden_emb = self.encoder(x["melspec_2048"])
+
+ return logits, hidden_emb
+
+ def get_latent(self, x, layer_ix=12):
+ _, hidden_states = self.get_predictions(x)
+ emb = hidden_states[layer_ix]
+ return emb
+
+ def get_loss(self, logits, target_tokens, masked_indices):
+ losses = {}
+ accuracies = {}
+ for key in logits.keys():
+ masked_logits = logits[key][tuple(masked_indices.t())]
+ masked_tokens = target_tokens[key][tuple(masked_indices.t())]
+ losses[key] = self.loss(masked_logits, masked_tokens)
+ accuracies[key] = (
+ torch.sum(masked_logits.argmax(-1) == masked_tokens)
+ / masked_tokens.numel()
+ )
+ return losses, accuracies
+
+ def forward(self, x):
+ # get target feature tokens
+ target_tokens = self.get_targets(x)
+
+ # masking
+ x, masked_indices = self.masking(x)
+
+ # forward
+ logits, hidden_emb = self.get_predictions(x)
+
+ # get loss
+ losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
+
+ return logits, hidden_emb, losses, accuracies
diff --git a/src/third_party/musicfm/modules/__init__.py b/src/third_party/musicfm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438
--- /dev/null
+++ b/src/third_party/musicfm/modules/__init__.py
@@ -0,0 +1,2 @@
+
+
diff --git a/src/third_party/musicfm/modules/conv.py b/src/third_party/musicfm/modules/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cc1a8f16cd103d09c86ace5b7c48b0583134e02
--- /dev/null
+++ b/src/third_party/musicfm/modules/conv.py
@@ -0,0 +1,82 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+from torch import nn
+from einops import rearrange
+
+
+class Res2dModule(nn.Module):
+ def __init__(self, idim, odim, stride=(2, 2)):
+ super(Res2dModule, self).__init__()
+ self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+ self.bn1 = nn.BatchNorm2d(odim)
+ self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
+ self.bn2 = nn.BatchNorm2d(odim)
+ self.relu = nn.ReLU()
+
+ # residual
+ self.diff = False
+ if (idim != odim) or (stride[0] > 1):
+ self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
+ self.bn3 = nn.BatchNorm2d(odim)
+ self.diff = True
+
+ def forward(self, x):
+ out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
+ if self.diff:
+ x = self.bn3(self.conv3(x))
+ out = x + out
+ out = self.relu(out)
+ return out
+
+
+class Conv2dSubsampling(nn.Module):
+ """Convolutional 2D subsampling (to 1/4 length).
+
+ Args:
+ idim (int): Input dimension.
+ hdim (int): Hidden dimension.
+ odim (int): Output dimension.
+ strides (list): Sizes of strides.
+ n_bands (int): Number of frequency bands.
+ """
+
+ def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64):
+ """Construct an Conv2dSubsampling object."""
+ super(Conv2dSubsampling, self).__init__()
+
+ self.conv = nn.Sequential(
+ Res2dModule(idim, hdim, (2, strides[0])),
+ Res2dModule(hdim, hdim, (2, strides[1])),
+ )
+ self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
+
+ def forward(self, x):
+ """Subsample x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, idim, time).
+
+ Returns:
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
+ where time' = time // 4.
+ """
+
+ if x.dim() == 3:
+ x = x.unsqueeze(1) # (b, c, f, t)
+ x = self.conv(x)
+ x = rearrange(x, "b c f t -> b t (c f)")
+ x = self.linear(x)
+ return x
diff --git a/src/third_party/musicfm/modules/features.py b/src/third_party/musicfm/modules/features.py
new file mode 100644
index 0000000000000000000000000000000000000000..c38f525e856eeefffeb2580c7bf61058ed228e0e
--- /dev/null
+++ b/src/third_party/musicfm/modules/features.py
@@ -0,0 +1,45 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import torchaudio
+from torch import nn
+
+
+class MelSTFT(nn.Module):
+ def __init__(
+ self,
+ sample_rate=24000,
+ n_fft=2048,
+ hop_length=240,
+ n_mels=128,
+ is_db=False,
+ ):
+ super(MelSTFT, self).__init__()
+
+ # spectrogram
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
+ )
+
+ # amplitude to decibel
+ self.is_db = is_db
+ if is_db:
+ self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
+
+ def forward(self, waveform):
+ if self.is_db:
+ return self.amplitude_to_db(self.mel_stft(waveform))
+ else:
+ return self.mel_stft(waveform)
diff --git a/src/third_party/musicfm/modules/flash_conformer.py b/src/third_party/musicfm/modules/flash_conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89012c476c27973748e2cee914dae3b400348465
--- /dev/null
+++ b/src/third_party/musicfm/modules/flash_conformer.py
@@ -0,0 +1,2114 @@
+# coding=utf-8
+# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Wav2Vec2-Conformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+
+from transformers.activations import ACT2FN
+from transformers.deepspeed import is_deepspeed_zero3_enabled
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ CausalLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+ Wav2Vec2BaseModelOutput,
+ XVectorOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_HIDDEN_STATES_START_POSITION = 2
+
+# General docstring
+_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
+_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
+_CTC_EXPECTED_LOSS = 64.21
+
+
+WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "facebook/wav2vec2-conformer-rel-pos-large",
+ # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer
+]
+
+
+@dataclass
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions.
+
+ Args:
+ loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
+ paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
+ projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
+ projected quantized states.
+ projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
+ Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
+ target vectors for contrastive loss.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
+ The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) .
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ projected_states: torch.FloatTensor = None
+ projected_quantized_states: torch.FloatTensor = None
+ codevector_perplexity: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ contrastive_loss: Optional[torch.FloatTensor] = None
+ diversity_loss: Optional[torch.FloatTensor] = None
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.sum(-1).detach().tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices
+def _sample_negative_indices(
+ features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
+):
+ """
+ Sample `num_negatives` vectors from feature vectors.
+ """
+ batch_size, sequence_length = features_shape
+
+ # generate indices of the positive vectors themselves, repeat them `num_negatives` times
+ sequence_length_range = np.arange(sequence_length)
+
+ # get `num_negatives` random vector indices from the same utterance
+ sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
+
+ mask_time_indices = (
+ mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
+ )
+
+ for batch_idx in range(batch_size):
+ high = mask_time_indices[batch_idx].sum() - 1
+ mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
+
+ feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
+ sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
+ # avoid sampling the same positive vector, but keep the distribution uniform
+ sampled_indices[sampled_indices >= feature_indices] += 1
+
+ # remap to actual indices
+ sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
+
+ # correct for batch size
+ sampled_negative_indices[batch_idx] += batch_idx * sequence_length
+
+ return sampled_negative_indices
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerLayerNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+
+ hidden_states = hidden_states.transpose(-2, -1)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = hidden_states.transpose(-2, -1)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGroupNormConvLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
+ self.out_conv_dim = config.conv_dim[layer_id]
+
+ self.conv = nn.Conv1d(
+ self.in_conv_dim,
+ self.out_conv_dim,
+ kernel_size=config.conv_kernel[layer_id],
+ stride=config.conv_stride[layer_id],
+ bias=config.conv_bias,
+ )
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ )
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
+ deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
+ else:
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
+ """Rotary positional embedding
+ Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size // config.num_attention_heads
+ base = config.rotary_embedding_base
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+ self.cached_sequence_length = None
+ self.cached_rotary_positional_embedding = None
+
+ def forward(self, hidden_states):
+ sequence_length = hidden_states.shape[1]
+
+ if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None:
+ return self.cached_rotary_positional_embedding
+
+ self.cached_sequence_length = sequence_length
+ time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
+ freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
+ embeddings = torch.cat((freqs, freqs), dim=-1)
+
+ cos_embeddings = embeddings.cos()[:, None, None, :]
+ sin_embeddings = embeddings.sin()[:, None, None, :]
+ self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings])
+ return self.cached_rotary_positional_embedding
+
+
+class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
+ """Relative positional encoding module."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.max_len = config.max_source_positions
+ self.d_model = config.hidden_size
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, self.max_len))
+
+ def extend_pe(self, x):
+ # Reset the positional encodings
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` is the position of query vector and `j` is the
+ # position of key vector. We use positive relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (iWav2Vec2Conformer
+class Wav2Vec2ConformerSamePadLayer(nn.Module):
+ def __init__(self, num_conv_pos_embeddings):
+ super().__init__()
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
+
+ def forward(self, hidden_states):
+ if self.num_pad_remove > 0:
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureEncoder(nn.Module):
+ """Construct the features from raw audio waveform"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ if config.feat_extract_norm == "group":
+ conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [
+ Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1)
+ for i in range(config.num_feat_extract_layers - 1)
+ ]
+ elif config.feat_extract_norm == "layer":
+ conv_layers = [
+ Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
+ ]
+ else:
+ raise ValueError(
+ f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
+ )
+ self.conv_layers = nn.ModuleList(conv_layers)
+ self.gradient_checkpointing = False
+ self._requires_grad = True
+
+ def _freeze_parameters(self):
+ for param in self.parameters():
+ param.requires_grad = False
+ self._requires_grad = False
+
+ def forward(self, input_values):
+ hidden_states = input_values[:, None]
+
+ # make sure hidden_states require grad for gradient_checkpointing
+ if self._requires_grad and self.training:
+ hidden_states.requires_grad = True
+
+ for conv_layer in self.conv_layers:
+ if self._requires_grad and self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(conv_layer),
+ hidden_states,
+ )
+ else:
+ hidden_states = conv_layer(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeatureProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.dropout = nn.Dropout(config.feat_proj_dropout)
+
+ def forward(self, hidden_states):
+ # non-projected hidden states are needed for quantization
+ norm_hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.projection(norm_hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states, norm_hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerFeedForward(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
+
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate_dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ hidden_states = self.intermediate_dropout(hidden_states)
+
+ hidden_states = self.output_dense(hidden_states)
+ hidden_states = self.output_dropout(hidden_states)
+ return hidden_states
+
+
+class Wav2Vec2ConformerConvolutionModule(nn.Module):
+ """Convolution block used in the conformer block"""
+
+ def __init__(self, config):
+ super().__init__()
+ if (config.conv_depthwise_kernel_size - 1) % 2 == 1:
+ raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding")
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ config.hidden_size,
+ 2 * config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.glu = torch.nn.GLU(dim=1)
+ self.depthwise_conv = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ config.conv_depthwise_kernel_size,
+ stride=1,
+ padding=(config.conv_depthwise_kernel_size - 1) // 2,
+ groups=config.hidden_size,
+ bias=False,
+ )
+ self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size)
+ self.activation = ACT2FN[config.hidden_act]
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ )
+ self.dropout = torch.nn.Dropout(config.conformer_conv_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.layer_norm(hidden_states)
+ # exchange the temporal dimension and the feature dimension
+ hidden_states = hidden_states.transpose(1, 2)
+
+ # GLU mechanism
+ # => (batch, 2*channel, dim)
+ hidden_states = self.pointwise_conv1(hidden_states)
+ # => (batch, channel, dim)
+ hidden_states = self.glu(hidden_states)
+
+ # 1D Depthwise Conv
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.batch_norm(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ hidden_states = self.pointwise_conv2(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+class Wav2Vec2ConformerSelfAttention(nn.Module):
+ """Construct an Wav2Vec2ConformerSelfAttention object.
+ Can be enhanced with rotary or relative position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_size = config.hidden_size // config.num_attention_heads
+ self.num_heads = config.num_attention_heads
+ self.position_embeddings_type = config.position_embeddings_type
+
+ self.linear_q = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_k = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_v = nn.Linear(config.hidden_size, config.hidden_size)
+ self.linear_out = nn.Linear(config.hidden_size, config.hidden_size)
+
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+ self.dropout_p = config.attention_dropout
+
+ self.is_causal = config.is_causal
+
+ if self.position_embeddings_type == "relative":
+ # linear transformation for positional encoding
+ self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # self-attention mechanism
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ # make sure query/key states can be != value states
+ query_key_states = hidden_states
+ value_states = hidden_states
+
+ if self.position_embeddings_type == "rotary":
+ if relative_position_embeddings is None:
+ raise ValueError(
+ "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'"
+ )
+ query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings)
+
+ # project query_key_states and value_states
+ query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size)
+ value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size)
+
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal)
+ probs = None
+
+ # # apply attention_mask if necessary
+ # if attention_mask is not None:
+ # scores = scores + attention_mask
+
+ # # => (batch, head, time1, time2)
+ # probs = torch.softmax(scores, dim=-1)
+ # probs = self.dropout(probs)
+
+ # # => (batch, head, time1, d_k)
+ # hidden_states = torch.matmul(probs, value)
+
+ # => (batch, time1, hidden_size)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size)
+ hidden_states = self.linear_out(hidden_states)
+
+ return hidden_states, probs
+
+ def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings):
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size)
+
+ cos = relative_position_embeddings[0, :sequence_length, ...]
+ sin = relative_position_embeddings[1, :sequence_length, ...]
+
+ # rotate hidden_states with rotary embeddings
+ hidden_states = hidden_states.transpose(0, 1)
+ rotated_states_begin = hidden_states[..., : self.head_size // 2]
+ rotated_states_end = hidden_states[..., self.head_size // 2 :]
+ rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1)
+ hidden_states = (hidden_states * cos) + (rotated_states * sin)
+ hidden_states = hidden_states.transpose(0, 1)
+
+ hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size)
+
+ return hidden_states
+
+ def _apply_relative_embeddings(self, query, key, relative_position_embeddings):
+ # 1. project positional embeddings
+ # => (batch, head, 2*time1-1, d_k)
+ proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.view(
+ relative_position_embeddings.size(0), -1, self.num_heads, self.head_size
+ )
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2)
+ proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3)
+
+ # 2. Add bias to query
+ # => (batch, head, time1, d_k)
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ # 3. attention score: first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # => (batch, head, time1, time2)
+ scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ # 4. then compute matrix b and matrix d
+ # => (batch, head, time1, 2*time1-1)
+ scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings)
+
+ # 5. shift matrix b and matrix d
+ zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype)
+ scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1)
+ scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2])
+ scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape)
+ scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd)
+ scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1]
+
+ # 6. sum matrices
+ # => (batch, head, time1, time2)
+ scores = (scores_ac + scores_bd) / math.sqrt(self.head_size)
+
+ return scores
+
+
+class Wav2Vec2ConformerEncoderLayer(nn.Module):
+ """Conformer block based on https://arxiv.org/abs/2005.08100."""
+
+ def __init__(self, config):
+ super().__init__()
+ embed_dim = config.hidden_size
+ dropout = config.attention_dropout
+
+ # Feed-forward 1
+ self.ffn1_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn1 = Wav2Vec2ConformerFeedForward(config)
+
+ # Self-Attention
+ self.self_attn_layer_norm = nn.LayerNorm(embed_dim)
+ self.self_attn_dropout = torch.nn.Dropout(dropout)
+ self.self_attn = Wav2Vec2ConformerSelfAttention(config)
+
+ # Conformer Convolution
+ self.conv_module = Wav2Vec2ConformerConvolutionModule(config)
+
+ # Feed-forward 2
+ self.ffn2_layer_norm = nn.LayerNorm(embed_dim)
+ self.ffn2 = Wav2Vec2ConformerFeedForward(config)
+ self.final_layer_norm = nn.LayerNorm(embed_dim)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask: Optional[torch.Tensor] = None,
+ relative_position_embeddings: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ):
+ hidden_states = hidden_states
+
+ # 1. Feed-Forward 1 layer
+ residual = hidden_states
+ hidden_states = self.ffn1_layer_norm(hidden_states)
+ hidden_states = self.ffn1(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ residual = hidden_states
+
+ # 2. Self-Attention layer
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weigts = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_dropout(hidden_states)
+ hidden_states = hidden_states + residual
+
+ # 3. Convolutional Layer
+ residual = hidden_states
+ hidden_states = self.conv_module(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # 4. Feed-Forward 2 Layer
+ residual = hidden_states
+ hidden_states = self.ffn2_layer_norm(hidden_states)
+ hidden_states = self.ffn2(hidden_states)
+ hidden_states = hidden_states * 0.5 + residual
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ return hidden_states, attn_weigts
+
+
+class Wav2Vec2ConformerEncoder(nn.Module):
+ def __init__(self, config, is_causal=False):
+ super().__init__()
+ config.is_causal = is_causal
+ self.config = config
+
+ if config.position_embeddings_type == "relative":
+ self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config)
+ elif config.position_embeddings_type == "rotary":
+ self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config)
+ else:
+ self.embed_positions = None
+
+ self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states[~attention_mask] = 0.0
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = np.random.uniform(0, 1)
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ # create gradient checkpointing function
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module):
+ """
+ Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
+ GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.num_groups = config.num_codevector_groups
+ self.num_vars = config.num_codevectors_per_group
+
+ if config.codevector_dim % self.num_groups != 0:
+ raise ValueError(
+ f"`config.codevector_dim {config.codevector_dim} must be divisible "
+ f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
+ )
+
+ # storage for codebook variables (codewords)
+ self.codevectors = nn.Parameter(
+ torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
+ )
+ self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
+
+ # can be decayed for training
+ self.temperature = 2
+
+ @staticmethod
+ def _compute_perplexity(probs, mask=None):
+ if mask is not None:
+ mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
+ probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
+ marginal_probs = probs.sum(dim=0) / mask.sum()
+ else:
+ marginal_probs = probs.mean(dim=0)
+
+ perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
+ return perplexity
+
+ def forward(self, hidden_states, mask_time_indices=None):
+ batch_size, sequence_length, hidden_size = hidden_states.shape
+
+ # project to codevector dim
+ hidden_states = self.weight_proj(hidden_states)
+ hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
+
+ if self.training:
+ # sample code vector probs via gumbel in differentiateable way
+ codevector_probs = nn.functional.gumbel_softmax(
+ hidden_states.float(), tau=self.temperature, hard=True
+ ).type_as(hidden_states)
+
+ # compute perplexity
+ codevector_soft_dist = torch.softmax(
+ hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
+ )
+ perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
+ else:
+ # take argmax in non-differentiable way
+ # comptute hard codevector distribution (one hot)
+ codevector_idx = hidden_states.argmax(dim=-1)
+ codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
+ -1, codevector_idx.view(-1, 1), 1.0
+ )
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
+
+ perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
+
+ codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
+ # use probs to retrieve codevectors
+ codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
+ codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
+ codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
+
+ return codevectors, perplexity
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapter(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ # feature dim might need to be down-projected
+ if config.output_hidden_size != config.hidden_size:
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
+ self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
+ else:
+ self.proj = self.proj_layer_norm = None
+
+ self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers))
+ self.layerdrop = config.layerdrop
+
+ def forward(self, hidden_states):
+ # down project hidden_states if necessary
+ if self.proj is not None and self.proj_layer_norm is not None:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.proj_layer_norm(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+
+ for layer in self.layers:
+ layerdrop_prob = np.random.random()
+ if not self.training or (layerdrop_prob > self.layerdrop):
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2)
+ return hidden_states
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer
+class Wav2Vec2ConformerAdapterLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.output_hidden_size,
+ 2 * config.output_hidden_size,
+ config.adapter_kernel_size,
+ stride=config.adapter_stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=1)
+
+ return hidden_states
+
+
+class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Wav2Vec2ConformerConfig
+ base_model_prefix = "wav2vec2_conformer"
+ main_input_name = "input_values"
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
+ if isinstance(module, Wav2Vec2ConformerForPreTraining):
+ module.project_hid.reset_parameters()
+ module.project_q.reset_parameters()
+ module.project_hid._is_hf_initialized = True
+ module.project_q._is_hf_initialized = True
+ # gumbel softmax requires special init
+ elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer):
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
+ module.weight_proj.bias.data.zero_()
+ nn.init.uniform_(module.codevectors)
+ elif isinstance(module, Wav2Vec2ConformerSelfAttention):
+ if hasattr(module, "pos_bias_u"):
+ nn.init.xavier_uniform_(module.pos_bias_u)
+ if hasattr(module, "pos_bias_v"):
+ nn.init.xavier_uniform_(module.pos_bias_v)
+ elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, Wav2Vec2ConformerFeatureProjection):
+ k = math.sqrt(1 / module.projection.in_features)
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
+ ):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ if add_adapter:
+ for _ in range(self.config.num_adapter_layers):
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
+ ):
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
+ # on inference mode.
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
+
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
+ output_lengths = output_lengths.to(torch.long)
+
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
+ module.gradient_checkpointing = value
+
+
+WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
+ Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
+ Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
+ Auli.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving etc.).
+
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
+ into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
+ soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
+ conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
+ 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+
+
+ `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
+ True`. For all models whose processor has `config.return_attention_mask == False`, such as
+ [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large),
+ `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For
+ such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware
+ that these models also yield slightly different results depending on whether `input_values` is padded or
+ not.
+
+
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config)
+ self.feature_projection = Wav2Vec2ConformerFeatureProjection(config)
+
+ # model only needs masking vector if mask prob is > 0.0
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
+
+ self.encoder = Wav2Vec2ConformerEncoder(config)
+
+ self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=Wav2Vec2BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ hidden_states, extract_features = self.feature_projection(extract_features)
+ hidden_states = self._mask_hidden_states(
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
+ )
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if self.adapter is not None:
+ hidden_states = self.adapter(hidden_states)
+
+ if not return_dict:
+ return (hidden_states, extract_features) + encoder_outputs[1:]
+
+ return Wav2Vec2BaseModelOutput(
+ last_hidden_state=hidden_states,
+ extract_features=extract_features,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING
+)
+class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config: Wav2Vec2ConformerConfig):
+ super().__init__(config)
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
+
+ self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config)
+
+ self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
+ self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature
+ def set_gumbel_temperature(self, temperature: int):
+ """
+ Set the Gumbel softmax temperature to a given value. Only necessary for training
+ """
+ self.quantizer.temperature = temperature
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @staticmethod
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits
+ def compute_contrastive_logits(
+ target_features: torch.FloatTensor,
+ negative_features: torch.FloatTensor,
+ predicted_features: torch.FloatTensor,
+ temperature: int = 0.1,
+ ):
+ """
+ Compute logits for contrastive loss based using cosine similarity as the distance measure between
+ `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
+ """
+ target_features = torch.cat([target_features, negative_features], dim=0)
+
+ logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
+ target_features
+ )
+
+ # apply temperature
+ logits = logits / temperature
+ return logits
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.BoolTensor] = None,
+ sampled_negative_indices: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
+ Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
+ Required input for pre-training.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
+ >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
+ ... _compute_mask_indices,
+ ... _sample_negative_indices,
+ ... )
+ >>> from datasets import load_dataset
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+ >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
+
+ >>> # compute masked indices
+ >>> batch_size, raw_sequence_length = input_values.shape
+ >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
+ >>> mask_time_indices = _compute_mask_indices(
+ ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
+ ... )
+ >>> sampled_negative_indices = _sample_negative_indices(
+ ... features_shape=(batch_size, sequence_length),
+ ... num_negatives=model.config.num_negatives,
+ ... mask_time_indices=mask_time_indices,
+ ... )
+ >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
+ >>> sampled_negative_indices = torch.tensor(
+ ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
+ ... )
+
+ >>> with torch.no_grad():
+ ... outputs = model(input_values, mask_time_indices=mask_time_indices)
+
+ >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
+ >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
+
+ >>> # show that cosine similarity is much higher than random
+ >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
+ tensor(True)
+
+ >>> # for contrastive loss training model should be put into train mode
+ >>> model = model.train()
+ >>> loss = model(
+ ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
+ ... ).loss
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if mask_time_indices is not None:
+ mask_time_indices = mask_time_indices.to(torch.bool)
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ mask_time_indices=mask_time_indices,
+ return_dict=return_dict,
+ )
+
+ # 1. project all transformed features (including masked) to final vq dim
+ transformer_features = self.project_hid(outputs[0])
+
+ # 2. quantize all (unmasked) extracted features and project to final vq dim
+ extract_features = self.dropout_features(outputs[1])
+
+ if attention_mask is not None:
+ # compute reduced attention_mask correponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(
+ extract_features.shape[1], attention_mask, add_adapter=False
+ )
+
+ quantized_features, codevector_perplexity = self.quantizer(
+ extract_features, mask_time_indices=mask_time_indices
+ )
+ quantized_features = self.project_q(quantized_features)
+
+ loss = contrastive_loss = diversity_loss = None
+ if sampled_negative_indices is not None:
+ batch_size, sequence_length, hidden_size = quantized_features.shape
+
+ # for training, we sample negatives
+ # 3. sample K negatives (distractors) quantized states for contrastive loss
+ # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
+ # sample negative quantized vectors BTC => (BxT)C
+ negative_quantized_features = quantized_features.view(-1, hidden_size)[
+ sampled_negative_indices.long().view(-1)
+ ]
+ negative_quantized_features = negative_quantized_features.view(
+ batch_size, sequence_length, -1, hidden_size
+ ).permute(2, 0, 1, 3)
+
+ # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
+ # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
+ logits = self.compute_contrastive_logits(
+ quantized_features[None, :],
+ negative_quantized_features,
+ transformer_features,
+ self.config.contrastive_logits_temperature,
+ )
+
+ # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
+ # its cosine similarity will be masked
+ neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
+
+ if neg_is_pos.any():
+ logits[1:][neg_is_pos] = float("-inf")
+
+ # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
+ # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
+ logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
+ target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
+
+ contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
+ # 7. compute diversity loss: \mathbf{L}_d
+ num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
+ diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
+
+ # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
+ loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
+
+ if not return_dict:
+ if loss is not None:
+ return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+ return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
+
+ return Wav2Vec2ConformerForPreTrainingOutput(
+ loss=loss,
+ projected_states=transformer_features,
+ projected_quantized_states=quantized_features,
+ codevector_perplexity=codevector_perplexity,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ contrastive_loss=contrastive_loss,
+ diversity_loss=diversity_loss,
+ )
+
+
+@add_start_docstrings(
+ """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ self.dropout = nn.Dropout(config.final_dropout)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = (
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
+ )
+ self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states)
+
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ if labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for
+ tasks like SUPERB Keyword Spotting.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+ if attention_mask is None:
+ pooled_output = hidden_states.mean(dim=1)
+ else:
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+ hidden_states[~padding_mask] = 0.0
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel):
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def __init__(self, config):
+ super().__init__(config)
+
+ if hasattr(config, "add_adapter") and config.add_adapter:
+ raise ValueError(
+ "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)"
+ )
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ self.num_labels = config.num_labels
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
+class AMSoftmaxLoss(nn.Module):
+ def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
+ super(AMSoftmaxLoss, self).__init__()
+ self.scale = scale
+ self.margin = margin
+ self.num_labels = num_labels
+ self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, hidden_states, labels):
+ labels = labels.flatten()
+ weight = nn.functional.normalize(self.weight, dim=0)
+ hidden_states = nn.functional.normalize(hidden_states, dim=1)
+ cos_theta = torch.mm(hidden_states, weight)
+ psi = cos_theta - self.margin
+
+ onehot = nn.functional.one_hot(labels, self.num_labels)
+ logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
+ loss = self.loss(logits, labels)
+
+ return loss
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
+class TDNNLayer(nn.Module):
+ def __init__(self, config, layer_id=0):
+ super().__init__()
+ self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
+ self.out_conv_dim = config.tdnn_dim[layer_id]
+ self.kernel_size = config.tdnn_kernel[layer_id]
+ self.dilation = config.tdnn_dilation[layer_id]
+
+ self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
+ self.activation = nn.ReLU()
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = nn.functional.unfold(
+ hidden_states,
+ (self.kernel_size, self.in_conv_dim),
+ stride=(1, self.in_conv_dim),
+ dilation=(self.dilation, 1),
+ )
+ hidden_states = hidden_states.transpose(1, 2)
+ hidden_states = self.kernel(hidden_states)
+
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification.
+ """,
+ WAV2VEC2_CONFORMER_START_DOCSTRING,
+)
+class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
+ if config.use_weighted_layer_sum:
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
+ self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
+
+ tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
+ self.tdnn = nn.ModuleList(tdnn_layers)
+
+ self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
+ self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
+
+ self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
+
+ self.init_weights()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer
+ def freeze_feature_encoder(self):
+ """
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
+ not be updated during training.
+ """
+ self.wav2vec2_conformer.feature_extractor._freeze_parameters()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer
+ def freeze_base_model(self):
+ """
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
+ be updated during training. Only the classification head will be updated.
+ """
+ for param in self.wav2vec2_conformer.parameters():
+ param.requires_grad = False
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer
+ def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the TDNN layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return (input_length - kernel_size) // stride + 1
+
+ for kernel_size in self.config.tdnn_kernel:
+ input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
+
+ return input_lengths
+
+ @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=XVectorOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ )
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, XVectorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
+
+ outputs = self.wav2vec2_conformer(
+ input_values,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if self.config.use_weighted_layer_sum:
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
+ hidden_states = torch.stack(hidden_states, dim=1)
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
+ else:
+ hidden_states = outputs[0]
+
+ hidden_states = self.projector(hidden_states)
+
+ for tdnn_layer in self.tdnn:
+ hidden_states = tdnn_layer(hidden_states)
+
+ # Statistic Pooling
+ if attention_mask is None:
+ mean_features = hidden_states.mean(dim=1)
+ std_features = hidden_states.std(dim=1)
+ else:
+ feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
+ tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
+ mean_features = []
+ std_features = []
+ for i, length in enumerate(tdnn_output_lengths):
+ mean_features.append(hidden_states[i, :length].mean(dim=0))
+ std_features.append(hidden_states[i, :length].std(dim=0))
+ mean_features = torch.stack(mean_features)
+ std_features = torch.stack(std_features)
+ statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
+
+ output_embeddings = self.feature_extractor(statistic_pooling)
+ logits = self.classifier(output_embeddings)
+
+ loss = None
+ if labels is not None:
+ loss = self.objective(logits, labels)
+
+ if not return_dict:
+ output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return XVectorOutput(
+ loss=loss,
+ logits=logits,
+ embeddings=output_embeddings,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/third_party/musicfm/modules/random_quantizer.py b/src/third_party/musicfm/modules/random_quantizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1257014658a24e4557814ccbb746de455ec111fa
--- /dev/null
+++ b/src/third_party/musicfm/modules/random_quantizer.py
@@ -0,0 +1,83 @@
+# MIT License
+#
+# Copyright 2023 ByteDance Inc.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
+# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+# IN THE SOFTWARE.
+
+import torch
+from torch import nn, einsum
+from einops import rearrange
+
+
+class RandomProjectionQuantizer(nn.Module):
+ """
+ Random projection and codebook lookup module
+
+ Some code is borrowed from:
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
+ But I did normalization using pre-computed global mean & variance instead of using layer norm.
+ """
+
+ def __init__(
+ self,
+ input_dim,
+ codebook_dim,
+ codebook_size,
+ seed=142,
+ ):
+ super().__init__()
+
+ # random seed
+ torch.manual_seed(seed)
+
+ # randomly initialized projection
+ random_projection = torch.empty(input_dim, codebook_dim)
+ nn.init.xavier_normal_(random_projection)
+ self.register_buffer("random_projection", random_projection)
+
+ # randomly initialized codebook
+ codebook = torch.empty(codebook_size, codebook_dim)
+ nn.init.normal_(codebook)
+ self.register_buffer("codebook", codebook)
+
+ def codebook_lookup(self, x):
+ # reshape
+ b = x.shape[0]
+ x = rearrange(x, "b n e -> (b n) e")
+
+ # L2 normalization
+ normalized_x = nn.functional.normalize(x, dim=1, p=2)
+ normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
+
+ # compute distances
+ distances = torch.cdist(normalized_codebook, normalized_x)
+
+ # get nearest
+ nearest_indices = torch.argmin(distances, dim=0)
+
+ # reshape
+ xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
+
+ return xq
+
+ @torch.no_grad()
+ def forward(self, x):
+ # always eval
+ self.eval()
+
+ # random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
+ x = einsum("b n d, d e -> b n e", x, self.random_projection)
+
+ # codebook lookup
+ xq = self.codebook_lookup(x)
+
+ return xq