Karlo Pintaric commited on
Commit
fdc1efd
1 Parent(s): 48fb9cc

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/api/test_files/test.wav filter=lfs diff=lfs merge=lfs -text
DockerFile.backend ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as the base image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the setup.py file and the package directory into the container
8
+ COPY ./setup.py .
9
+
10
+ # Install the package and its dependencies
11
+ COPY ./src ./src
12
+
13
+ RUN pip install --no-cache-dir .[backend] torch==1.13.1+cpu torchvision==0.14.1+cpu torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cpu
14
+
15
+ EXPOSE 7860
16
+
17
+ CMD ["uvicorn", "src.api.main:app", "--host", "0.0.0.0", "--port", "7860"]
setup.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="lumen-irmas",
5
+ version="0.1.0",
6
+ description="LUMEN Data Science nagradni zadatak",
7
+ author="Karlo Pintaric i Tatjana Cigula",
8
+ packages=find_packages(include=["src"]),
9
+ python_requires=">=3.9",
10
+ install_requires=[
11
+ "numpy==1.23.5",
12
+ "transformers==4.27.4",
13
+ ],
14
+ extras_require={
15
+ "backend": ["fastapi==0.95.1", "uvicorn==0.21.1", "pydantic==1.10.7", "python-multipart==0.0.6"],
16
+ "frontend": ["streamlit==1.21.0", "requests==2.28.2", "soundfile==0.12.1"],
17
+ "user": [
18
+ "lumen-irmas[backend]",
19
+ "lumen-irmas[frontend]",
20
+ "torch==1.13.1",
21
+ "torchaudio==0.13.1",
22
+ "torchvision==0.14.1",
23
+ ],
24
+ "dev": [
25
+ "lumen-irmas[user]",
26
+ "librosa==0.10.0.post2",
27
+ "pandas==1.5.3",
28
+ "scikit-learn==1.2.2",
29
+ "tqdm==4.65.0",
30
+ "wandb==0.14.2",
31
+ "pytest==7.3.1",
32
+ "joblib==1.2.0",
33
+ "PyYAML==6.0",
34
+ "flake8==6.0.0",
35
+ "isort== 5.12.0",
36
+ "black==23.3.0"
37
+ ]
38
+ },
39
+ )
src/__init__.py ADDED
File without changes
src/api/ModelService.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST
8
+
9
+ MODELS_FOLDER = Path(__file__).parent / "models"
10
+
11
+ CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
12
+
13
+
14
+ def load_model(model_type: str):
15
+ """
16
+ Loads a pre-trained AST model of the specified type.
17
+
18
+ :param model_type: The type of model to load
19
+ :type model_type: str
20
+ :return: The loaded pre-trained AST model.
21
+ :rtype: ASTPretrained
22
+ """
23
+
24
+ if model_type == "accuracy":
25
+ model = ASTPretrained(n_classes=11, download_weights=False)
26
+ model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
27
+ else:
28
+ model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
29
+ model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
30
+ model.eval()
31
+ return model
32
+
33
+
34
+ def load_labels():
35
+ """
36
+ Loads a dictionary of class labels for the AST model.
37
+
38
+ :return: A dictionary where the keys are the class indices and the values are the class labels.
39
+ :rtype: Dict[int, str]
40
+ """
41
+
42
+ labels = {i: CLASSES[i] for i in range(len(CLASSES))}
43
+ return labels
44
+
45
+
46
+ def load_thresholds(model_type: str):
47
+ """
48
+ Loads the prediction thresholds for the AST model.
49
+
50
+ :return: The prediction thresholds for each class.
51
+ :rtype: np.ndarray
52
+ """
53
+ if model_type == "accuracy":
54
+ thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
55
+ else:
56
+ thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
57
+ return thresholds
58
+
59
+
60
+ class ModelServiceAST:
61
+ def __init__(self, model_type: str):
62
+ """
63
+ Initializes a ModelServiceAST instance with the specified model type.
64
+
65
+ :param model_type: The type of model to load
66
+ :type model_type: str
67
+ """
68
+
69
+ self.model = load_model(model_type)
70
+ self.labels = load_labels()
71
+ self.thresholds = load_thresholds(model_type)
72
+ self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])
73
+
74
+ def get_prediction(self, audio):
75
+ """
76
+ Gets the binary predictions for the given audio file.
77
+
78
+ :param audio_file: The file object for the input audio to make predictions for.
79
+ :type audio_file: file object
80
+ :return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
81
+ :rtype: Dict[str, int]
82
+ """
83
+ processed = self.transform(audio)
84
+ with torch.no_grad():
85
+ # Don't forget to transpose the output to seq_len x num_features!!!
86
+ output = torch.sigmoid(self.model(processed.mT))
87
+ output = output.squeeze().numpy().astype(float)
88
+
89
+ binary_predictions = {}
90
+ for i, label in enumerate(CLASSES):
91
+ binary_predictions[label] = int(output[i] >= self.thresholds[i])
92
+
93
+ return binary_predictions
src/api/__init__.py ADDED
File without changes
src/api/main.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from logging.handlers import RotatingFileHandler
4
+ from pathlib import Path
5
+ from typing import Dict
6
+
7
+
8
+ from fastapi import Depends, FastAPI, File, UploadFile
9
+ from fastapi.exceptions import RequestValidationError
10
+ from fastapi.responses import JSONResponse
11
+ from src.api.ModelService import ModelServiceAST
12
+ from pydantic import BaseModel, validator
13
+
14
+ LOG_SAVE_DIR = Path(__file__).parent / "logs"
15
+ if not os.path.exists(LOG_SAVE_DIR):
16
+ os.makedirs(LOG_SAVE_DIR)
17
+
18
+ ml_models = {}
19
+ ml_models["Accuracy"] = ModelServiceAST(model_type="accuracy")
20
+ ml_models["Speed"] = ModelServiceAST(model_type="speed")
21
+
22
+ app = FastAPI()
23
+
24
+ # Define the allowed file formats and maximum file size (in bytes)
25
+ ALLOWED_FILE_FORMATS = ["wav"]
26
+
27
+ # Configure logging
28
+ logger = logging.getLogger(__name__)
29
+ logger.setLevel(logging.DEBUG)
30
+
31
+ # Create a rotating file handler to save logs to a file
32
+ handler = RotatingFileHandler(f"{LOG_SAVE_DIR}/app.log", maxBytes=100000, backupCount=5)
33
+ handler.setLevel(logging.DEBUG)
34
+
35
+ # Define the log format
36
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
37
+ handler.setFormatter(formatter)
38
+
39
+ # Add the handler to the logger
40
+ logger.addHandler(handler)
41
+
42
+
43
+ class InvalidFileTypeError(Exception):
44
+ def __init__(self):
45
+ self.message = "Only wav files are supported"
46
+ super().__init__(self.message)
47
+
48
+
49
+ class InvalidModelError(Exception):
50
+ def __init__(self):
51
+ self.message = "Selected model doesn't exist"
52
+ super().__init__(self.message)
53
+
54
+
55
+ class MissingFileError(Exception):
56
+ def __init__(self):
57
+ self.message = "File cannot be None"
58
+ super().__init__(self.message)
59
+
60
+
61
+ class PredictionRequest(BaseModel):
62
+ model_name: str
63
+
64
+ @validator("model_name")
65
+ @classmethod
66
+ def valid_model(cls, v):
67
+ if v not in ml_models.keys():
68
+ raise InvalidModelError
69
+ return v
70
+
71
+
72
+ class PredictionResult(BaseModel):
73
+ prediction: Dict[str, Dict[str, int]]
74
+
75
+
76
+ @app.exception_handler(RequestValidationError)
77
+ def validation_exception_handler(request, ex):
78
+ logger.error(f"Request validation error: {ex}")
79
+ return JSONResponse(content={"error": "Bad Request", "detail": ex.errors()}, status_code=400)
80
+
81
+
82
+ @app.exception_handler(InvalidFileTypeError)
83
+ def filetype_exception_handler(request, ex):
84
+ logger.error(f"Invalid file type error: {ex}")
85
+ return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
86
+
87
+
88
+ @app.exception_handler(InvalidModelError)
89
+ def model_exception_handler(request, ex):
90
+ logger.error(f"Invalid model error: {ex}")
91
+ return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
92
+
93
+
94
+ @app.exception_handler(MissingFileError)
95
+ def handle_missing_file_error(request, ex):
96
+ logger.error(f"Missing file error: {ex}")
97
+ return JSONResponse(content={"error": "Bad Request", "detail": ex.message}, status_code=400)
98
+
99
+
100
+ @app.exception_handler(Exception)
101
+ def handle_exceptions(request, ex):
102
+ logger.exception(f"Internal server error: {ex}")
103
+ # If an exception occurs during processing, return a JSON response with an error message
104
+ return JSONResponse(content={"error": "Internal Server Error", "detail": str(ex)}, status_code=500)
105
+
106
+
107
+ @app.get("/")
108
+ def root():
109
+ logger.info("Received request to root endpoint")
110
+ return {"message": "Welcome to my API. Go to /docs to view the documentation."}
111
+
112
+
113
+ @app.get("/health-check")
114
+ def health_check():
115
+ """
116
+ Health check endpoint to verify if the API is running.
117
+ """
118
+ logger.info("Health check endpoint was hit")
119
+ return {"status": "API is running"}
120
+
121
+
122
+ @app.post("/predict")
123
+ def predict(request: PredictionRequest = Depends(), file: UploadFile = File(...)) -> PredictionResult: # noqa
124
+ if not file:
125
+ raise MissingFileError
126
+ if file.filename.split(".")[-1].lower() not in ALLOWED_FILE_FORMATS:
127
+ raise InvalidFileTypeError
128
+ logger.info(f"Prediction request received: {request}")
129
+ output = ml_models[request.model_name].get_prediction(file.file)
130
+ logger.info(f"Prediction result: {output}")
131
+ prediction_result = PredictionResult(prediction={file.filename: output})
132
+
133
+ return prediction_result
src/api/main_test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import soundfile as sf
6
+ from fastapi.testclient import TestClient
7
+
8
+ sys.path.append(".")
9
+
10
+ from src.api.main import app # noqa
11
+
12
+ TEST_FILES_DIR = Path(__file__).parent / "test_files"
13
+ TEST_WAV_FILE = TEST_FILES_DIR / "test.wav"
14
+
15
+ client = TestClient(app)
16
+
17
+
18
+ def test_health_check():
19
+ response = client.get("/health-check")
20
+ assert response.status_code == 200
21
+ assert response.json() == {"status": "API is running"}
22
+
23
+
24
+ def test_predict_valid_cut_file():
25
+ audio_data, sample_rate = sf.read(TEST_WAV_FILE)
26
+ audio_file = io.BytesIO()
27
+ sf.write(audio_file, audio_data, sample_rate, format="wav")
28
+ audio_file = ("test.wav", audio_file)
29
+
30
+ file = {"file": audio_file}
31
+ request_data = {"model_name": "Accuracy"}
32
+ # Make a request to the /predict endpoint
33
+ response = client.post("/predict", params=request_data, files=file)
34
+
35
+ # Check that the response is successful
36
+ assert response.status_code == 200
37
+ assert response.json()["prediction"]["test.wav"] is not None
38
+
39
+
40
+ def test_predict_valid_file():
41
+ with open(TEST_WAV_FILE, "rb") as file:
42
+ data = {"model_name": "Accuracy"}
43
+ response = client.post("/predict", params=data, files={"file": file})
44
+ assert response.status_code == 200
45
+ assert response.json()["prediction"]["test.wav"] is not None
46
+
47
+
48
+ def test_predict_invalid_file_type():
49
+ file_data = io.BytesIO(b"dummy txt data")
50
+ file = ("test.txt", file_data)
51
+ data = {"model_name": "Accuracy"}
52
+ response = client.post("/predict", params=data, files={"file": file})
53
+ assert response.status_code == 400
54
+ assert "Only wav files are supported" in response.json()["detail"]
55
+
56
+
57
+ def test_predict_invalid_model():
58
+ file_data = io.BytesIO(b"dummy wav data")
59
+ file = ("test.wav", file_data)
60
+ data = {"model_name": "InvalidModel"}
61
+ response = client.post("/predict", params=data, files={"file": file})
62
+ assert response.status_code == 400
63
+ assert "Selected model doesn't exist" in response.json()["detail"]
src/api/models/acc_model_ast.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2305b1d04ed918b6d6428f86dfde162d6912b5021741ff58785fa7b020094ec0
3
+ size 344860756
src/api/models/acc_model_thresh.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3034a1e953618280465b52b4104184b577e783afdf6231add9b96d119e12addf
3
+ size 216
src/api/models/speed_model_ast.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e529b7b85881d249f455b5386cdb5306915ad34cd5fc5fafeca35fc965573637
3
+ size 22573905
src/api/models/speed_model_thresh.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56838178f12bccc05cf5ffc92a7ff570a70d3a42f3f87c977ad8c9ae0f4a3359
3
+ size 216
src/api/test_files/test.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60f854cc407877512a3e68a286cfd26e95dc2f0a4e76ba313fbb3e21ddf2d2f9
3
+ size 3492764
src/frontend/.streamlit/config.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "dark"
3
+ primaryColor = "#FFFFFF"
4
+ backgroundColor = "#212121"
5
+ secondaryBackgroundColor = "#757575"
6
+ textColor = "#FFFFFF"
7
+ font = "sans serif"
8
+
9
+ [browser]
10
+ gatherUsageStats = false
src/frontend/__init__.py ADDED
File without changes
src/frontend/ui.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import streamlit as st
4
+ from ui_backend import (
5
+ check_for_api,
6
+ cut_audio_file,
7
+ display_predictions,
8
+ load_audio,
9
+ predict_multiple,
10
+ predict_single,
11
+ )
12
+
13
+
14
+ def main():
15
+ # Page settings
16
+ st.set_page_config(
17
+ page_title="Music Instrument Recognition", page_icon="🎸", layout="wide", initial_sidebar_state="collapsed"
18
+ )
19
+
20
+ # Sidebar
21
+ with st.sidebar:
22
+ st.title("⚙️ Settings")
23
+ selected_model = st.selectbox(
24
+ "Select Model",
25
+ ("Accuracy", "Speed"),
26
+ index=0,
27
+ help="Select a slower but more accurate model or a faster but less accurate model",
28
+ )
29
+
30
+ # Main title
31
+ st.markdown(
32
+ "<h1 style='text-align: center; color: #FFFFFF; font-size: 3rem;'>Instrument Recognition 🎶</h1>",
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ # Upload widget
37
+ audio_file = load_audio()
38
+
39
+ # Send a health check request to the API in a loop until it is running
40
+ api_running = check_for_api(10)
41
+
42
+ # Enable or disable a button based on API status
43
+ predict_valid = False
44
+ cut_valid = False
45
+
46
+ if api_running:
47
+ st.info("API is running", icon="🤖")
48
+
49
+ if audio_file:
50
+ num_files = len(audio_file)
51
+ st.write(f"Number of uploaded files: {num_files}")
52
+ predict_valid = True
53
+ if len(audio_file) > 1:
54
+ cut_valid = False
55
+ else:
56
+ audio_file = audio_file[0]
57
+ cut_valid = True
58
+ name = audio_file.name
59
+
60
+ if cut_valid:
61
+ cut_audio = st.checkbox(
62
+ "✂️ Cut duration",
63
+ disabled=not predict_valid,
64
+ help="Cut a long audio file. Model works best if audio is around 15 seconds",
65
+ )
66
+
67
+ if cut_audio:
68
+ audio_file = cut_audio_file(audio_file, name)
69
+
70
+ result = st.button("Predict", disabled=not predict_valid, help="Send the audio to API to get a prediction")
71
+
72
+ if result:
73
+ predictions = {}
74
+ if isinstance(audio_file, list):
75
+ predictions = predict_multiple(audio_file, selected_model)
76
+
77
+ else:
78
+ predictions = predict_single(audio_file, name, selected_model)
79
+
80
+ # Sort the dictionary alphabetically by key
81
+ sorted_predictions = dict(sorted(predictions.items()))
82
+
83
+ # Convert the sorted dictionary to a JSON string
84
+ json_string = json.dumps(sorted_predictions)
85
+ st.download_button(
86
+ label="Download JSON",
87
+ file_name="predictions.json",
88
+ mime="application/json",
89
+ data=json_string,
90
+ help="Download the predictions in JSON format",
91
+ )
92
+
93
+ display_predictions(sorted_predictions)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ main()
src/frontend/ui_backend.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from json import JSONDecodeError
5
+ import math
6
+
7
+ import requests
8
+ import soundfile as sf
9
+ import streamlit as st
10
+
11
+ if os.environ.get("IS_DOCKER", False):
12
+ backend = "http://api:7860"
13
+ else:
14
+ backend = "http://0.0.0.0:7860"
15
+
16
+ INSTRUMENTS = {
17
+ "tru": "Trumpet",
18
+ "sax": "Saxophone",
19
+ "vio": "Violin",
20
+ "gac": "Acoustic Guitar",
21
+ "org": "Organ",
22
+ "cla": "Clarinet",
23
+ "flu": "Flute",
24
+ "voi": "Voice",
25
+ "gel": "Electric Guitar",
26
+ "cel": "Cello",
27
+ "pia": "Piano",
28
+ }
29
+
30
+
31
+ def load_audio():
32
+ """
33
+ Upload a WAV audio file and display it in a Streamlit app.
34
+
35
+ :return: A BytesIO object representing the uploaded audio file, or None if no file was uploaded.
36
+ :rtype: Optional[BytesIO]
37
+ """
38
+
39
+ audio_file = st.file_uploader(label="Upload audio file", type="wav", accept_multiple_files=True)
40
+ if len(audio_file) > 0:
41
+ st.audio(audio_file[0])
42
+ return audio_file
43
+ else:
44
+ return None
45
+
46
+
47
+ @st.cache_data(show_spinner=False)
48
+ def check_for_api(max_tries: int):
49
+ """
50
+ Check if the API is running by making a health check request.
51
+
52
+ :param max_tries: The maximum number of attempts to check the API's health.
53
+ :type max_tries: int
54
+ :return: True if the API is running, False otherwise.
55
+ :rtype: bool
56
+ """
57
+ trial_count = 0
58
+
59
+ with st.spinner("Waiting for API..."):
60
+ while trial_count <= max_tries:
61
+ try:
62
+ response = health_check()
63
+ if response:
64
+ return True
65
+ except requests.exceptions.ConnectionError:
66
+ trial_count += 1
67
+ # Handle connection error, e.g. API not yet running
68
+ time.sleep(5) # Sleep for 1 second before retrying
69
+ st.error("API is not running. Please refresh the page to try again.", icon="🚨")
70
+ st.stop()
71
+
72
+
73
+ def cut_audio_file(audio_file, name):
74
+ """
75
+ Cut an audio file and return the cut audio data as a tuple.
76
+
77
+ :param audio_file: The path of the audio file to be cut.
78
+ :type audio_file: str
79
+ :param name: The name of the audio file to be cut.
80
+ :type name: str
81
+ :raises RuntimeError: If the audio file cannot be read.
82
+ :return: A tuple containing the name and the cut audio data as a BytesIO object.
83
+ :rtype: tuple
84
+ """
85
+ try:
86
+ audio_data, sample_rate = sf.read(audio_file)
87
+ except RuntimeError as e:
88
+ raise e
89
+
90
+ # Display audio duration
91
+ duration = round(len(audio_data) / sample_rate, 2)
92
+ st.info(f"Audio Duration: {duration} seconds")
93
+
94
+ # Get start and end time for cutting
95
+ start_time = st.number_input("Start Time (seconds)", min_value=0.0, max_value=duration - 1, step=0.1)
96
+ end_time = st.number_input("End Time (seconds)", min_value=start_time, value=duration, max_value=duration, step=0.1)
97
+
98
+ # Convert start and end time to sample indices
99
+ start_sample = int(start_time * sample_rate)
100
+ end_sample = int(end_time * sample_rate)
101
+
102
+ # Cut audio
103
+ cut_audio_data = audio_data[start_sample:end_sample]
104
+
105
+ # Create a temporary in-memory file for cut audio
106
+ audio_file = io.BytesIO()
107
+ sf.write(audio_file, cut_audio_data, sample_rate, format="wav")
108
+
109
+ # Display cut audio
110
+ st.audio(audio_file, format="audio/wav")
111
+ audio_file = (name, audio_file)
112
+
113
+ return audio_file
114
+
115
+
116
+ def display_predictions(predictions: dict):
117
+ """
118
+ Display the predictions using instrument names instead of codes.
119
+
120
+ :param predictions: A dictionary containing the filenames and instruments detected in them.
121
+ :type predictions: dict
122
+ """
123
+
124
+ # Display the results using instrument names instead of codes
125
+ for filename, instruments in predictions.items():
126
+ st.subheader(filename)
127
+
128
+ if isinstance(instruments, str):
129
+ st.write(instruments)
130
+
131
+ else:
132
+ with st.container():
133
+ col1, col2 = st.columns([1, 3])
134
+ present_instruments = [
135
+ INSTRUMENTS[instrument_code] for instrument_code, presence in instruments.items() if presence
136
+ ]
137
+ if present_instruments:
138
+ for instrument_name in present_instruments:
139
+ with col1:
140
+ st.write(instrument_name)
141
+ with col2:
142
+ st.write("✔️")
143
+ else:
144
+ st.write("No instruments found in this file.")
145
+
146
+
147
+ def health_check():
148
+ """
149
+ Sends a health check request to the API and checks if it's running.
150
+
151
+ :return: Returns True if the API is running, else False.
152
+ :rtype: bool
153
+ """
154
+
155
+ # Send a health check request to the API
156
+ response = requests.get(f"{backend}/health-check", timeout=100)
157
+
158
+ # Check if the API is running
159
+ if response.status_code == 200:
160
+ return True
161
+ else:
162
+ return False
163
+
164
+
165
+ def predict(data, model_name):
166
+ """
167
+ Sends a POST request to the API with the provided data and model name.
168
+
169
+ :param data: The audio data to be used for prediction.
170
+ :type data: bytes
171
+ :param model_name: The name of the model to be used for prediction.
172
+ :type model_name: str
173
+ :return: The response from the API.
174
+ :rtype: requests.Response
175
+ """
176
+
177
+ file = {"file": data}
178
+ request_data = {"model_name": model_name}
179
+
180
+ response = requests.post(
181
+ f"{backend}/predict", params=request_data, files=file, timeout=300
182
+ ) # Replace with your API endpoint URL
183
+
184
+ return response
185
+
186
+
187
+ @st.cache_data(show_spinner=False)
188
+ def predict_single(audio_file, name, selected_model):
189
+ """
190
+ Predicts the instruments in a single audio file using the selected model.
191
+
192
+ :param audio_file: The audio file to be used for prediction.
193
+ :type audio_file: bytes
194
+ :param name: The name of the audio file.
195
+ :type name: str
196
+ :param selected_model: The name of the selected model.
197
+ :type selected_model: str
198
+ :return: A dictionary containing the predicted instruments for the audio file.
199
+ :rtype: dict
200
+ """
201
+
202
+ predictions = {}
203
+
204
+ with st.spinner("Predicting instruments..."):
205
+ response = predict(audio_file, selected_model)
206
+
207
+ if response.status_code == 200:
208
+ prediction = response.json()["prediction"]
209
+ predictions[name] = prediction.get(name, "Error making prediction")
210
+ else:
211
+ st.write(response)
212
+ try:
213
+ st.json(response.json())
214
+ except JSONDecodeError:
215
+ st.error(response.text)
216
+ st.stop()
217
+ return predictions
218
+
219
+
220
+ @st.cache_data(show_spinner=False)
221
+ def predict_multiple(audio_files, selected_model):
222
+ """
223
+ Generates predictions for multiple audio files using the selected model.
224
+
225
+ :param audio_files: A list of audio files to make predictions on.
226
+ :type audio_files: List[UploadedFile]
227
+ :param selected_model: The model to use for making predictions.
228
+ :type selected_model: str
229
+ :return: A dictionary where the keys are the names of the audio files and the values are the predicted labels.
230
+ :rtype: Dict[str, str]
231
+ """
232
+
233
+ predictions = {}
234
+ progress_text = "Getting predictions for all files. Please wait."
235
+ progress_bar = st.empty()
236
+ progress_bar.progress(0, text=progress_text)
237
+
238
+ num_files = len(audio_files)
239
+
240
+ for i, file in enumerate(audio_files):
241
+ name = file.name
242
+ response = predict(file, selected_model)
243
+ if response.status_code == 200:
244
+ prediction = response.json()["prediction"]
245
+ predictions[name] = prediction[name]
246
+ progress_bar.progress((i + 1) / num_files, text=progress_text)
247
+ else:
248
+ predictions[name] = "Error making prediction."
249
+ progress_bar.empty()
250
+ return predictions
251
+
252
+
253
+ if __name__ == "__main__":
254
+ pass
src/modeling/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from src.modeling.models import ASTPretrained, StudentAST
2
+ from src.modeling.transforms import FeatureExtractor, PreprocessPipeline
src/modeling/dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Optional, Tuple, Type, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torchvision.transforms import Compose
9
+
10
+ import modeling.transforms as transform_module
11
+ from modeling.transforms import (
12
+ LabelsFromTxt,
13
+ OneHotEncode,
14
+ ParentMultilabel,
15
+ Preprocess,
16
+ Transform,
17
+ )
18
+ from modeling.utils import CLASSES, get_wav_files, init_obj, init_transforms
19
+
20
+
21
+ class IRMASDataset(Dataset):
22
+ """Dataset class for IRMAS dataset.
23
+
24
+ :param audio_dir: Directory containing the audio files
25
+ :type audio_dir: Union[str, Path]
26
+ :param preprocess: Preprocessing method to apply to the audio files
27
+ :type preprocess: Type[Preprocess]
28
+ :param signal_augments: Signal augmentation method to apply to the audio files, defaults to None
29
+ :type signal_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
30
+ :param transforms: Transform method to apply to the audio files, defaults to None
31
+ :type transforms: Optional[Union[Type[Compose], Type[Transform]]], optional
32
+ :param spec_augments: Spectrogram augmentation method to apply to the audio files, defaults to None
33
+ :type spec_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
34
+ :param subset: Subset of the data to load (train, valid, or test), defaults to "train"
35
+ :type subset: str, optional
36
+ :raises AssertionError: Raises an assertion error if subset is not train, valid or test
37
+ :raises OSError: Raises an OS error if test_songs.txt is not found in the data folder
38
+ :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
39
+ :rtype: Tuple[Tensor, Tensor]
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ audio_dir: Union[str, Path],
45
+ preprocess: Type[Preprocess],
46
+ signal_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
47
+ transforms: Optional[Union[Type[Compose], Type[Transform]]] = None,
48
+ spec_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
49
+ subset: str = "train",
50
+ ):
51
+ self.files = get_wav_files(audio_dir)
52
+ assert subset in ["train", "valid", "test"], "Subset can only be train, valid or test"
53
+ self.subset = subset
54
+
55
+ if self.subset != "train":
56
+ try:
57
+ test_songs = np.genfromtxt("../data/test_songs.txt", dtype=str, ndmin=1, delimiter="\n")
58
+ except OSError as e:
59
+ print("Error: {e}")
60
+ print("test_songs.txt not found in data/. Please generate a split before training")
61
+ raise e
62
+
63
+ if self.subset == "valid":
64
+ self.files = [file for file in self.files if Path(file).stem not in test_songs]
65
+ if self.subset == "test":
66
+ self.files = [file for file in self.files if Path(file).stem in test_songs]
67
+
68
+ self.preprocess = preprocess
69
+ self.transforms = transforms
70
+ self.signal_augments = signal_augments
71
+ self.spec_augments = spec_augments
72
+
73
+ def __len__(self):
74
+ """Return the length of the dataset.
75
+
76
+ :return: The length of the dataset
77
+ :rtype: int
78
+ """
79
+
80
+ return len(self.files)
81
+
82
+ def __getitem__(self, index):
83
+ """Get an item from the dataset.
84
+
85
+ :param index: The index of the item to get
86
+ :type index: int
87
+ :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
88
+ :rtype: Tuple[Tensor, Tensor]
89
+ """
90
+
91
+ sample_path = self.files[index]
92
+ signal = self.preprocess(sample_path)
93
+
94
+ if self.subset == "train":
95
+ target_transforms = Compose([ParentMultilabel(sep="-"), OneHotEncode(CLASSES)])
96
+ else:
97
+ target_transforms = Compose([LabelsFromTxt(), OneHotEncode(CLASSES)])
98
+
99
+ label = target_transforms(sample_path)
100
+
101
+ if self.signal_augments is not None and self.subset == "train":
102
+ signal = self.signal_augments(signal)
103
+
104
+ if self.transforms is not None:
105
+ signal = self.transforms(signal)
106
+
107
+ if self.spec_augments is not None and self.subset == "train":
108
+ signal = self.spec_augments(signal)
109
+
110
+ return signal, label.float()
111
+
112
+
113
+ def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]):
114
+ """
115
+ Function to collate a batch of audio signals and their corresponding labels.
116
+
117
+ :param data: A list of tuples containing the audio signals and their corresponding labels.
118
+ :type data: List[Tuple[torch.Tensor, torch.Tensor]]
119
+
120
+ :return: A tuple containing the batch of audio signals and their corresponding labels.
121
+ :rtype: Tuple[torch.Tensor, torch.Tensor]
122
+ """
123
+
124
+ features, labels = zip(*data)
125
+ features = [item.squeeze().T for item in features]
126
+ # Pads items to same length if they're not
127
+ features = pad_sequence(features, batch_first=True)
128
+ labels = torch.stack(labels)
129
+
130
+ return features, labels
131
+
132
+
133
+ def get_loader(config: dict, subset: str):
134
+ """
135
+ Function to create a PyTorch DataLoader for a given subset of the IRMAS dataset.
136
+
137
+ :param config: A configuration object.
138
+ :type config: Any
139
+ :param subset: The subset of the dataset to use. Can be "train" or "valid".
140
+ :type subset: str
141
+
142
+ :return: A PyTorch DataLoader for the specified subset of the dataset.
143
+ :rtype: torch.utils.data.DataLoader
144
+ """
145
+
146
+ dst = IRMASDataset(
147
+ config.train_dir if subset == "train" else config.valid_dir,
148
+ preprocess=init_obj(config.preprocess, transform_module),
149
+ transforms=init_obj(config.transforms, transform_module),
150
+ signal_augments=init_transforms(config.signal_augments, transform_module),
151
+ spec_augments=init_transforms(config.spec_augments, transform_module),
152
+ subset=subset,
153
+ )
154
+
155
+ return DataLoader(
156
+ dst,
157
+ batch_size=config.batch_size,
158
+ shuffle=True if subset == "train" else False,
159
+ pin_memory=True if torch.cuda.is_available() else False,
160
+ num_workers=torch.get_num_threads() - 1,
161
+ collate_fn=collate_fn,
162
+ )
src/modeling/learner.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ import wandb
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.autonotebook import tqdm
11
+
12
+ import modeling.loss as loss_module
13
+ import modeling.metrics as metrics_module
14
+ from modeling.loss import HardDistillationLoss
15
+ from modeling.models import freeze, layerwise_lr_decay
16
+ from modeling.utils import init_obj
17
+
18
+
19
+ class BaseLearner(ABC):
20
+ """
21
+ Abstract base class for a learner.
22
+
23
+ :param train_dl: DataLoader for training data
24
+ :type train_dl: Type[DataLoader]
25
+ :param valid_dl: DataLoader for validation data
26
+ :type valid_dl: Type[DataLoader]
27
+ :param model: Model to be trained
28
+ :type model: Type[nn.Module]
29
+ :param config: Configuration object
30
+ :type config: Any
31
+ """
32
+
33
+ def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config):
34
+ self.train_dl = train_dl
35
+ self.valid_dl = valid_dl
36
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ self.model = model.to(self.device)
38
+ self.config = config
39
+
40
+ @abstractmethod
41
+ def fit(
42
+ self,
43
+ ):
44
+ """Abstract method for fitting the model."""
45
+
46
+ pass
47
+
48
+ @abstractmethod
49
+ def _train_epoch(
50
+ self,
51
+ ):
52
+ """Abstract method for training the model for one epoch."""
53
+ pass
54
+
55
+ @abstractmethod
56
+ def _test_epoch(
57
+ self,
58
+ ):
59
+ """Abstract method for testing the model for one epoch."""
60
+ pass
61
+
62
+
63
+ class Learner(BaseLearner):
64
+ def __init__(self, train_dl: DataLoader, valid_dl: DataLoader, model: nn.Module, config):
65
+ """
66
+ A class that inherits from the BaseLearner class and represents a learner object.
67
+
68
+ :param train_dl: DataLoader for training data
69
+ :type train_dl: DataLoader
70
+ :param valid_dl: DataLoader for validation data
71
+ :type valid_dl: DataLoader
72
+ :param model: Model to be trained
73
+ :type model: nn.Module
74
+ :param config: Configuration object
75
+ :type config: Any
76
+ """
77
+
78
+ super().__init__(train_dl, valid_dl, model, config)
79
+
80
+ self.model = torch.nn.DataParallel(module=self.model, device_ids=list(range(config.num_gpus)))
81
+ self.loss_fn = init_obj(self.config.loss, loss_module)
82
+ params = layerwise_lr_decay(self.config, self.model)
83
+ self.optimizer = init_obj(self.config.optimizer, optim, params)
84
+ self.scheduler = init_obj(
85
+ self.config.scheduler,
86
+ optim.lr_scheduler,
87
+ self.optimizer,
88
+ max_lr=[param["lr"] for param in params],
89
+ epochs=self.config.epochs,
90
+ steps_per_epoch=int(np.ceil(len(train_dl) / self.config.num_accum)),
91
+ )
92
+
93
+ self.verbose = self.config.verbose
94
+ self.metrics = MetricTracker(self.config.metrics, self.verbose)
95
+ self.scaler = torch.cuda.amp.GradScaler()
96
+
97
+ self.train_step = 0
98
+ self.test_step = 0
99
+
100
+ def fit(self, model_name: str = "model"):
101
+ """
102
+ Method to train the model.
103
+
104
+ :param model_name: Name of the model to be saved, defaults to "model"
105
+ :type model_name: str, optional
106
+ """
107
+
108
+ loop = tqdm(range(self.config.epochs), leave=False)
109
+
110
+ for epoch in loop:
111
+ train_loss = self._train_epoch()
112
+ val_loss = self._test_epoch()
113
+
114
+ wandb.log({"train_loss": train_loss, "val_loss": val_loss, "epoch": epoch + 1})
115
+
116
+ if self.verbose:
117
+ print(f"| EPOCH: {epoch+1} | train_loss: {train_loss:.3f} | val_loss: {val_loss:.3f} |\n")
118
+ self.metrics.display()
119
+
120
+ if self.config.save_last_checkpoint:
121
+ torch.save(self.model.module.state_dict(), f"{model_name}.pth")
122
+
123
+ def _train_epoch(self, distill: bool = False):
124
+ """
125
+ Method to perform one epoch of training.
126
+
127
+ :param distill: Flag to indicate if knowledge distillation is used, defaults to False
128
+ :type distill: bool, optional
129
+ :return: Average training loss for the epoch
130
+ :rtype: float
131
+ """
132
+
133
+ if distill:
134
+ print("Distilling knowledge...", flush=True)
135
+
136
+ loop = tqdm(self.train_dl, leave=False)
137
+ self.model.train()
138
+
139
+ num_batches = len(self.train_dl)
140
+ train_loss = 0
141
+
142
+ for idx, (xb, yb) in enumerate(loop):
143
+ xb = xb.to(self.device)
144
+ yb = yb.to(self.device)
145
+
146
+ # forward
147
+ with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=not distill):
148
+ predictions = self.model(xb)
149
+
150
+ if distill:
151
+ loss = self.KDloss_fn(xb, predictions, yb)
152
+ else:
153
+ loss = self.loss_fn(predictions, yb)
154
+
155
+ loss /= self.config.num_accum
156
+
157
+ # backward
158
+ self.scaler.scale(loss).backward()
159
+ wandb.log({f"lr_param_group_{i}": lr for i, lr in enumerate(self.scheduler.get_last_lr())})
160
+
161
+ if ((idx + 1) % self.config.num_accum == 0) or (idx + 1 == num_batches):
162
+ self.scaler.step(self.optimizer)
163
+ self.scaler.update()
164
+ self.scheduler.step()
165
+ self.optimizer.zero_grad()
166
+
167
+ # update loop
168
+ loop.set_postfix(loss=loss.item())
169
+ self.train_step += 1
170
+ wandb.log({"train_loss_per_batch": loss.item(), "train_step": self.train_step})
171
+ train_loss += loss.item()
172
+
173
+ if distill:
174
+ if ((idx + 1) % 2500 == 0) and not (idx + 1 == num_batches):
175
+ val_loss = self._test_epoch()
176
+ wandb.log({"val_loss": val_loss})
177
+ self.model.train()
178
+
179
+ train_loss /= num_batches
180
+
181
+ return train_loss
182
+
183
+ def _test_epoch(self):
184
+ """
185
+ Method to perform one epoch of validation/testing.
186
+
187
+ :return: Average validation/test loss for the epoch
188
+ :rtype: float
189
+ """
190
+
191
+ loop = tqdm(self.valid_dl, leave=False)
192
+ self.model.eval()
193
+
194
+ num_batches = len(self.valid_dl)
195
+ preds = []
196
+ targets = []
197
+ test_loss = 0
198
+
199
+ with torch.no_grad():
200
+ for xb, yb in loop:
201
+ xb, yb = xb.to(self.device), yb.to(self.device)
202
+ pred = self.model(xb)
203
+ loss = self.loss_fn(pred, yb).item()
204
+ self.test_step += 1
205
+ wandb.log({"valid_loss_per_batch": loss, "test_step": self.test_step})
206
+ test_loss += loss
207
+
208
+ pred = torch.sigmoid(pred)
209
+ preds.extend(pred.cpu().numpy())
210
+ targets.extend(yb.cpu().numpy())
211
+
212
+ preds, targets = np.array(preds), np.array(targets)
213
+ self.metrics.update(preds, targets)
214
+ test_loss /= num_batches
215
+
216
+ return test_loss
217
+
218
+
219
+ class KDLearner(Learner):
220
+ """
221
+ Knowledge Distillation Learner class for training a student model with knowledge distillation.
222
+
223
+ :param train_dl: Train data loader
224
+ :type train_dl: DataLoader
225
+ :param valid_dl: Validation data loader
226
+ :type valid_dl: DataLoader
227
+ :param student_model: Student model to be trained
228
+ :type student_model: nn.Module
229
+ :param teacher: Teacher model for knowledge distillation
230
+ :type teacher: nn.Module
231
+ :param thresholds: Thresholds for HardDistillationLoss
232
+ :type thresholds: List[float]
233
+ :param config: Configuration object for training
234
+ :type config: Config
235
+ """
236
+
237
+ def __init__(self, train_dl, valid_dl, student_model, teacher, thresholds, config):
238
+ super().__init__(train_dl, valid_dl, student_model, config)
239
+
240
+ self.teacher = nn.DataParallel(freeze(teacher).to(self.device))
241
+ self.KDloss_fn = HardDistillationLoss(self.teacher, self.loss_fn, thresholds, self.device)
242
+ self.scaler = torch.cuda.amp.GradScaler(enabled=False)
243
+
244
+ def _train_epoch(self):
245
+ """
246
+ Method to perform one epoch of training with knowledge distillation.
247
+
248
+ :return: Average training loss for the epoch
249
+ :rtype: float
250
+ """
251
+
252
+ return super()._train_epoch(distill=True)
253
+
254
+
255
+ class MetricTracker:
256
+ """
257
+ Metric Tracker class for tracking evaluation metrics during model validation.
258
+ This class is used to track and display evaluation metrics during model validation.
259
+ It keeps track of the results of the provided metric functions for each validation batch,
260
+ and logs them to Weights & Biases using wandb.log(). The display() method can be used
261
+ to print the tracked metric results, if verbose is set to True during initialization.
262
+
263
+ :param metrics: List of metric functions to track
264
+ :type metrics: List[Callable]
265
+ :param verbose: Flag to indicate whether to print the results or not, defaults to True
266
+ :type verbose: bool, optional
267
+ """
268
+
269
+ def __init__(self, metrics, verbose: bool = True):
270
+ self.metrics_fn = [getattr(metrics_module, metric) for metric in metrics]
271
+ self.verbose = verbose
272
+ self.result = None
273
+
274
+ def update(self, preds, targets):
275
+ """
276
+ Update the metric tracker with the latest predictions and targets.
277
+
278
+ :param preds: Model predictions
279
+ :type preds: torch.Tensor
280
+ :param targets: Ground truth targets
281
+ :type targets: torch.Tensor
282
+ """
283
+
284
+ self.result = {metric.__name__: metric(preds, targets) for metric in self.metrics_fn}
285
+ wandb.log(self.result)
286
+
287
+ def display(self):
288
+ """Display the tracked metric results."""
289
+
290
+ for k, v in self.result.items():
291
+ print(f"{k}: {v:.2f}")
292
+
293
+
294
+ def get_preds(data: DataLoader, model: nn.Module, device: str = "cpu") -> Tuple[np.ndarray, np.ndarray]:
295
+ """
296
+ Get predictions and targets from a data loader and a PyTorch model.
297
+
298
+ :param data: A PyTorch DataLoader containing the data to predict on.
299
+ :type data: torch.utils.data.DataLoader
300
+ :param model: A PyTorch model to use for predictions.
301
+ :type model: torch.nn.Module
302
+ :param device: The device to use for predictions (default is "cpu").
303
+ :type device: str
304
+ :raises TypeError: If any of the input arguments is of an incorrect type.
305
+ :return: A tuple containing two NumPy arrays: the predictions and the targets.
306
+ :rtype: Tuple[numpy.ndarray, numpy.ndarray]
307
+ """
308
+
309
+ if not isinstance(data, DataLoader):
310
+ raise TypeError("The 'data' argument must be a PyTorch DataLoader.")
311
+ if not isinstance(model, nn.Module):
312
+ raise TypeError("The 'model' argument must be a PyTorch model.")
313
+ if not isinstance(device, str):
314
+ raise TypeError("The 'device' argument must be a string.")
315
+
316
+ loop = tqdm(data, leave=False)
317
+ model = model.to(device)
318
+ model.eval()
319
+
320
+ preds = []
321
+ targets = []
322
+
323
+ with torch.no_grad():
324
+ for xb, yb in loop:
325
+ xb, yb = xb.to(device), yb.to(device)
326
+ pred = model(xb)
327
+ pred = torch.sigmoid(pred)
328
+ preds.extend(pred.cpu().numpy())
329
+ targets.extend(yb.cpu().numpy())
330
+
331
+ preds, targets = np.array(preds), np.array(targets)
332
+
333
+ return preds, targets
src/modeling/loss.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.ops import sigmoid_focal_loss
8
+
9
+
10
+ class FocalLoss(nn.Module):
11
+ """
12
+ Focal Loss implementation.
13
+
14
+ This class defines the Focal Loss, which is a variant of the Binary Cross Entropy (BCE) loss that is
15
+ designed to address the problem of class imbalance in binary classification tasks.
16
+ The Focal Loss introduces two hyperparameters, alpha and gamma, to control the balance between easy
17
+ and hard examples during training.
18
+
19
+ :param alpha: The balancing parameter between positive and negative examples. A float value between 0 and 1.
20
+ If set to -1, no balancing is applied. Default is 0.25.
21
+ :type alpha: float
22
+ :param gamma: The focusing parameter to control the emphasis on hard examples. A positive integer. Default is 2.
23
+ :type gamma: int
24
+ """
25
+
26
+ def __init__(self, alpha: float = 0.25, gamma: int = 2):
27
+ super().__init__()
28
+ self.loss_fn = partial(sigmoid_focal_loss, alpha=alpha, gamma=gamma, reduction="mean")
29
+
30
+ def forward(self, inputs, targets):
31
+ """
32
+ Compute the Focal Loss.
33
+
34
+ :param inputs: The predicted inputs from the model.
35
+ :type inputs: torch.Tensor
36
+ :param targets: The ground truth targets.
37
+ :type targets: torch.Tensor
38
+ :return: The computed Focal Loss.
39
+ :rtype: torch.Tensor
40
+ :raises ValueError: If the inputs and targets have different shapes.
41
+ """
42
+
43
+ return self.loss_fn(inputs=inputs, targets=targets)
44
+
45
+
46
+ class HardDistillationLoss(nn.Module):
47
+ """Hard Distillation Loss implementation.
48
+
49
+ This class defines the Hard Distillation Loss, which is used for model distillation,
50
+ a technique used to transfer knowledge from a large, complex teacher model to a smaller,
51
+ simpler student model. The Hard Distillation Loss computes the loss by comparing the outputs
52
+ of the student model and the teacher model using a provided loss function. It also introduces a
53
+ threshold parameter to convert the teacher model outputs to binary labels for the distillation process.
54
+
55
+ :param teacher: The teacher model used for distillation.
56
+ :type teacher: torch.nn.Module
57
+ :param loss_fn: The loss function used for computing the distillation loss.
58
+ :type loss_fn: torch.nn.Module
59
+ :param threshold: The threshold value used to convert teacher model outputs to binary labels.
60
+ Can be a list or numpy array of threshold values.
61
+ :type threshold: Union[list, np.array]
62
+ :param device: The device to be used for computation. Default is "cuda".
63
+ :type device: str
64
+ """
65
+
66
+ def __init__(self, teacher: nn.Module, loss_fn: nn.Module, threshold: Union[list, np.array], device: str = "cuda"):
67
+ super().__init__()
68
+ self.teacher = teacher
69
+ self.loss_fn = loss_fn
70
+ self.threshold = torch.tensor(threshold).to(device)
71
+
72
+ def forward(self, inputs, student_outputs, targets):
73
+ """
74
+ Compute the Hard Distillation Loss.
75
+
76
+ :param inputs: The input data fed to the student model.
77
+ :type inputs: torch.Tensor
78
+ :param student_outputs: The output predictions from the student model, which consists of
79
+ both classification and distillation outputs.
80
+ :type student_outputs: tuple
81
+ :param targets: The ground truth targets.
82
+ :type targets: torch.Tensor
83
+ :return: The computed Hard Distillation Loss.
84
+ :rtype: torch.Tensor
85
+ :raises ValueError: If the inputs and targets have different shapes.
86
+ """
87
+
88
+ outputs_cls, outputs_dist = student_outputs
89
+
90
+ teacher_outputs = torch.sigmoid(self.teacher(inputs))
91
+ teacher_labels = (teacher_outputs > self.threshold).float()
92
+
93
+ base_loss = self.loss_fn(outputs_cls, targets)
94
+ teacher_loss = self.loss_fn(outputs_dist, teacher_labels)
95
+
96
+ return (base_loss + teacher_loss) / 2
src/modeling/metrics.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import (
3
+ accuracy_score,
4
+ average_precision_score,
5
+ f1_score,
6
+ hamming_loss,
7
+ precision_recall_curve,
8
+ zero_one_loss,
9
+ )
10
+
11
+
12
+ def hamming_score(preds, targets, thresholds: np.array = None):
13
+ """Compute Hamming Score.
14
+
15
+ This function computes the Hamming Score, a performance metric used for multi-label classification tasks.
16
+ The Hamming Score measures the similarity between the predicted labels and the ground truth labels, where
17
+ a higher score indicates better prediction accuracy.
18
+
19
+ :param preds: The predicted labels.
20
+ :type preds: numpy array
21
+ :param targets: The ground truth labels.
22
+ :type targets: numpy array
23
+ :return: The computed Hamming Score.
24
+ :rtype: int
25
+ """
26
+ if thresholds is None:
27
+ thresholds = optimize_accuracy(preds, targets)
28
+
29
+ preds = (preds > thresholds).astype(int)
30
+ return 1 - hamming_loss(targets, preds)
31
+
32
+
33
+ def zero_one_score(preds, targets, thresholds: np.array = None):
34
+ """
35
+ Compute Zero-One Score.
36
+
37
+ This function computes the Zero-One Score, a performance metric used for
38
+ multi-label classification tasks. The Zero-One Score measures the similarity
39
+ between the predicted labels and the ground truth labels, where a higher score
40
+ indicates better prediction accuracy. The Zero-One Score ranges from 0 to 1, with 1 being a perfect match.
41
+
42
+ :param preds: The predicted labels.
43
+ :type preds: numpy array
44
+ :param targets: The ground truth labels.
45
+ :type targets: numpy array
46
+ :return: The computed Zero-One Score.
47
+ :rtype: int
48
+ """
49
+
50
+ if thresholds is None:
51
+ thresholds = optimize_accuracy(preds, targets)
52
+
53
+ preds = (preds > thresholds).astype(int)
54
+ return 1 - zero_one_loss(targets, preds, normalize=True)
55
+
56
+
57
+ def mean_f1_score(preds, targets, thresholds: np.array = None):
58
+ """Compute Mean F1 Score.
59
+
60
+ This function computes the Mean F1 Score, a performance metric used for multi-label
61
+ classification tasks. The Mean F1 Score measures the trade-off between precision and recall,
62
+ where a higher score indicates better prediction accuracy. The Mean F1 Score ranges from
63
+ 0 to 1, with 1 being a perfect match.
64
+
65
+ :param preds: The predicted labels.
66
+ :type preds: numpy array
67
+ :param targets: The ground truth labels.
68
+ :type targets: numpy array
69
+ :return: The computed Mean F1 Score.
70
+ :rtype: int
71
+ """
72
+ if thresholds is None:
73
+ thresholds = optimize_f1_score(preds, targets)
74
+
75
+ preds = (preds > thresholds).astype(int)
76
+ return f1_score(targets, preds, average="samples", zero_division=0)
77
+
78
+
79
+ def per_instr_f1_score(preds, targets, thresholds: np.array = None):
80
+ """Compute Per-Instrument F1 Score.
81
+
82
+ This function computes the F1 Score for each instrument separately in a multi-label
83
+ classification task. The Per-Instrument F1 Score measures the prediction accuracy for
84
+ each instrument class independently. The F1 Score is the harmonic mean of precision and recall,
85
+ where a higher score indicates better prediction accuracy. The Per-Instrument F1 Score ranges
86
+ from 0 to 1, with 1 being a perfect match.
87
+
88
+ :param preds: The predicted labels.
89
+ :type preds: numpy array
90
+ :param targets: The ground truth labels.
91
+ :type targets: numpy array
92
+ :return: The computed Per-Instrument F1 Score.
93
+ :rtype: numpy array
94
+ """
95
+
96
+ if thresholds is None:
97
+ thresholds = optimize_f1_score(preds, targets)
98
+
99
+ preds = (preds > thresholds).astype(int)
100
+ return f1_score(targets, preds, average=None, zero_division=0)
101
+
102
+
103
+ def mean_average_precision(preds, targets):
104
+ """
105
+ Compute mean Average Precision (mAP).
106
+
107
+ This function computes the mean Average Precision (mAP), a performance metric used
108
+ for multi-label classification tasks. The mAP measures the average precision across
109
+ all classes, taking into account the precision-recall trade-off, where a higher score
110
+ indicates better prediction accuracy.
111
+
112
+ :param preds: The predicted probabilities or scores.
113
+ :type preds: numpy array
114
+ :param targets: The ground truth labels.
115
+ :type targets: numpy array
116
+ :return: The computed mAP score.
117
+ :rtype: int
118
+ """
119
+
120
+ return average_precision_score(targets, preds, average="samples")
121
+
122
+
123
+ def optimize_f1_score(preds, targets):
124
+ """
125
+ Optimize Threshold.
126
+
127
+ This function optimizes the threshold for binary classification based on the predicted probabilities
128
+ and ground truth labels. It computes the precision, recall, and F1 Score for each class separately
129
+ using the precision_recall_curve function from sklearn.metrics module. It then selects the threshold
130
+ that maximizes the F1 Score for each class.
131
+
132
+ :param preds: The predicted probabilities.
133
+ :type preds: numpy array
134
+ :param targets: The ground truth labels.
135
+ :type targets: numpy array
136
+ :return: The optimized thresholds for binary classification.
137
+ :rtype: numpy array
138
+ """
139
+
140
+ label_thresholds = np.empty(preds.shape[1])
141
+
142
+ for i in range(preds.shape[1]):
143
+ precision, recall, thresholds = precision_recall_curve(targets[:, i], preds[:, i])
144
+ fscore = (2 * precision * recall) / (precision + recall)
145
+ ix = np.argmax(fscore)
146
+ best_thresh = thresholds[ix]
147
+ label_thresholds[i] = best_thresh
148
+
149
+ return label_thresholds
150
+
151
+
152
+ def optimize_accuracy(preds, targets):
153
+ """
154
+ Determine the optimal threshold for each label, based on the predicted probabilities and the true targets,
155
+ in order to maximize the accuracy of the predictions.
156
+
157
+ :param preds: A 2D NumPy array containing the predicted probabilities for each label.
158
+ :type preds: numpy.ndarray
159
+ :param targets: A 2D NumPy array containing the true binary targets for each label.
160
+ :type targets: numpy.ndarray
161
+ :raises ValueError: If the input arrays are not 2D arrays or have incompatible shapes.
162
+ :return: A 1D NumPy array containing the optimal threshold for each label.
163
+ :rtype: numpy.ndarray
164
+ """
165
+
166
+ # Vary the threshold for each label and calculate accuracy for each threshold
167
+ thresholds = np.arange(0.0001, 1, 0.0001)
168
+ best_thresholds = np.empty(preds.shape[1])
169
+ for i in range(preds.shape[1]):
170
+ accuracies = []
171
+ for th in thresholds:
172
+ y_pred = (preds[:, i] >= th).astype(int) # Convert probabilities to binary predictions using the threshold
173
+ acc = accuracy_score(targets[:, i], y_pred)
174
+ accuracies.append(acc)
175
+ # Find the threshold that gives the highest accuracy for this label
176
+ best_idx = np.argmax(accuracies)
177
+ best_thresholds[i] = thresholds[best_idx]
178
+
179
+ return best_thresholds
src/modeling/models.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from warnings import warn
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import ASTConfig, ASTModel
7
+
8
+
9
+ class StudentAST(nn.Module):
10
+ """
11
+ A student model for audio classification using the AST architecture.
12
+
13
+ :param n_classes: The number of classes to classify.
14
+ :type n_classes: int
15
+ :param hidden_size: The number of units in the hidden layers, defaults to 384.
16
+ :type hidden_size: int, optional
17
+ :param num_heads: The number of attention heads to use, defaults to 6.
18
+ :type num_heads: int, optional
19
+ """
20
+
21
+ def __init__(self, n_classes: int, hidden_size: int = 384, num_heads: int = 6):
22
+ super().__init__()
23
+
24
+ config = ASTConfig(hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=hidden_size * 4)
25
+ self.base_model = ASTModel(config=config)
26
+ self.classifier = StudentClassificationHead(hidden_size, n_classes)
27
+
28
+ def forward(self, x: torch.Tensor):
29
+ """
30
+ Forward pass of the student model.
31
+
32
+ :param x: The input tensor of shape [batch_size, sequence_length, input_dim].
33
+ :type x: torch.Tensor
34
+ :return: The output tensor of shape [batch_size, n_classes].
35
+ :rtype: torch.Tensor
36
+ """
37
+
38
+ x = self.base_model(x)[0]
39
+ x = self.classifier(x)
40
+ return x
41
+
42
+
43
+ class StudentClassificationHead(nn.Module):
44
+ """
45
+ A classification head for the student model.
46
+
47
+ :param emb_size: The size of the embedding.
48
+ :type emb_size: int
49
+ :param n_classes: The number of classes to classify.
50
+ :type n_classes: int
51
+ """
52
+
53
+ def __init__(self, emb_size: int, n_classes: int):
54
+ super().__init__()
55
+
56
+ self.cls_head = nn.Linear(emb_size, n_classes)
57
+ self.dist_head = nn.Linear(emb_size, n_classes)
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ """
61
+ Forward pass of the classification head.
62
+
63
+ :param x: The input tensor of shape [batch_size, emb_size*2].
64
+ :type x: torch.Tensor
65
+ :return: The output tensor of shape [batch_size, n_classes].
66
+ :rtype: torch.Tensor
67
+ """
68
+
69
+ x_cls, x_dist = x[:, 0], x[:, 1]
70
+ x_cls_head = self.cls_head(x_cls)
71
+ x_dist_head = self.dist_head(x_dist)
72
+
73
+ if self.training:
74
+ x = x_cls_head, x_dist_head
75
+ else:
76
+ x = (x_cls_head + x_dist_head) / 2
77
+
78
+ return x
79
+
80
+
81
+ class ASTPretrained(nn.Module):
82
+ """
83
+ This class implements a PyTorch module for a pre-trained Audio Set Transformer (AST) model
84
+ fine-tuned on MIT's dataset for audio event classification.
85
+
86
+ :param n_classes: The number of classes for audio event classification.
87
+ :type n_classes: int
88
+ :param dropout: The dropout probability for the fully connected layer, defaults to 0.5.
89
+ :type dropout: float, optional
90
+ :raises ValueError: If n_classes is not positive.
91
+ :raises TypeError: If dropout is not a float or is not between 0 and 1.
92
+ :return: The output tensor of shape [batch_size, n_classes] containing the probabilities of each class.
93
+ :rtype: torch.Tensor
94
+ """
95
+
96
+ def __init__(self, n_classes: int, download_weights: bool = True, freeze_body: bool = False, dropout: float = 0.5):
97
+ super().__init__()
98
+
99
+ if download_weights:
100
+ self.base_model = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
101
+ else:
102
+ config = ASTConfig()
103
+ self.base_model = ASTModel(config=config)
104
+
105
+ if freeze_body:
106
+ self.base_model = freeze(self.base_model)
107
+
108
+ fc_in = self.base_model.config.hidden_size
109
+
110
+ self.classifier = nn.Sequential(
111
+ nn.LayerNorm((fc_in,), eps=1e-12), nn.Dropout(p=dropout), nn.Linear(fc_in, n_classes)
112
+ )
113
+
114
+ def forward(self, x):
115
+ """Passes the input tensor through the pre-trained Audio Set Transformer (AST) model
116
+ followed by a fully connected layer.
117
+
118
+ :param x: The input tensor of shape [batch_size, seq_len, num_features].
119
+ :type x: torch.Tensor
120
+ :return: The output tensor of shape [batch_size, n_classes] containing the probabilities of each class.
121
+ :rtype: torch.Tensor
122
+ :raises ValueError: If the shape of x is not [batch_size, seq_len, num_features].
123
+ """
124
+
125
+ x = self.base_model(x)[1]
126
+ x = self.classifier(x)
127
+ return x
128
+
129
+
130
+ def layerwise_lr_decay(config, model: ASTModel):
131
+ """
132
+ LLRD (Layer-wise Learning Rate Decay) function computes the learning rate for each layer in a deep neural network
133
+ using a specific decay rate and a base learning rate for the optimizer.
134
+
135
+ :param config: A configuration object that contains the parameters required for LLRD.
136
+ :type config: Any
137
+ :param model: A PyTorch neural network model.
138
+ :type model: ASTModel
139
+
140
+ :raises Warning: If the configuration object does not contain the LLRD parameters.
141
+
142
+ :return: A dictionary containing the optimizer parameters (parameters, weight decay, and learning rate)
143
+ for each layer.
144
+ :rtype: dict
145
+ """
146
+
147
+ try:
148
+ config = config.LLRD
149
+ except Exception:
150
+ warn("No LLRD found in config. Learner will use single lr for whole model.")
151
+ return None
152
+
153
+ lr = config["base_lr"]
154
+ weight_decay = config["weight_decay"]
155
+ no_decay = ["bias", "layernorm"]
156
+ body = ["embeddings", "encoder.layer"]
157
+ head_params = [(n, p) for n, p in model.named_parameters() if not any(body_param in n for body_param in body)]
158
+ optimizer_grouped_parameters = [
159
+ {
160
+ "params": [p for n, p in head_params if not any(nd in n for nd in no_decay)],
161
+ "weight_decay": weight_decay,
162
+ "lr": lr,
163
+ },
164
+ {
165
+ "params": [p for n, p in head_params if any(nd in n for nd in no_decay)],
166
+ "weight_decay": 0.0,
167
+ "lr": lr,
168
+ },
169
+ ]
170
+
171
+ # initialize lrs for every layer
172
+ layers = [getattr(model.module, config["body"]).embeddings] + list(
173
+ getattr(model.module, config["body"]).encoder.layer
174
+ )
175
+ layers.reverse()
176
+ for layer in layers:
177
+ lr *= config["lr_decay_rate"]
178
+ optimizer_grouped_parameters += [
179
+ {
180
+ "params": [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
181
+ "weight_decay": weight_decay,
182
+ "lr": lr,
183
+ },
184
+ {
185
+ "params": [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)],
186
+ "weight_decay": 0.0,
187
+ "lr": lr,
188
+ },
189
+ ]
190
+
191
+ return optimizer_grouped_parameters
192
+
193
+
194
+ def freeze(model: nn.Module):
195
+ """
196
+ Freeze function sets the requires_grad attribute to False for all parameters
197
+ in the given PyTorch neural network model. This is used to freeze the weights of
198
+ the model during training or inference.
199
+
200
+ :param model: A PyTorch neural network model.
201
+ :type model: nn.Module
202
+
203
+ :return: The same model with requires_grad attribute set to False for all parameters.
204
+ :rtype: nn.Module
205
+ """
206
+
207
+ model.eval()
208
+ for param in model.parameters():
209
+ param.requires_grad = False
210
+
211
+ return model
212
+
213
+
214
+ def unfreeze(model: nn.Module):
215
+ """
216
+ Unfreeze the model by setting requires_grad to True for all parameters.
217
+
218
+ :param model: The model to unfreeze.
219
+ :type model: nn.Module
220
+ :return: The unfrozen model.
221
+ :rtype: nn.Module
222
+ """
223
+
224
+ model.train()
225
+ for param in model.parameters():
226
+ param.requires_grad = True
227
+
228
+ return model
229
+
230
+
231
+ def interpolate_params(student: nn.Module, teacher: nn.Module):
232
+ """
233
+ Interpolate parameters between two models. This function scales the parameters of the
234
+ teacher model to match the shape of the corresponding parameters in the student model
235
+ using bilinear interpolation. If the shapes of the parameters in the two models are already the same,
236
+ the parameters are unchanged.
237
+
238
+ :param student: The student model.
239
+ :type student: nn.Module
240
+ :param teacher: The teacher model.
241
+ :type teacher: nn.Module
242
+ :return: A dictionary of interpolated parameters for the student model.
243
+ :rtype: dict
244
+ """
245
+
246
+ new_params = {}
247
+
248
+ # Iterate over the parameters in the first model
249
+ for name, param in teacher.base_model.named_parameters():
250
+ # Scale the parameter using interpolate if its shape is different from that of the second model
251
+ target_param = student.base_model.state_dict()[name]
252
+ if param.shape != target_param.shape:
253
+ squeeze_count = 0
254
+ permuted = False
255
+ while param.ndim < 4:
256
+ param = param.unsqueeze(0)
257
+ squeeze_count += 1
258
+
259
+ if param.shape[0] > 1:
260
+ param = param.permute(1, 2, 3, 0)
261
+ target_param = target_param.permute(1, 2, 3, 0)
262
+ permuted = True
263
+
264
+ if target_param.ndim < 2:
265
+ target_param = target_param.unsqueeze(0)
266
+
267
+ scaled_param = F.interpolate(param, size=(target_param.shape[-2:]), mode="bilinear")
268
+
269
+ while squeeze_count > 0:
270
+ scaled_param = scaled_param.squeeze(0)
271
+ squeeze_count -= 1
272
+
273
+ if permuted:
274
+ scaled_param = scaled_param.permute(-1, 0, 1, 2)
275
+
276
+ else:
277
+ scaled_param = param
278
+ new_params[name] = scaled_param
279
+
280
+ return new_params
281
+
282
+
283
+ def average_model_weights(model_weights_list):
284
+ """
285
+ Compute the average weights of a list of PyTorch models.
286
+
287
+ :param model_weights_list: A list of file paths to PyTorch model weight files.
288
+ :type model_weights_list: List[str]
289
+ :raises ValueError: If the input list is empty.
290
+ :return: A dictionary containing the average weights of the models.
291
+ :rtype: Dict[str, torch.Tensor]
292
+ """
293
+
294
+ if not model_weights_list:
295
+ raise ValueError("The input list cannot be empty.")
296
+
297
+ num_models = len(model_weights_list)
298
+ averaged_weights = {}
299
+
300
+ # Load the first model weights
301
+ state_dict = torch.load(model_weights_list[0])
302
+
303
+ # Iterate through the remaining models and add their weights to the first model's weights
304
+ for i in range(1, num_models):
305
+ state_dict_i = torch.load(model_weights_list[i])
306
+ for key in state_dict.keys():
307
+ state_dict[key] += state_dict_i[key]
308
+
309
+ # Compute the average of the weights
310
+ for key in state_dict.keys():
311
+ averaged_weights[key] = state_dict[key] / num_models
312
+
313
+ return averaged_weights
src/modeling/preprocess.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import pandas as pd
9
+ import soundfile as sf
10
+ from joblib import Parallel, delayed
11
+ from sklearn.model_selection import StratifiedGroupKFold
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ from modeling.transforms import LabelsFromTxt, ParentMultilabel
15
+ from modeling.utils import get_file_info, sync_bpm, sync_onset, sync_pitch
16
+
17
+
18
+ def generate_metadata(
19
+ data_dir: Union[str, Path],
20
+ save_path: str = ".",
21
+ subset: str = "train",
22
+ extract_music_features: bool = False,
23
+ n_jobs: int = -2,
24
+ ):
25
+ """
26
+ Generate metadata CSV file containing information about audio files in a directory.
27
+
28
+ :param data_dir: Directory containing audio files.
29
+ :type data_dir: Union[str, Path]
30
+ :param save_path: Directory path to save metadata CSV file.
31
+ :type save_path: str
32
+ :param subset: Subset of the dataset (train or test), defaults to 'train'.
33
+ :type subset: str
34
+ :param extract_music_features: Flag to indicate whether to extract music features or not, defaults to False.
35
+ :type extract_music_features: bool
36
+ :param n_jobs: Number of parallel jobs to run, defaults to -2.
37
+ :type n_jobs: int
38
+ :raises FileNotFoundError: If the provided data directory does not exist.
39
+ :return: DataFrame containing the metadata information.
40
+ :rtype: pandas.DataFrame
41
+ """
42
+
43
+ data_dir = Path(data_dir) if isinstance(data_dir, str) else data_dir
44
+
45
+ if subset == "train":
46
+ pattern = r"(.*)__[\d]+$"
47
+ label_extractor = ParentMultilabel()
48
+ else:
49
+ pattern = r"(.*)-[\d]+$"
50
+ label_extractor = LabelsFromTxt()
51
+
52
+ sound_files = list(data_dir.glob("**/*.wav"))
53
+ output = Parallel(n_jobs=n_jobs)(delayed(get_file_info)(path, extract_music_features) for path in tqdm(sound_files))
54
+
55
+ df = pd.DataFrame(data=output)
56
+
57
+ df["fname"] = df.path.map(lambda x: Path(x).stem)
58
+ df["song_name"] = df.fname.str.extract(pattern)
59
+ df["inst"] = df.path.map(lambda x: "-".join(sorted(list(label_extractor(x)))))
60
+ df["label_count"] = df.inst.map(lambda x: len(x.split("-")))
61
+
62
+ df.to_csv(f"{save_path}/metadata_{subset}.csv", index=False)
63
+
64
+ return df
65
+
66
+
67
+ def create_test_split(metadata_path: str, txt_save_path: str, random_state: Optional[int] = None):
68
+ """Create test split by generating a list of test songs and saving them to a text file.
69
+
70
+ :param metadata_path: Path to the CSV file containing metadata of all songs
71
+ :type metadata_path: str
72
+ :param txt_save_path: Path to the directory where the text file containing test songs will be saved
73
+ :type txt_save_path: str
74
+ :param random_state: Seed value for the random number generator, defaults to None
75
+ :type random_state: int, optional
76
+ :raises TypeError: If metadata_path or txt_save_path is not a string or if random_state is not an integer or None
77
+ :raises FileNotFoundError: If metadata_path does not exist
78
+ :raises PermissionError: If the program does not have permission to write to txt_save_path
79
+ :return: None
80
+ :rtype: None
81
+ """
82
+
83
+ df = pd.read_csv(metadata_path)
84
+ kf = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=random_state)
85
+ splits = kf.split(df.fname, df.inst, groups=df.song_name)
86
+ _, test = list(splits)[0]
87
+
88
+ test_songs = df.iloc[test].fname.sort_values().to_numpy()
89
+
90
+ with open(f"{txt_save_path}/test_songs.txt", "w") as f:
91
+ # iterate over the list of names and write each one to a new line in the file
92
+ for song in test_songs:
93
+ f.write(song + "\n")
94
+
95
+
96
+ class IRMASPreprocessor:
97
+ """
98
+ A class to preprocess IRMAS dataset metadata and create a mapping between
99
+ file paths and their corresponding instrument labels.
100
+
101
+ :param metadata: A pandas DataFrame or path to csv file containing metadata, defaults to None
102
+ :type metadata: Union[pd.DataFrame, str], optional
103
+ :param data_dir: Path to the directory containing the IRMAS dataset, defaults to None
104
+ :type data_dir: Union[str, Path], optional
105
+ :param sample_rate: Sample rate of the audio files, defaults to 16000
106
+ :type sample_rate: int, optional
107
+
108
+ :raises AssertionError: Raised when metadata is None and data_dir is also None.
109
+
110
+ :return: An instance of IRMASPreprocessor
111
+ :rtype: IRMASPreprocessor
112
+ """
113
+
114
+ def __init__(
115
+ self, metadata: Union[pd.DataFrame, str] = None, data_dir: Union[str, Path] = None, sample_rate: int = 16000
116
+ ):
117
+ if metadata is not None:
118
+ self.metadata = pd.read_csv(metadata) if isinstance(metadata, str) else metadata
119
+ if data_dir is not None:
120
+ self.metadata["path"] = self.metadata.apply(lambda x: f"{data_dir}/{x.inst}/{x.fname}.wav", axis=1)
121
+ else:
122
+ assert data_dir is not None, "No metadata found. Need to provide data directory"
123
+ self.metadata = generate_metadata(data_dir=data_dir, subset="train", extract_music_features=True)
124
+
125
+ self.instruments = self.metadata.inst.unique()
126
+ self.sample_rate = sample_rate
127
+
128
+ def preprocess_and_mix(self, save_dir: str, sync: str, ordered: bool, num_track_to_mix: int, n_jobs: int = -2):
129
+ """
130
+ A method to preprocess and mix audio tracks from the IRMAS dataset.
131
+
132
+ :param save_dir: The directory to save the preprocessed and mixed tracks
133
+ :type save_dir: str
134
+ :param sync: The column name used to synchronize the audio tracks during mixing
135
+ :type sync: str
136
+ :param ordered: Whether to order the metadata by the sync column before mixing the tracks
137
+ :type ordered: bool
138
+ :param num_track_to_mix: The number of tracks to mix together
139
+ :type num_track_to_mix: int
140
+ :param n_jobs: The number of parallel jobs to run, defaults to -2
141
+ :type n_jobs: int, optional
142
+
143
+ :raises None
144
+
145
+ :return: None
146
+ :rtype: None
147
+ """
148
+
149
+ combs = itertools.combinations(self.instruments, r=num_track_to_mix)
150
+
151
+ if ordered:
152
+ self.metadata = self.metadata.sort_values(by=sync)
153
+ else:
154
+ self.metadata = self.metadata.sample(frac=1)
155
+
156
+ Parallel(n_jobs=n_jobs)(delayed(self._mix)(insts, save_dir, sync) for (insts) in tqdm(combs))
157
+ print("Parallel preprocessing done!")
158
+
159
+ def _mix(self, insts: Tuple[str], save_dir: str, sync: str):
160
+ """
161
+ A private method to mix audio tracks and save them to disk.
162
+
163
+ :param insts: A tuple of instrument labels to mix
164
+ :type insts: Tuple[str]
165
+ :param save_dir: The directory to save the mixed tracks
166
+ :type save_dir: str
167
+ :param sync: The column name used to synchronize the audio tracks during mixing
168
+ :type sync: str
169
+
170
+ :raises None
171
+
172
+ :return: None
173
+ :rtype: None
174
+ """
175
+
176
+ save_dir = self._create_save_dir(insts, save_dir)
177
+
178
+ insts_files_list = [self._get_filepaths(inst) for inst in insts]
179
+
180
+ max_length = max([inst_files.shape[0] for inst_files in insts_files_list])
181
+ for i, inst_files in enumerate(insts_files_list):
182
+ if inst_files.shape[0] < max_length:
183
+ diff = max_length - inst_files.shape[0]
184
+ inst_files = np.pad(inst_files, (0, diff), mode="symmetric")
185
+ insts_files_list[i] = [Path(x) for x in inst_files]
186
+
187
+ self._mix_files_and_save(insts_files_list, save_dir, sync)
188
+
189
+ def _get_filepaths(self, inst: str):
190
+ """
191
+ A private method to retrieve file paths of audio tracks for a given instrument label.
192
+
193
+ :param inst: The label of the instrument for which to retrieve the file paths
194
+ :type inst: str
195
+
196
+ :raises KeyError: Raised when the instrument label is not found in the metadata.
197
+
198
+ :return: A numpy array of file paths corresponding to the instrument label.
199
+ :rtype: numpy.ndarray
200
+ """
201
+
202
+ metadata = self.metadata.loc[self.metadata.inst == inst]
203
+
204
+ if metadata.empty:
205
+ raise KeyError("Instrument not found. Please regenerate metadata!")
206
+
207
+ files = metadata.path.to_numpy()
208
+
209
+ return files
210
+
211
+ def _mix_files_and_save(self, insts_files_list: List[List[Path]], save_dir: str, sync: str):
212
+ """
213
+ A private method to mix audio files, synchronize them using a given column name in the metadata,
214
+ and save the mixed file to disk.
215
+
216
+ :param insts_files_list: A list of lists of file paths corresponding to each instrument label
217
+ :type insts_files_list: List[List[Path]]
218
+ :param save_dir: The directory to save the mixed tracks
219
+ :type save_dir: str
220
+ :param sync: The column name used to synchronize the audio tracks during mixing
221
+ :type sync: str
222
+
223
+ :raises None
224
+
225
+ :return: None
226
+ :rtype: None
227
+ """
228
+
229
+ for i in range(len(insts_files_list[0])):
230
+ files_to_sync = [inst_files[i] for inst_files in insts_files_list]
231
+ new_name = f"{'-'.join([file.stem for file in files_to_sync])}.wav"
232
+
233
+ synced_file = self._sync_and_mix(files_to_sync, sync)
234
+ sf.write(os.path.join(save_dir, new_name), synced_file, samplerate=self.sample_rate)
235
+
236
+ def _sync_and_mix(self, files_to_sync: List[Path], sync: str):
237
+ """
238
+ Synchronize and mix audio files.
239
+
240
+ :param files_to_sync: A list of file paths to synchronize and mix.
241
+ :type files_to_sync: List[Path]
242
+ :param sync: The type of synchronization to use. One of ['bpm', 'pitch', None].
243
+ :type sync: str, optional
244
+ :raises KeyError: If any file in files_to_sync is not found in metadata.
245
+ :return: The synchronized and mixed audio signal.
246
+ :rtype: numpy.ndarray
247
+ """
248
+
249
+ cols = ["pitch", "bpm", "onset"]
250
+ files_metadata_df = self.metadata.loc[
251
+ self.metadata.path.isin([str(file_path) for file_path in files_to_sync])
252
+ ].set_index("path")
253
+
254
+ num_files = files_metadata_df.shape[0]
255
+ if num_files != len(files_to_sync):
256
+ raise KeyError("File not found in metadata. Please regenerate")
257
+
258
+ if sync is not None:
259
+ mean_features = files_metadata_df[cols].mean().to_dict()
260
+
261
+ metadata_dict = files_metadata_df.to_dict("index")
262
+
263
+ for i, (file_to_sync_path, features) in enumerate(metadata_dict.items()):
264
+ file_to_sync, sr_sync = librosa.load(file_to_sync_path, sr=None)
265
+
266
+ if sr_sync != 44100:
267
+ file_to_sync = librosa.resample(y=file_to_sync, orig_sr=sr_sync, target_sr=self.sample_rate)
268
+
269
+ if sync == "bpm":
270
+ file_to_sync = sync_bpm(file_to_sync, sr_sync, bpm_base=mean_features["bpm"], bpm=features["bpm"])
271
+
272
+ if sync == "pitch":
273
+ file_to_sync = sync_pitch(
274
+ file_to_sync, sr_sync, pitch_base=mean_features["pitch"], pitch=features["pitch"]
275
+ )
276
+
277
+ if sync is not None:
278
+ file_to_sync = sync_onset(
279
+ file_to_sync, sr_sync, onset_base=mean_features["onset"], onset=features["onset"]
280
+ )
281
+
282
+ file_to_sync = librosa.util.normalize(file_to_sync)
283
+
284
+ if i == 0:
285
+ mixed_sound = np.zeros_like(file_to_sync)
286
+
287
+ if mixed_sound.shape[0] > file_to_sync.shape[0]:
288
+ file_to_sync = np.resize(file_to_sync, mixed_sound.shape)
289
+ else:
290
+ mixed_sound = np.resize(mixed_sound, file_to_sync.shape)
291
+
292
+ mixed_sound += file_to_sync
293
+
294
+ mixed_sound /= num_files
295
+
296
+ return librosa.resample(y=mixed_sound, orig_sr=44100, target_sr=self.sample_rate)
297
+
298
+ def _create_save_dir(self, insts: Union[Tuple[str], List[str]], save_dir: str):
299
+ """
300
+ Create and return a directory to save instrument-specific files.
301
+
302
+ :param insts: A tuple or list of instrument names.
303
+ :type insts: Union[Tuple[str], List[str]]
304
+ :param save_dir: The path to the directory where the new directory will be created.
305
+ :type save_dir: str
306
+ :return: The path to the newly created directory.
307
+ :rtype: str
308
+ """
309
+
310
+ new_dir_name = "-".join(insts)
311
+ new_dir_path = os.path.join(save_dir, new_dir_name)
312
+ os.makedirs(new_dir_path, exist_ok=True)
313
+ return new_dir_path
314
+
315
+ @classmethod
316
+ def from_metadata(cls, metadata_path: str, **kwargs):
317
+ """
318
+ Create a new instance of the class from a metadata file.
319
+
320
+ :param metadata_path: The path to the metadata file.
321
+ :type metadata_path: str
322
+ :param **kwargs: Additional keyword arguments to pass to the class constructor.
323
+ :return: A new instance of the class.
324
+ :rtype: cls
325
+ """
326
+
327
+ metadata = pd.read_csv(metadata_path)
328
+ return cls(metadata, **kwargs)
329
+
330
+
331
+ if __name__ == "__main__":
332
+ data_dir = "/home/kpintaric/lumen-irmas/data/raw/IRMAS_Training_Data"
333
+ metadata_path = "/home/kpintaric/lumen-irmas/data/metadata_train.csv"
334
+ preprocess = IRMASPreprocessor(metadata=metadata_path, data_dir=data_dir)
335
+ preprocess.preprocess_and_mix(save_dir="data", sync="pitch", ordered=False, num_track_to_mix=3)
336
+ a = 1
src/modeling/transforms.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from torchaudio.transforms import FrequencyMasking, TimeMasking
10
+ from torchvision.transforms import Compose
11
+ from transformers import ASTFeatureExtractor
12
+
13
+
14
+ class Transform(ABC):
15
+ """Abstract base class for audio transformations."""
16
+
17
+ @abstractmethod
18
+ def __call__(self):
19
+ """
20
+ Abstract method to apply the transformation.
21
+
22
+ :raises NotImplementedError: If the subclass does not implement this method.
23
+
24
+ """
25
+ pass
26
+
27
+
28
+ class Preprocess(ABC):
29
+ """Abstract base class for preprocessing data.
30
+
31
+ This class defines the interface for preprocessing data. Subclasses must implement the call method.
32
+
33
+ """
34
+
35
+ @abstractmethod
36
+ def __call__(self):
37
+ """Process the data.
38
+
39
+ This method must be implemented by subclasses.
40
+
41
+ :raises NotImplementedError: Subclasses must implement this method.
42
+
43
+ """
44
+ pass
45
+
46
+
47
+ class OneHotEncode(Transform):
48
+ """Transform labels to one-hot encoded tensor.
49
+
50
+ This class is a transform that takes a list of labels and returns a one-hot encoded tensor.
51
+ The labels are converted to a tensor with one-hot encoding using the specified classes.
52
+
53
+ :param c: A list of classes to be used for one-hot encoding.
54
+ :type c: list
55
+ :return: A one-hot encoded tensor.
56
+ :rtype: torch.Tensor
57
+
58
+ """
59
+
60
+ def __init__(self, c: list):
61
+ self.c = c
62
+
63
+ def __call__(self, labels):
64
+ """
65
+ Transform labels to one-hot encoded tensor.
66
+
67
+ :param labels: A list of labels to be encoded.
68
+ :type labels: list
69
+ :return: A one-hot encoded tensor.
70
+ :rtype: torch.Tensor
71
+
72
+ """
73
+
74
+ target = torch.zeros(len(self.c), dtype=torch.float)
75
+ for label in labels:
76
+ idx = self.c.index(label)
77
+ target[idx] = 1
78
+ return target
79
+
80
+
81
+ class ParentMultilabel(Transform):
82
+ """
83
+ A transform that extracts a list of labels from the parent directory name of a file path.
84
+
85
+ :param sep: The separator used to split the parent directory name into labels. Defaults to " ".
86
+ :type sep: str
87
+ """
88
+
89
+ def __init__(self, sep=" "):
90
+ self.sep = sep
91
+
92
+ def __call__(self, path):
93
+ """
94
+ Extract a list of labels from the parent directory name of a file path.
95
+
96
+ :param path: The file path from which to extract labels.
97
+ :type path: str
98
+ :return: A list of labels extracted from the parent directory name of the input file path.
99
+ :rtype: List[str]
100
+ """
101
+
102
+ label = path.split(os.path.sep)[-2].split(self.sep)
103
+ return label
104
+
105
+
106
+ class LabelsFromTxt(Transform):
107
+ """
108
+ Extract multilabel parent directory from file path.
109
+
110
+ This class is a transform that extracts a multilabel parent directory from a file path.
111
+ The directory names are split by a specified separator.
112
+
113
+ :param sep: The separator used to split the directory names. Defaults to " ".
114
+ :type sep: str
115
+
116
+ """
117
+
118
+ def __init__(self, delimiter=None):
119
+ self.delimiter = delimiter
120
+
121
+ def __call__(self, path):
122
+ """
123
+ Extract multilabel parent directory from file path.
124
+
125
+ :param path: The path of the file to extract the multilabel directory from.
126
+ :type path: str
127
+ :return: A list of directory names representing the multilabel parent directory.
128
+ :rtype: list
129
+
130
+ """
131
+
132
+ path = path.replace("wav", "txt")
133
+ label = np.loadtxt(path, dtype=str, ndmin=1, delimiter=self.delimiter)
134
+ return label
135
+
136
+
137
+ class PreprocessPipeline(Preprocess):
138
+ """A preprocessing pipeline for audio data.
139
+
140
+ This class is a preprocessing pipeline for audio data.
141
+ The pipeline includes resampling to a target sampling rate, mixing down stereo to mono,
142
+ and loading audio from a file.
143
+
144
+ :param target_sr: The target sampling rate to resample to.
145
+ :type target_sr: int
146
+ """
147
+
148
+ def __init__(self, target_sr):
149
+ self.target_sr = target_sr
150
+
151
+ def __call__(self, path):
152
+ """
153
+ Preprocess audio data using a pipeline.
154
+
155
+ :param path: The path to the audio file to load.
156
+ :type path: str
157
+ :return: A NumPy array of preprocessed audio data.
158
+ :rtype: numpy.ndarray
159
+
160
+ """
161
+
162
+ signal, sr = torchaudio.load(path)
163
+ signal = self._resample(signal, sr)
164
+ signal = self._mix_down(signal)
165
+ return signal.numpy()
166
+
167
+ def _mix_down(self, signal):
168
+ """
169
+ Mix down stereo to mono.
170
+
171
+ :param signal: The audio signal to mix down.
172
+ :type signal: torch.Tensor
173
+ :return: The mixed down audio signal.
174
+ :rtype: torch.Tensor
175
+
176
+ """
177
+
178
+ if signal.shape[0] > 1:
179
+ signal = torch.mean(signal, dim=0, keepdim=True)
180
+ return signal
181
+
182
+ def _resample(self, signal, input_sr):
183
+ """
184
+ Resample audio signal to a target sampling rate.
185
+
186
+ :param signal: The audio signal to resample.
187
+ :type signal: torch.Tensor
188
+ :param input_sr: The current sampling rate of the audio signal.
189
+ :type input_sr: int
190
+ :return: The resampled audio signal.
191
+ :rtype: torch.Tensor
192
+
193
+ """
194
+
195
+ if input_sr != self.target_sr:
196
+ resampler = torchaudio.transforms.Resample(input_sr, self.target_sr)
197
+ signal = resampler(signal)
198
+ return signal
199
+
200
+
201
+ class SpecToImage(Transform):
202
+ def __init__(self, mean=None, std=None, eps=1e-6):
203
+ self.mean = mean
204
+ self.std = std
205
+ self.eps = eps
206
+
207
+ def __call__(self, spec):
208
+ spec = torch.stack([spec, spec, spec], dim=-1)
209
+
210
+ mean = torch.mean(spec) if self.mean is None else self.mean
211
+ std = torch.std(spec) if self.std is None else self.std
212
+ spec_norm = (spec - mean) / std
213
+
214
+ spec_min, spec_max = torch.min(spec_norm), torch.max(spec_norm)
215
+ spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
216
+
217
+ return spec_scaled.type(torch.uint8)
218
+
219
+
220
+ class MinMaxScale(Transform):
221
+ def __call__(self, spec):
222
+ spec_min, spec_max = torch.min(spec), torch.max(spec)
223
+
224
+ return (spec - spec_min) / (spec_max - spec_min)
225
+
226
+
227
+ class Normalize(Transform):
228
+ def __init__(self, mean, std):
229
+ self.mean = mean
230
+ self.std = std
231
+
232
+ def __call__(self, spec):
233
+ return (spec - self.mean) / self.std
234
+
235
+
236
+ class FeatureExtractor(Transform):
237
+ """Extract features from audio signal using an AST feature extractor.
238
+
239
+ This class is a transform that extracts features from an audio signal using an AST feature extractor.
240
+ The features are returned as a PyTorch tensor.
241
+
242
+ :param sr: The sampling rate of the audio signal.
243
+ :type sr: int
244
+ """
245
+
246
+ def __init__(self, sr):
247
+ self.transform = partial(ASTFeatureExtractor(), sampling_rate=sr, return_tensors="pt")
248
+
249
+ def __call__(self, signal):
250
+ """
251
+ Extract features from audio signal using an AST feature extractor.
252
+
253
+ :param signal: The audio signal to extract features from.
254
+ :type signal: numpy.ndarray
255
+ :return: A tensor of extracted audio features.
256
+ :rtype: torch.Tensor
257
+
258
+ """
259
+
260
+ return self.transform(signal.squeeze()).input_values.mT
261
+
262
+
263
+ class Preemphasis(Transform):
264
+ """perform preemphasis on the input signal.
265
+ :param signal: The signal to filter.
266
+ :param coeff: The preemphasis coefficient. 0 is none, default 0.97.
267
+ :returns: the filtered signal.
268
+ """
269
+
270
+ def __init__(self, coeff: float = 0.97):
271
+ self.coeff = coeff
272
+
273
+ def __call__(self, signal):
274
+ return torch.cat([signal[:, :1], signal[:, 1:] - self.coeff * signal[:, :-1]], dim=1)
275
+
276
+
277
+ class Spectrogram(Transform):
278
+ def __init__(self, sample_rate, n_mels, hop_length, n_fft):
279
+ self.transform = torchaudio.transforms.MelSpectrogram(
280
+ sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft, f_min=20, center=False
281
+ )
282
+
283
+ def __call__(self, signal):
284
+ return self.transform(signal)
285
+
286
+
287
+ class LogTransform(Transform):
288
+ def __call__(self, signal):
289
+ return torch.log(signal + 1e-8)
290
+
291
+
292
+ class PadCutToLength(Transform):
293
+ def __init__(self, max_length):
294
+ self.max_length = max_length
295
+
296
+ def __call__(self, spec):
297
+ seq_len = spec.shape[-1]
298
+
299
+ if seq_len > self.max_length:
300
+ return spec[..., : self.max_length]
301
+ if seq_len < self.max_length:
302
+ diff = self.max_length - seq_len
303
+ return F.pad(spec, (0, diff), mode="constant", value=0)
304
+
305
+
306
+ class CustomFeatureExtractor(Transform):
307
+ def __init__(self, sample_rate, n_mels, hop_length, n_fft, max_length, mean, std):
308
+ self.extract = Compose(
309
+ [
310
+ Preemphasis(),
311
+ Spectrogram(sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft),
312
+ LogTransform(),
313
+ PadCutToLength(max_length=max_length),
314
+ Normalize(mean=mean, std=std),
315
+ ]
316
+ )
317
+
318
+ def __call__(self, x):
319
+ return self.extract(x)
320
+
321
+
322
+ class RepeatAudio(Transform):
323
+ """A transform to repeat audio data.
324
+
325
+ This class is a transform that repeats audio data a random number of times up to a maximum specified value.
326
+
327
+ :param max_repeats: The maximum number of times to repeat the audio data.
328
+ :type max_repeats: int
329
+ """
330
+
331
+ def __init__(self, max_repeats: int = 2):
332
+ self.max_repeats = max_repeats
333
+
334
+ def __call__(self, signal):
335
+ """
336
+ Repeat audio data a random number of times up to a maximum specified value.
337
+
338
+ :param signal: The audio data to repeat.
339
+ :type signal: numpy.ndarray
340
+ :return: The repeated audio data.
341
+ :rtype: numpy.ndarray
342
+
343
+ """
344
+
345
+ num_repeats = torch.randint(1, self.max_repeats, (1,)).item()
346
+ return np.tile(signal, reps=num_repeats)
347
+
348
+
349
+ class MaskFrequency(Transform):
350
+ """A transform to mask frequency of a spectrogram.
351
+
352
+ This class is a transform that masks out a random number of consecutive frequencies from a spectrogram.
353
+
354
+ :param max_mask_length: The maximum number of consecutive frequencies to mask out from the spectrogram.
355
+ :type max_mask_length: int
356
+ """
357
+
358
+ def __init__(self, max_mask_length: int = 0):
359
+ self.aug = FrequencyMasking(max_mask_length)
360
+
361
+ def __call__(self, spec):
362
+ """
363
+ Mask out a random number of consecutive frequencies from a spectrogram.
364
+
365
+ :param spec: The input spectrogram.
366
+ :type spec: numpy.ndarray
367
+ :return: The spectrogram with masked frequencies.
368
+ :rtype: numpy.ndarray
369
+
370
+ """
371
+
372
+ return self.aug(spec)
373
+
374
+
375
+ class MaskTime(Transform):
376
+ """A transform to mask time of a spectrogram.
377
+
378
+ This class is a transform that masks out a random number of consecutive time steps from a spectrogram.
379
+
380
+ :param max_mask_length: The maximum number of consecutive time steps to mask out from the spectrogram.
381
+ :type max_mask_length: int
382
+ """
383
+
384
+ def __init__(self, max_mask_length: int = 0):
385
+ self.aug = TimeMasking(max_mask_length)
386
+
387
+ def __call__(self, spec):
388
+ """
389
+ Mask out a random number of consecutive time steps from a spectrogram.
390
+
391
+ :param spec: The input spectrogram.
392
+ :type spec: numpy.ndarray
393
+ :return: The spectrogram with masked time steps.
394
+ :rtype: numpy.ndarray
395
+
396
+ """
397
+
398
+ return self.aug(spec)
src/modeling/utils.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ from pathlib import Path
3
+ from types import SimpleNamespace
4
+ from typing import Union
5
+
6
+ import librosa
7
+ import numpy as np
8
+ import yaml
9
+
10
+ CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
11
+
12
+
13
+ def get_wav_files(base_path):
14
+ """
15
+ Function to recursively get all the .wav files in a directory.
16
+
17
+ :param base_path: The base path of the directory to search.
18
+ :type base_path: str or pathlib.Path
19
+
20
+ :return: A list of paths to .wav files found in the directory.
21
+ :rtype: List[str]
22
+ """
23
+
24
+ return glob(f"{base_path}/**/*.wav", recursive=True)
25
+
26
+
27
+ def parse_config(config_path):
28
+ """
29
+ Parse a YAML configuration file and return the configuration as a SimpleNamespace object.
30
+
31
+ :param config_path: The path to the YAML configuration file.
32
+ :type config_path: str or pathlib.Path
33
+
34
+ :return: A SimpleNamespace object representing the configuration.
35
+ :rtype: types.SimpleNamespace
36
+ """
37
+ with open(config_path) as file:
38
+ return SimpleNamespace(**yaml.safe_load(file))
39
+
40
+
41
+ def init_transforms(fn_dict, module):
42
+ """
43
+ Initialize a list of transforms from a dictionary of function names and their parameters.
44
+
45
+ :param fn_dict: A dictionary where keys are the names of transform functions
46
+ and values are dictionaries of parameters.
47
+ :type fn_dict: Dict[str, Dict[str, Any]]
48
+
49
+ :param module: The module where the transform functions are defined.
50
+ :type module: module
51
+
52
+ :return: A list of transform functions.
53
+ :rtype: List[Callable]
54
+ """
55
+ transforms = init_objs(fn_dict, module)
56
+ if transforms is not None:
57
+ transforms = ComposeTransforms(transforms)
58
+ return transforms
59
+
60
+
61
+ def init_objs(fn_dict, module):
62
+ """
63
+ Initialize a list of objects from a dictionary of object names and their parameters.
64
+
65
+ :param fn_dict: A dictionary where keys are the names of object classes and values are dictionaries of parameters.
66
+ :type fn_dict: Dict[str, Dict[str, Any]]
67
+
68
+ :param module: The module where the object classes are defined.
69
+ :type module: module
70
+
71
+ :return: A list of objects.
72
+ :rtype: List[Any]
73
+ """
74
+
75
+ if fn_dict is None:
76
+ return None
77
+
78
+ transforms = []
79
+ for transform in fn_dict.keys():
80
+ fn = getattr(module, transform)
81
+ if fn is None:
82
+ raise NotImplementedError(
83
+ "The attribute '{}' is not implemented in the module '{}'.".format(transform, module.__name__)
84
+ )
85
+
86
+ fn_args = fn_dict[transform]
87
+
88
+ if fn_args is None:
89
+ transforms.append(fn())
90
+ else:
91
+ transforms.append(fn(**fn_args))
92
+
93
+ return transforms
94
+
95
+
96
+ def init_obj(fn_dict, module, *args, **kwargs):
97
+ """
98
+ Initialize an object by calling a function with the provided arguments.
99
+
100
+ :param fn_dict: A dictionary that maps the function name to its arguments.
101
+ :type fn_dict: dict or None
102
+ :param module: The module containing the function.
103
+ :type module: module
104
+ :param args: The positional arguments for the function.
105
+ :type args: tuple
106
+ :param kwargs: The keyword arguments for the function.
107
+ :type kwargs: dict
108
+ :raises AssertionError: If a keyword argument is already specified in fn_dict.
109
+ :return: The result of calling the function with the provided arguments.
110
+ :rtype: Any
111
+ """
112
+
113
+ if fn_dict is None:
114
+ return None
115
+
116
+ name = list(fn_dict.keys())[0]
117
+
118
+ fn = getattr(module, name)
119
+ if fn is None:
120
+ raise NotImplementedError(
121
+ "The attribute '{}' is not implemented in the module '{}'.".format(name, module.__name__)
122
+ )
123
+
124
+ fn_args = fn_dict[name]
125
+
126
+ if fn_args is not None:
127
+ assert all(k not in fn_args for k in kwargs)
128
+ fn_args.update(kwargs)
129
+
130
+ return fn(*args, **fn_args)
131
+ else:
132
+ return fn(*args, **kwargs)
133
+
134
+
135
+ class ComposeTransforms:
136
+ """
137
+ Composes a list of transforms to be applied in sequence to input data.
138
+
139
+ :param transforms: A list of transforms to be applied.
140
+ :type transforms: List[callable]
141
+ """
142
+
143
+ def __init__(self, transforms: list):
144
+ self.transforms = transforms
145
+
146
+ def __call__(self, data, *args):
147
+ for t in self.transforms:
148
+ data = t(data, *args)
149
+ return data
150
+
151
+
152
+ def load_raw_file(path: Union[str, Path]):
153
+ """
154
+ Loads an audio file from disk and returns its raw waveform and sample rate.
155
+
156
+ :param path: The path to the audio file to load.
157
+ :type path: Union[str, Path]
158
+ :return: A tuple containing the raw waveform and sample rate.
159
+ :rtype: tuple
160
+ """
161
+ return librosa.load(path, sr=None, mono=False)
162
+
163
+
164
+ def get_onset(signal, sr):
165
+ """
166
+ Computes the onset of an audio signal.
167
+
168
+ :param signal: The audio signal.
169
+ :type signal: np.ndarray
170
+ :param sr: The sample rate of the audio signal.
171
+ :type sr: int
172
+ :return: The onset of the audio signal in seconds.
173
+ :rtype: float
174
+ """
175
+ onset = librosa.onset.onset_detect(y=signal, sr=sr, units="time")[0]
176
+ return onset
177
+
178
+
179
+ def get_bpm(signal, sr):
180
+ """
181
+ Computes the estimated beats per minute (BPM) of an audio signal.
182
+
183
+ :param signal: The audio signal.
184
+ :type signal: np.ndarray
185
+ :param sr: The sample rate of the audio signal.
186
+ :type sr: int
187
+ :return: The estimated BPM of the audio signal, or None if the BPM cannot be computed.
188
+ :rtype: Union[float, None]
189
+ """
190
+
191
+ bpm, _ = librosa.beat.beat_track(y=signal, sr=sr)
192
+ return bpm if bpm != 0 else None
193
+
194
+
195
+ def get_pitch(signal, sr):
196
+ """
197
+ Computes the estimated pitch of an audio signal.
198
+
199
+ :param signal: The audio signal.
200
+ :type signal: np.ndarray
201
+ :param sr: The sample rate of the audio signal.
202
+ :type sr: int
203
+ :return: The estimated pitch of the audio signal in logarithmic scale, or None if the pitch cannot be computed.
204
+ :rtype: Union[float, None]
205
+ """
206
+
207
+ eps = 1e-8
208
+ fmin = librosa.note_to_hz("C2")
209
+ fmax = librosa.note_to_hz("C7")
210
+
211
+ pitch, _, _ = librosa.pyin(y=signal, sr=sr, fmin=fmin, fmax=fmax)
212
+
213
+ if not np.isnan(pitch).all():
214
+ mean_log_pitch = np.nanmean(np.log(pitch + eps))
215
+ else:
216
+ mean_log_pitch = None
217
+
218
+ return mean_log_pitch
219
+
220
+
221
+ def get_file_info(path: Union[str, Path], extract_music_features: bool):
222
+ """
223
+ Loads an audio file and computes some basic information about it,
224
+ such as pitch, BPM, onset time, duration, sample rate, and number of channels.
225
+
226
+ :param path: The path to the audio file.
227
+ :type path: Union[str, Path]
228
+ :param extract_music_features: Whether to extract music features such as pitch, BPM, and onset time.
229
+ :type extract_music_features: bool
230
+ :return: A dictionary containing information about the audio file.
231
+ :rtype: dict
232
+ """
233
+
234
+ path = str(path) if isinstance(path, Path) else path
235
+
236
+ signal, sr = load_raw_file(path)
237
+ channels = signal.shape[0]
238
+
239
+ signal = librosa.to_mono(signal)
240
+ duration = len(signal) / sr
241
+
242
+ pitch, bpm, onset = None, None, None
243
+ if extract_music_features:
244
+ pitch = get_pitch(signal, sr)
245
+ bpm = get_bpm(signal, sr)
246
+ onset = get_onset(signal, sr)
247
+
248
+ return {
249
+ "path": path,
250
+ "pitch": pitch,
251
+ "bpm": bpm,
252
+ "onset": onset,
253
+ "sample_rate": sr,
254
+ "duration": duration,
255
+ "channels": channels,
256
+ }
257
+
258
+
259
+ def sync_pitch(file_to_sync: np.ndarray, sr: int, pitch_base: float, pitch: float):
260
+ """
261
+ Shift the pitch of an audio file to match a new pitch value.
262
+
263
+ :param file_to_sync: The input audio file as a NumPy array.
264
+ :type file_to_sync: np.ndarray
265
+ :param sr: The sample rate of the input file.
266
+ :type sr: int
267
+ :param pitch_base: The pitch value of the original file.
268
+ :type pitch_base: float
269
+ :param pitch: The pitch value to synchronize the input file to.
270
+ :type pitch: float
271
+ :return: The synchronized audio file as a NumPy array.
272
+ :rtype: np.ndarray
273
+ """
274
+
275
+ assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
276
+
277
+ if any(np.isnan(x) for x in [pitch_base, pitch]):
278
+ return file_to_sync
279
+
280
+ steps = np.round(12 * np.log2(np.exp(pitch_base) / np.exp(pitch)), 0)
281
+
282
+ return librosa.effects.pitch_shift(y=file_to_sync, sr=sr, n_steps=steps)
283
+
284
+
285
+ def sync_bpm(file_to_sync: np.ndarray, sr: int, bpm_base: float, bpm: float):
286
+ """
287
+ Stretch or compress the duration of an audio file to match a new tempo.
288
+
289
+ :param file_to_sync: The input audio file as a NumPy array.
290
+ :type file_to_sync: np.ndarray
291
+ :param sr: The sample rate of the input file.
292
+ :type sr: int
293
+ :param bpm_base: The tempo of the original file.
294
+ :type bpm_base: float
295
+ :param bpm: The tempo to synchronize the input file to.
296
+ :type bpm: float
297
+ :return: The synchronized audio file as a NumPy array.
298
+ :rtype: np.ndarray
299
+ """
300
+
301
+ assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
302
+
303
+ if any(np.isnan(x) for x in [bpm_base, bpm]):
304
+ return file_to_sync
305
+
306
+ return librosa.effects.time_stretch(y=file_to_sync, rate=bpm_base / bpm)
307
+
308
+
309
+ def sync_onset(file_to_sync: np.ndarray, sr: int, onset_base: float, onset: float):
310
+ """
311
+ Sync the onset of an audio signal by adding or removing silence at the beginning.
312
+
313
+ :param file_to_sync: The audio signal to synchronize.
314
+ :type file_to_sync: np.ndarray
315
+ :param sr: The sample rate of the audio signal.
316
+ :type sr: int
317
+ :param onset_base: The onset of the reference signal in seconds.
318
+ :type onset_base: float
319
+ :param onset: The onset of the signal to synchronize in seconds.
320
+ :type onset: float
321
+ :raises AssertionError: If the input array has more than one dimension.
322
+ :return: The synchronized audio signal.
323
+ :rtype: np.ndarray
324
+ """
325
+
326
+ assert np.ndim(file_to_sync) == 1, "Input array has more than one dimensions"
327
+
328
+ if any(np.isnan(x) for x in [onset_base, onset]):
329
+ return file_to_sync
330
+
331
+ diff = int(round(abs(onset_base * sr - onset * sr), 0))
332
+
333
+ if onset_base > onset:
334
+ return np.pad(file_to_sync, (diff, 0), mode="constant", constant_values=0)
335
+ else:
336
+ return file_to_sync[diff:]