ramonpzg's picture
create app
6204178
raw
history blame
3.12 kB
import streamlit as st
from qdrant_client import QdrantClient
from transformers import pipeline
from audiocraft.models import MusicGen
import os
# import baseten
st.title("Music Recommendation App")
st.subheader("A :red[Generative AI]-to-Real Music Approach")
st.markdown("""
The purpose of this app is to help creative people explore the possibilities of Generative AI in the music
domain, while comparing their creations to music made by people with all sorts of instruments.
There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and
Qdrant for our vector database.
""")
client = QdrantClient(
"https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333",
api_key=os.environ['QDRANT_API_KEY'],
)
# classifier = baseten.deployed_model_id('20awxxq')
classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device)
model = MusicGen.get_pretrained('small')
val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5)
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=val1
)
music_prompt = st.text_input(
label="Music Prompt",
value="Fast-paced bachata in the style of Romeo Santos."
)
if st.button("Generate Some Music!"):
with st.spinner("Wait for it..."):
output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy()
st.success("Done! :)")
st.audio(output, sample_rate=32000)
genres = classifier(output)
if genres:
st.markdown("## Best Prediction")
col1, col2 = st.columns(2, gap="small")
col1.subheader(genres[0]['label'])
col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%")
st.markdown("### Other Predictions")
col3, col4 = st.columns(2, gap="small")
for idx, genre in enumerate(genres[1:]):
if idx % 2 == 0:
col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
else:
col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
features = classifier.feature_extractor(output)
with torch.no_grad():
vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0]
# vectr = embedding[0]
results = client.search(
collection_name="music_vectors",
query_vector=vectr,
limit=10
)
st.markdown("## Real Recommendations")
col5, col6 = st.columns(2)
for idx, result in enumerate(results):
if idx % 2 == 0:
col5.header(f"Genre: {result.payload['genre']}")
col5.markdown(f"### Artist: {result.payload['artist']}")
col5.markdown(f"#### Song name: {result.payload['name']}")
col5.audio(result.payload["urls"])
else:
col6.header(f"Genre: {result.payload['genre']}")
col6.markdown(f"### Artist: {result.payload['artist']}")
col6.markdown(f"#### Song name: {result.payload['name']}")
col6.audio(result.payload["urls"])