import streamlit as st from streamlit import session_state as session from src.config.configs import ProjectPaths import numpy as np from src.laion_clap.inference import AudioEncoder import pickle import torch import pandas as pd @st.cache_data def load_data(): vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy")) with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader: song_names = pickle.load(reader) return vectors, song_names @st.cache_resource def load_model(): recommender = AudioEncoder() return recommender recommender = load_model() audio_vectors, song_names = load_data() dataframe = None st.title("""Curate me a Playlist.""") session.text_input = st.text_input(label="Describe a playlist") session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5) buffer1, col1, buffer2 = st.columns([1.45, 1, 1]) is_clicked = col1.button(label="Curate") if is_clicked: text_embed = recommender.get_text_embedding(session.text_input) with torch.no_grad(): ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t() ranking = ranking[:, 0].reshape(-1, 1) dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).nlargest(int(session.slider_count), session.text_input).drop(columns=session.text_input) st.dataframe(dataframe)