berkaygkv54's picture
streamlit component changes
9a08670
raw
history blame
1.42 kB
import streamlit as st
from streamlit import session_state as session
from src.config.configs import ProjectPaths
import numpy as np
from src.laion_clap.inference import AudioEncoder
import pickle
import torch
import pandas as pd
@st.cache_data
def load_data():
vectors = np.load(ProjectPaths.DATA_DIR.joinpath("vectors", "audio_representations.npy"))
with open(ProjectPaths.DATA_DIR.joinpath("vectors", "song_names.pkl"), "rb") as reader:
song_names = pickle.load(reader)
return vectors, song_names
@st.cache_resource
def load_model():
recommender = AudioEncoder()
return recommender
recommender = load_model()
audio_vectors, song_names = load_data()
dataframe = None
st.title("""Curate me a Playlist.""")
session.text_input = st.text_input(label="Describe a playlist")
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=30, step=5)
buffer1, col1, buffer2 = st.columns([1.45, 1, 1])
is_clicked = col1.button(label="Curate")
if is_clicked:
text_embed = recommender.get_text_embedding(session.text_input)
with torch.no_grad():
ranking = torch.tensor(audio_vectors) @ torch.tensor(text_embed).t()
ranking = ranking[:, 0].reshape(-1, 1)
dataframe = pd.DataFrame(ranking, columns=[session.text_input], index=song_names).nlargest(int(session.slider_count), session.text_input).drop(columns=session.text_input)
st.dataframe(dataframe)