Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit import session_state as session | |
from src.config.configs import ProjectPaths | |
import numpy as np | |
from src.laion_clap.inference import AudioEncoder | |
import pickle | |
import torch | |
import pandas as pd | |
def load_data(): | |
vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) | |
with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader: | |
song_names = pickle.load(reader) | |
return vectors, song_names | |
def load_model(): | |
recommender = AudioEncoder() | |
return recommender | |
recommender = load_model() | |
audio_vectors, song_names = load_data() | |
dataframe = None | |
st.title("""Curate me a Playlist.""") | |
session.text_input = st.text_input(label="Describe a playlist") | |
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5) | |
buffer1, col1, buffer2 = st.columns([1.45, 1, 1]) | |
is_clicked = col1.button(label="Curate") | |
if is_clicked: | |
text_embed = recommender.get_text_embedding(session.text_input) | |
with torch.no_grad(): | |
ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t() | |
ranking = ranking[:, 0].reshape(-1, 1) | |
dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).nlargest(int(session.slider_count), session.text_input).drop(columns=session.text_input) | |
st.dataframe(dataframe) | |