dennisvdang's picture
Flatten directory structure for simpler imports
ad0da04
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Module for loading and managing the CRNN model for chorus detection."""
import os
from typing import Any, Optional, List, Tuple, Union
import numpy as np
import tensorflow as tf
from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
from chorus_detection.utils.logging import logger
def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
"""Load a CRNN model with custom loss and accuracy functions.
Args:
model_path: Path to the saved model
Returns:
Loaded Keras model
Raises:
RuntimeError: If the model cannot be loaded
"""
try:
# Define custom objects required for model loading
custom_objects = {
'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
'custom_accuracy': lambda y_true, y_pred: y_pred
}
# Try to load the model with custom objects
logger.info(f"Loading model from: {model_path}")
model = tf.keras.models.load_model(
model_path, custom_objects=custom_objects, compile=False)
# Compile the model with default optimizer and loss for prediction only
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
except Exception as e:
logger.error(f"Error loading model from {model_path}: {e}")
# Try Docker container path as fallback
if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
return load_CRNN_model(DOCKER_MODEL_PATH)
raise RuntimeError(f"Failed to load model: {e}")
def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
"""Smooth predictions by correcting isolated mispredictions and removing short sequences.
Args:
predictions: Array of binary predictions
Returns:
Smoothed array of binary predictions
"""
# Convert to numpy array if not already
data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
# First pass: Correct isolated 0's (handle 0's surrounded by 1's)
for i in range(1, len(data) - 1):
if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
data[i] = 1
# Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
corrected_data = data.copy()
for i in range(1, len(data) - 1):
if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
corrected_data[i] = 0
# Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
smoothed_data = corrected_data.copy()
sequence_start = None
for i in range(len(corrected_data)):
if corrected_data[i] == 1:
if sequence_start is None:
sequence_start = i
else:
if sequence_start is not None:
sequence_length = i - sequence_start
if sequence_length < 5:
smoothed_data[sequence_start:i] = 0
sequence_start = None
# Handle the case where the sequence extends to the end
if sequence_start is not None:
sequence_length = len(corrected_data) - sequence_start
if sequence_length < 5:
smoothed_data[sequence_start:] = 0
return smoothed_data
def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
audio_features: Any, url: Optional[str] = None,
video_name: Optional[str] = None) -> np.ndarray:
"""Generate predictions from the model and process them.
Args:
model: The loaded model for making predictions
processed_audio: The audio data that has been processed for prediction
audio_features: Audio features object containing necessary metadata
url: YouTube URL of the audio file (optional)
video_name: Name of the video (optional)
Returns:
The smoothed binary predictions
"""
import librosa
logger.info("Generating predictions...")
# Make predictions
predictions = model.predict(processed_audio)[0]
# Convert to binary predictions and handle potential size mismatch
meter_grid_length = len(audio_features.meter_grid) - 1
if len(predictions) > meter_grid_length:
predictions = predictions[:meter_grid_length]
binary_predictions = np.round(predictions).flatten()
# Apply smoothing to improve prediction quality
smoothed_predictions = smooth_predictions(binary_predictions)
# Get times for identified chorus sections
meter_grid_times = librosa.frames_to_time(
audio_features.meter_grid,
sr=audio_features.sr,
hop_length=audio_features.hop_length
)
# Identify where choruses start
chorus_start_times = [
meter_grid_times[i] for i in range(len(smoothed_predictions))
if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
]
# Print results if URL and video name are provided (CLI mode)
if url and video_name:
_print_chorus_results(url, video_name, chorus_start_times)
return smoothed_predictions
def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
"""Print formatted results showing identified choruses with links.
Args:
url: YouTube URL of the analyzed video
video_name: Name of the video
chorus_start_times: List of start times (in seconds) for identified choruses
"""
# Create YouTube links with time stamps
youtube_links = [
f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
for start_time in chorus_start_times
]
# Format the output
link_lengths = [len(link) for link in youtube_links]
max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
header_footer = "=" * (max_length + 4)
# Print the results
print("\n\n")
print(header_footer)
print(f"{video_name.center(max_length + 2)}")
print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
print(header_footer)
if chorus_start_times:
for link in youtube_links:
print(link)
else:
print("No choruses identified.")
print(header_footer)