ramonpzg's picture
added try-except logic to deal with bad song urls
50b6be3
import streamlit as st
from qdrant_client import QdrantClient
from transformers import pipeline
from audiocraft.models import MusicGen
import os
import torch
# import baseten
st.title("Music Recommendation App")
st.subheader("A :red[Generative AI]-to-Real Music Approach")
st.markdown("""
The purpose of this app is to help creative people explore the possibilities of Generative AI in the music
domain, while comparing their creations to music made by people with all sorts of instruments.
There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and
Qdrant for our vector database.
""")
client = QdrantClient(
"https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333",
api_key=os.environ['QDRANT_API_KEY'],
)
# classifier = baseten.deployed_model_id('20awxxq')
classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device)
model = MusicGen.get_pretrained('small')
val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5)
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=val1
)
music_prompt = st.text_input(
label="Music Prompt",
value="Fast-paced bachata in the style of Romeo Santos."
)
if st.button("Generate Some Music!"):
with st.spinner("Wait for it..."):
output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy()
st.success("Done! :)")
st.audio(output, sample_rate=32000)
genres = classifier(output)
if genres:
st.markdown("## Best Prediction")
col1, col2 = st.columns(2, gap="small")
col1.subheader(genres[0]['label'])
col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%")
st.markdown("### Other Predictions")
col3, col4 = st.columns(2, gap="small")
for idx, genre in enumerate(genres[1:]):
if idx % 2 == 0:
col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
else:
col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
features = classifier.feature_extractor(
output, sampling_rate=16_000, return_tensors="pt", padding=True,
return_attention_mask=True, max_length=16_000, truncation=True
)
with torch.no_grad():
vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0]
results = client.search(
collection_name="music_vectors",
query_vector=vectr.tolist(),
limit=10
)
st.markdown("## Real Recommendations")
col5, col6 = st.columns(2)
for idx, result in enumerate(results):
if idx % 2 == 0:
col5.header(f"Genre: {result.payload['genre']}")
col5.markdown(f"### Artist: {result.payload['artist']}")
col5.markdown(f"#### Song name: {result.payload['name']}")
try:
col5.audio(result.payload["urls"])
except:
continue
else:
col6.header(f"Genre: {result.payload['genre']}")
col6.markdown(f"### Artist: {result.payload['artist']}")
col6.markdown(f"#### Song name: {result.payload['name']}")
try:
col6.audio(result.payload["urls"])
except:
continue