Spaces:
Running
Running
import streamlit as st | |
from qdrant_client import QdrantClient | |
from transformers import pipeline | |
from audiocraft.models import MusicGen | |
import os | |
import torch | |
# import baseten | |
st.title("Music Recommendation App") | |
st.subheader("A :red[Generative AI]-to-Real Music Approach") | |
st.markdown(""" | |
The purpose of this app is to help creative people explore the possibilities of Generative AI in the music | |
domain, while comparing their creations to music made by people with all sorts of instruments. | |
There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and | |
Qdrant for our vector database. | |
""") | |
client = QdrantClient( | |
"https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333", | |
api_key=os.environ['QDRANT_API_KEY'], | |
) | |
# classifier = baseten.deployed_model_id('20awxxq') | |
classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device) | |
model = MusicGen.get_pretrained('small') | |
val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5) | |
model.set_generation_params( | |
use_sampling=True, | |
top_k=250, | |
duration=val1 | |
) | |
music_prompt = st.text_input( | |
label="Music Prompt", | |
value="Fast-paced bachata in the style of Romeo Santos." | |
) | |
if st.button("Generate Some Music!"): | |
with st.spinner("Wait for it..."): | |
output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy() | |
st.success("Done! :)") | |
st.audio(output, sample_rate=32000) | |
genres = classifier(output) | |
if genres: | |
st.markdown("## Best Prediction") | |
col1, col2 = st.columns(2, gap="small") | |
col1.subheader(genres[0]['label']) | |
col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%") | |
st.markdown("### Other Predictions") | |
col3, col4 = st.columns(2, gap="small") | |
for idx, genre in enumerate(genres[1:]): | |
if idx % 2 == 0: | |
col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%") | |
else: | |
col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%") | |
features = classifier.feature_extractor( | |
output, sampling_rate=16_000, return_tensors="pt", padding=True, | |
return_attention_mask=True, max_length=16_000, truncation=True | |
) | |
with torch.no_grad(): | |
vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0] | |
results = client.search( | |
collection_name="music_vectors", | |
query_vector=vectr.tolist(), | |
limit=10 | |
) | |
st.markdown("## Real Recommendations") | |
col5, col6 = st.columns(2) | |
for idx, result in enumerate(results): | |
if idx % 2 == 0: | |
col5.header(f"Genre: {result.payload['genre']}") | |
col5.markdown(f"### Artist: {result.payload['artist']}") | |
col5.markdown(f"#### Song name: {result.payload['name']}") | |
col5.audio(result.payload["urls"]) | |
else: | |
col6.header(f"Genre: {result.payload['genre']}") | |
col6.markdown(f"### Artist: {result.payload['artist']}") | |
col6.markdown(f"#### Song name: {result.payload['name']}") | |
col6.audio(result.payload["urls"]) |