Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
"""Module for loading and managing the CRNN model for chorus detection.""" | |
import os | |
from typing import Any, Optional, List, Tuple, Union | |
import numpy as np | |
import tensorflow as tf | |
from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH | |
from chorus_detection.utils.logging import logger | |
def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model: | |
"""Load a CRNN model with custom loss and accuracy functions. | |
Args: | |
model_path: Path to the saved model | |
Returns: | |
Loaded Keras model | |
Raises: | |
RuntimeError: If the model cannot be loaded | |
""" | |
try: | |
# Define custom objects required for model loading | |
custom_objects = { | |
'custom_binary_crossentropy': lambda y_true, y_pred: y_pred, | |
'custom_accuracy': lambda y_true, y_pred: y_pred | |
} | |
# Try to load the model with custom objects | |
logger.info(f"Loading model from: {model_path}") | |
model = tf.keras.models.load_model( | |
model_path, custom_objects=custom_objects, compile=False) | |
# Compile the model with default optimizer and loss for prediction only | |
model.compile(optimizer='adam', loss='binary_crossentropy') | |
return model | |
except Exception as e: | |
logger.error(f"Error loading model from {model_path}: {e}") | |
# Try Docker container path as fallback | |
if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH): | |
logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}") | |
return load_CRNN_model(DOCKER_MODEL_PATH) | |
raise RuntimeError(f"Failed to load model: {e}") | |
def smooth_predictions(predictions: np.ndarray) -> np.ndarray: | |
"""Smooth predictions by correcting isolated mispredictions and removing short sequences. | |
Args: | |
predictions: Array of binary predictions | |
Returns: | |
Smoothed array of binary predictions | |
""" | |
# Convert to numpy array if not already | |
data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy() | |
# First pass: Correct isolated 0's (handle 0's surrounded by 1's) | |
for i in range(1, len(data) - 1): | |
if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1: | |
data[i] = 1 | |
# Second pass: Correct isolated 1's (handle 1's surrounded by 0's) | |
corrected_data = data.copy() | |
for i in range(1, len(data) - 1): | |
if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0: | |
corrected_data[i] = 0 | |
# Third pass: Remove short sequences of 1s (less than 5 consecutive 1's) | |
smoothed_data = corrected_data.copy() | |
sequence_start = None | |
for i in range(len(corrected_data)): | |
if corrected_data[i] == 1: | |
if sequence_start is None: | |
sequence_start = i | |
else: | |
if sequence_start is not None: | |
sequence_length = i - sequence_start | |
if sequence_length < 5: | |
smoothed_data[sequence_start:i] = 0 | |
sequence_start = None | |
# Handle the case where the sequence extends to the end | |
if sequence_start is not None: | |
sequence_length = len(corrected_data) - sequence_start | |
if sequence_length < 5: | |
smoothed_data[sequence_start:] = 0 | |
return smoothed_data | |
def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray, | |
audio_features: Any, url: Optional[str] = None, | |
video_name: Optional[str] = None) -> np.ndarray: | |
"""Generate predictions from the model and process them. | |
Args: | |
model: The loaded model for making predictions | |
processed_audio: The audio data that has been processed for prediction | |
audio_features: Audio features object containing necessary metadata | |
url: YouTube URL of the audio file (optional) | |
video_name: Name of the video (optional) | |
Returns: | |
The smoothed binary predictions | |
""" | |
import librosa | |
logger.info("Generating predictions...") | |
# Make predictions | |
predictions = model.predict(processed_audio)[0] | |
# Convert to binary predictions and handle potential size mismatch | |
meter_grid_length = len(audio_features.meter_grid) - 1 | |
if len(predictions) > meter_grid_length: | |
predictions = predictions[:meter_grid_length] | |
binary_predictions = np.round(predictions).flatten() | |
# Apply smoothing to improve prediction quality | |
smoothed_predictions = smooth_predictions(binary_predictions) | |
# Get times for identified chorus sections | |
meter_grid_times = librosa.frames_to_time( | |
audio_features.meter_grid, | |
sr=audio_features.sr, | |
hop_length=audio_features.hop_length | |
) | |
# Identify where choruses start | |
chorus_start_times = [ | |
meter_grid_times[i] for i in range(len(smoothed_predictions)) | |
if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0) | |
] | |
# Print results if URL and video name are provided (CLI mode) | |
if url and video_name: | |
_print_chorus_results(url, video_name, chorus_start_times) | |
return smoothed_predictions | |
def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None: | |
"""Print formatted results showing identified choruses with links. | |
Args: | |
url: YouTube URL of the analyzed video | |
video_name: Name of the video | |
chorus_start_times: List of start times (in seconds) for identified choruses | |
""" | |
# Create YouTube links with time stamps | |
youtube_links = [ | |
f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\" | |
for start_time in chorus_start_times | |
] | |
# Format the output | |
link_lengths = [len(link) for link in youtube_links] | |
max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50 | |
header_footer = "=" * (max_length + 4) | |
# Print the results | |
print("\n\n") | |
print(header_footer) | |
print(f"{video_name.center(max_length + 2)}") | |
print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4)) | |
print(header_footer) | |
if chorus_start_times: | |
for link in youtube_links: | |
print(link) | |
else: | |
print("No choruses identified.") | |
print(header_footer) |