Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit import session_state as session | |
from src.laion_clap.inference import AudioEncoder | |
# from src.utils.spotify import SpotifyHandler, SpotifyAuthentication | |
import pandas as pd | |
from dotenv import load_dotenv | |
from langchain.llms import CTransformers, Ollama | |
from src.llm.chain import LLMChain | |
from pymongo.mongo_client import MongoClient | |
import os | |
st.set_page_config(page_title="Curate me a playlist", layout="wide") | |
load_dotenv() | |
def load_llm_pipeline(): | |
ctransformers_config = { | |
"max_new_tokens": 3000, | |
"temperature": 0, | |
"top_k": 1, | |
"top_p": 1, | |
"context_length": 2800 | |
} | |
llm = CTransformers( | |
model="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", | |
model_file="mistral-7b-instruct-v0.1.Q5_K_M.gguf", | |
config=ctransformers_config | |
) | |
# llm = Ollama(temperature=0, model="mistral:7b-instruct-q8_0", top_k=1, top_p=1, num_ctx=2800) | |
chain = LLMChain(llm) | |
return chain | |
def load_resources(): | |
password = os.getenv("MONGODB_PASSWORD") | |
url = os.getenv("MONGODB_URL") | |
uri = f"mongodb+srv://berkaygkv:{password}@{url}/?retryWrites=true&w=majority" | |
client = MongoClient(uri) | |
db = client.spoti | |
mongo_db_collection = db.saved_tracks | |
recommender = AudioEncoder(mongo_db_collection) | |
recommender.load_existing_audio_vectors() | |
llm_pipeline = load_llm_pipeline() | |
return recommender, llm_pipeline | |
recommender, llm_pipeline = load_resources() | |
st.title("""Curate me a Playlist.""") | |
session.text_input = st.text_input(label="Describe a playlist") | |
session.slider_count = st.slider(label="Track counts", min_value=5, max_value=35, step=5) | |
buffer1, col1, col2, buffer2 = st.columns([1.45, 1, 1, 1]) | |
is_clicked = col1.button(label="Curate") | |
if is_clicked: | |
output = llm_pipeline.process_user_description(session.text_input) | |
song_list = [] | |
for _, song_desc in output: | |
print(song_desc) | |
ranking = recommender.list_top_k_songs(song_desc, k=15) | |
song_list += ranking | |
dataframe = pd.DataFrame(song_list).sort_values("score", ascending=False).drop_duplicates(subset=["track_id"]).drop(columns=["track_id"]).reset_index(drop=True) | |
dataframe = dataframe.iloc[:session.slider_count] | |
st.data_editor( | |
dataframe, | |
column_config={ | |
"link": st.column_config.LinkColumn( | |
"link", | |
# help="The top trending Streamlit apps", | |
# validate="^https://[a-z]+\.streamlit\.app$", | |
# max_chars=100, | |
) | |
}, | |
hide_index=False, | |
use_container_width=True | |
) | |
# with st.form(key="spotiform"): | |
# st.form_submit_button(on_click=authenticate_spotify, args=(session.access_url, )) | |
# st.markdown(session.access_url) | |