dennisvdang commited on
Commit
f440926
1 Parent(s): 443db5b

Add application file

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs (must be set before importing TensorFlow)
3
+ import tensorflow as tf
4
+ tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
5
+ import warnings
6
+ warnings.filterwarnings("ignore") # Suppress all warnings
7
+
8
+ import argparse
9
+ from functools import reduce
10
+ from typing import List, Tuple
11
+ import shutil
12
+ import librosa
13
+ import numpy as np
14
+ from matplotlib import pyplot as plt
15
+ from pydub import AudioSegment
16
+ from pydub.silence import detect_nonsilent
17
+ from pytube import YouTube
18
+ from sklearn.preprocessing import StandardScaler
19
+ import shutil
20
+ import streamlit as st
21
+
22
+
23
+ # Constants
24
+ SR = 12000
25
+ HOP_LENGTH = 128
26
+ MAX_FRAMES = 300
27
+ MAX_METERS = 201
28
+ N_FEATURES = 15
29
+ MODEL_PATH = "models/CRNN/best_model_V3.h5"
30
+ AUDIO_TEMP_PATH = "output/temp"
31
+
32
+ def extract_audio(url, output_path=AUDIO_TEMP_PATH):
33
+ try:
34
+ yt = YouTube(url)
35
+ video_title = yt.title
36
+ audio_stream = yt.streams.filter(only_audio=True).first()
37
+ if audio_stream:
38
+ os.makedirs(output_path, exist_ok=True)
39
+ out_file = audio_stream.download(output_path)
40
+ base, _ = os.path.splitext(out_file)
41
+ audio_file = base + '.mp3'
42
+ if os.path.exists(audio_file):
43
+ os.remove(audio_file)
44
+ os.rename(out_file, audio_file)
45
+ return audio_file, video_title
46
+ else:
47
+ st.error("No audio stream found")
48
+ return None, None
49
+ except Exception as e:
50
+ st.error(f"An error occurred: {e}")
51
+ return None, None
52
+
53
+ def strip_silence(audio_path):
54
+ sound = AudioSegment.from_file(audio_path)
55
+ nonsilent_ranges = detect_nonsilent(sound, min_silence_len=500, silence_thresh=-50)
56
+ stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]], nonsilent_ranges, AudioSegment.empty())
57
+ stripped.export(audio_path, format='mp3')
58
+
59
+ class AudioFeature:
60
+ def __init__(self, audio_path, sr=SR, hop_length=HOP_LENGTH):
61
+ self.audio_path = audio_path
62
+ self.sr = sr
63
+ self.hop_length = hop_length
64
+ self.y = None
65
+ self.y_harm, self.y_perc = None, None
66
+ self.spectrogram = None
67
+ self.rms = None
68
+ self.melspectrogram = None
69
+ self.mel_acts = None
70
+ self.chromagram = None
71
+ self.chroma_acts = None
72
+ self.onset_env = None
73
+ self.tempogram = None
74
+ self.tempogram_acts = None
75
+ self.mfccs = None
76
+ self.mfcc_acts = None
77
+ self.combined_features = None
78
+ self.n_frames = None
79
+ self.tempo = None
80
+ self.beats = None
81
+ self.meter_grid = None
82
+ self.key, self.mode = None, None
83
+
84
+ def detect_key(self, chroma_vals):
85
+ note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
86
+ major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
87
+ minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
88
+ major_profile /= np.linalg.norm(major_profile)
89
+ minor_profile /= np.linalg.norm(minor_profile)
90
+
91
+ major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
92
+ minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
93
+
94
+ max_major_idx = np.argmax(major_correlations)
95
+ max_minor_idx = np.argmax(minor_correlations)
96
+
97
+ self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
98
+ self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
99
+ return self.key, self.mode
100
+
101
+ def calculate_ki_chroma(self, waveform, sr, hop_length):
102
+ chromagram = librosa.feature.chroma_cqt(y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
103
+ chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min())
104
+ chroma_vals = np.sum(chromagram, axis=1)
105
+ key, mode = self.detect_key(chroma_vals)
106
+ key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
107
+ shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
108
+ return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
109
+
110
+ def extract_features(self):
111
+ self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
112
+ self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
113
+ self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
114
+ self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
115
+ self.melspectrogram = librosa.feature.melspectrogram(y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
116
+ self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
117
+ self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
118
+ self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
119
+ self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
120
+ self.tempogram = np.clip(librosa.feature.tempogram(onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, np.percentile(self.tempogram, 99)).astype(np.float32)
121
+ self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1].astype(np.float32)
122
+ self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=13, hop_length=self.hop_length).astype(np.float32)
123
+ self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=3, sort=True)[1].astype(np.float32)
124
+ self.combined_features = np.vstack([self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts])
125
+ self.n_frames = self.combined_features.shape[1]
126
+ self.tempo, self.beats = librosa.beat.beat_track(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
127
+ self.meter_grid = librosa.util.fix_frames(librosa.util.frame(self.beats, frame_length=MAX_METERS, hop_length=1), x_min=0, x_max=self.n_frames)
128
+ self.key, self.mode = self.detect_key(np.sum(self.chromagram, axis=1))
129
+
130
+ def get_features(self):
131
+ self.extract_features()
132
+ return self.combined_features, self.n_frames, self.tempo, self.beats, self.meter_grid, self.key, self.mode
133
+
134
+ def load_model(model_path=MODEL_PATH):
135
+ return tf.keras.models.load_model(model_path)
136
+
137
+ def predict_chorus(audio_features, model):
138
+ features, n_frames, tempo, beats, meter_grid, key, mode = audio_features.get_features()
139
+ features = features[:, :MAX_FRAMES]
140
+ features = np.expand_dims(features, axis=0)
141
+ scaler = StandardScaler()
142
+ features = scaler.fit_transform(features.reshape(-1, features.shape[-1])).reshape(features.shape)
143
+ predictions = model.predict(features)
144
+ return predictions
145
+
146
+ def plot_predictions(predictions, title):
147
+ plt.figure(figsize=(10, 4))
148
+ plt.plot(predictions[0], label='Chorus Probability')
149
+ plt.title(title)
150
+ plt.xlabel('Frame')
151
+ plt.ylabel('Probability')
152
+ plt.legend()
153
+ st.pyplot(plt)
154
+
155
+ def main():
156
+ st.title("Chorus Finder")
157
+ st.write("Upload a YouTube URL to find the chorus in the song.")
158
+ url = st.text_input("YouTube URL")
159
+ if st.button("Find Chorus"):
160
+ if url:
161
+ audio_file, video_title = extract_audio(url)
162
+ if audio_file:
163
+ strip_silence(audio_file)
164
+ audio_features = AudioFeature(audio_file)
165
+ model = load_model()
166
+ predictions = predict_chorus(audio_features, model)
167
+ plot_predictions(predictions, video_title)
168
+ shutil.rmtree(AUDIO_TEMP_PATH)
169
+ else:
170
+ st.error("Please enter a valid YouTube URL")
171
+
172
+ if __name__ == "__main__":
173
+ main()
174
+