TrackRec / src /app.py
MoonlightXST's picture
fix features
9c52bca verified
import streamlit as st
import pandas as pd
import numpy as np
from annoy import AnnoyIndex
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import StandardScaler
import pickle
import os
from pathlib import Path
CACHE_DIR = "/app/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['HF_HOME'] = CACHE_DIR
os.environ['XDG_CACHE_HOME'] = CACHE_DIR
TEXT_DIM = 384
AUDIO_DIM = 6
@st.cache_resource
def load_data_and_models():
data_path = "src/dataset.csv"
model_dir = "src/"
if not os.path.exists(data_path):
raise FileNotFoundError(f"Dataset not found at {data_path}")
data = pd.read_csv(data_path)
text_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=CACHE_DIR)
scaler = pickle.load(open(f"{model_dir}scaler.pkl", "rb"))
annoy_index = AnnoyIndex(TEXT_DIM + AUDIO_DIM, 'angular')
annoy_index.load(f"{model_dir}annoy_index.ann")
return data, text_model, scaler, annoy_index
def recommend(track_name, data, model, scaler, annoy_index, n=5):
try:
track = data[data['track_name'].str.lower() == track_name.lower()].iloc[0]
text_emb = model.encode(
f"{track['artists']} - {track['track_name']} - {track['album_name']}"
)
audio_features = track[[
'danceability', 'energy', 'valence', 'tempo',
'loudness', 'acousticness'
]].values.reshape(1, -1)
audio_scaled = scaler.transform(audio_features)
combined = np.hstack([text_emb, audio_scaled[0]])
similar_indices = annoy_index.get_nns_by_vector(combined, n+1)[1:]
similar_tracks = data.iloc[similar_indices]
return similar_tracks[['artists', 'track_name', 'album_name']]
except Exception as e:
st.error(f"Ошибка: {str(e)}")
return pd.DataFrame()
def main():
st.set_page_config(page_title="🎧 Music Recommender", layout="wide")
st.markdown(
"<h1 style='text-align: center; color: #4A4A4A;'>🎵 Музыкальные рекомендации</h1>",
unsafe_allow_html=True
)
st.markdown("<p style='text-align: center;'>Найдите треки, похожие на ваш любимый — по звучанию и настроению.</p>", unsafe_allow_html=True)
try:
data, text_model, scaler, annoy_index = load_data_and_models()
except Exception as e:
st.error(f"Ошибка загрузки данных: {str(e)}")
return
data['display_name'] = data['track_name'] + " — " + data['artists']
with st.container():
st.markdown("### 🔍 Выбор трека")
track_display = st.selectbox(
"Выберите трек из списка:",
options=data['display_name'].unique(),
index=0
)
selected_track_name = data[data['display_name'] == track_display]['track_name'].values[0]
col1, col2 = st.columns([1, 3])
with col1:
find_button = st.button("🔎 Найти похожие", type="primary")
if find_button:
with st.spinner("🔬 Анализируем звучание и стиль..."):
recommendations = recommend(
selected_track_name, data, text_model, scaler, annoy_index
)
if not recommendations.empty:
st.markdown("### 🎧 Вам могут понравиться:")
st.dataframe(
recommendations.rename(columns={
"artists": "👤 Исполнитель",
"track_name": "🎵 Трек",
"album_name": "💿 Альбом"
}),
use_container_width=True,
hide_index=True
)
else:
st.warning("К сожалению, похожих треков не найдено. Попробуйте другой.")
if __name__ == "__main__":
main()