import streamlit as st import joblib import numpy as np from predict import extract_features import os import tempfile from huggingface_hub import hf_hub_download, list_repo_files import logging import traceback import sklearn # Set up logging logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Page configuration st.set_page_config( page_title="Healing Music Classifier", page_icon="🎵", layout="centered" ) @st.cache_resource def load_model(): """Load model from Hugging Face Hub""" try: # 检查scikit-learn版本 logger.info(f"Using scikit-learn version: {sklearn.__version__}") st.write(f"Using scikit-learn version: {sklearn.__version__}") # 首先列出仓库中的所有文件 logger.info("Listing repository files...") try: files = list_repo_files("404Brain-Not-Found-yeah/healing-music-classifier") logger.info(f"Repository files: {files}") st.write("Available files in repository:", files) except Exception as e: logger.error(f"Error listing repository files: {str(e)}\n{traceback.format_exc()}") st.error(f"Error listing repository files: {str(e)}") return None, None # 创建临时目录 os.makedirs("temp_models", exist_ok=True) logger.info("Created temp_models directory") logger.info("Downloading model from Hugging Face Hub...") # 下载模型文件 try: model_path = hf_hub_download( repo_id="404Brain-Not-Found-yeah/healing-music-classifier", filename="models/model.joblib", local_dir="temp_models" ) logger.info(f"Model downloaded to: {model_path}") st.write(f"Model downloaded to: {model_path}") except Exception as e: logger.error(f"Error downloading model: {str(e)}\n{traceback.format_exc()}") st.error(f"Error downloading model: {str(e)}") return None, None # 下载scaler文件 try: scaler_path = hf_hub_download( repo_id="404Brain-Not-Found-yeah/healing-music-classifier", filename="models/scaler.joblib", local_dir="temp_models" ) logger.info(f"Scaler downloaded to: {scaler_path}") st.write(f"Scaler downloaded to: {scaler_path}") except Exception as e: logger.error(f"Error downloading scaler: {str(e)}\n{traceback.format_exc()}") st.error(f"Error downloading scaler: {str(e)}") return None, None # 加载模型文件 try: logger.info("Loading model and scaler...") # 检查文件是否存在 if not os.path.exists(model_path): logger.error(f"Model file not found at: {model_path}") st.error(f"Model file not found at: {model_path}") return None, None if not os.path.exists(scaler_path): logger.error(f"Scaler file not found at: {scaler_path}") st.error(f"Scaler file not found at: {scaler_path}") return None, None # 检查文件大小 model_size = os.path.getsize(model_path) scaler_size = os.path.getsize(scaler_path) logger.info(f"Model file size: {model_size} bytes") logger.info(f"Scaler file size: {scaler_size} bytes") st.write(f"Model file size: {model_size} bytes") st.write(f"Scaler file size: {scaler_size} bytes") # 尝试使用不同的pickle协议加载 try: model = joblib.load(model_path) scaler = joblib.load(scaler_path) except Exception as load_error: logger.warning(f"Standard loading failed: {str(load_error)}") # 尝试使用兼容模式加载 import pickle with open(model_path, 'rb') as f: model = pickle.load(f, encoding='latin1') with open(scaler_path, 'rb') as f: scaler = pickle.load(f, encoding='latin1') logger.info("Model and scaler loaded successfully") st.success("Model and scaler loaded successfully!") return model, scaler except Exception as e: logger.error(f"Error loading model/scaler files: {str(e)}\n{traceback.format_exc()}") st.error(f"Error loading model/scaler files: {str(e)}") return None, None except Exception as e: logger.error(f"Unexpected error in load_model: {str(e)}\n{traceback.format_exc()}") st.error(f"Unexpected error in load_model: {str(e)}") return None, None def main(): st.title("🎵 Healing Music Classifier") st.write(""" Upload your music file, and AI will analyze its healing potential! Supports mp3, wav formats. """) # Add file upload component uploaded_file = st.file_uploader("Choose an audio file...", type=['mp3', 'wav']) if uploaded_file is not None: # Create progress bar progress_bar = st.progress(0) status_text = st.empty() try: # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: # Write uploaded file content tmp_file.write(uploaded_file.getvalue()) tmp_file_path = tmp_file.name # Update status status_text.text("Analyzing music...") progress_bar.progress(30) # Load model model, scaler = load_model() if model is None or scaler is None: st.error("Model loading failed. Please check the logs for details.") return progress_bar.progress(50) # Extract features features = extract_features(tmp_file_path) if features is None: st.error("Failed to extract audio features. Please ensure the file is a valid audio file.") return progress_bar.progress(70) # Predict try: scaled_features = scaler.transform([features]) healing_probability = model.predict_proba(scaled_features)[0][1] progress_bar.progress(90) except Exception as e: logger.error(f"Error during prediction: {str(e)}\n{traceback.format_exc()}") st.error(f"Error during prediction: {str(e)}") return # Display results st.subheader("Analysis Results") # Create visualization progress bar healing_percentage = healing_probability * 100 st.progress(healing_probability) # Display percentage st.write(f"Healing Index: {healing_percentage:.1f}%") # Provide explanation if healing_percentage >= 75: st.success("This music has strong healing properties! 🌟") elif healing_percentage >= 50: st.info("This music has moderate healing effects. ✨") else: st.warning("This music has limited healing potential. 🎵") except Exception as e: st.error(f"An unexpected error occurred: {str(e)}") logger.exception("Unexpected error") finally: # Clean up temporary file try: if 'tmp_file_path' in locals() and os.path.exists(tmp_file_path): os.unlink(tmp_file_path) except Exception as e: logger.error(f"Failed to clean up temporary file: {str(e)}") # Complete progress bar progress_bar.progress(100) status_text.text("Analysis complete!") if __name__ == "__main__": main()