File size: 3,388 Bytes
6204178
 
 
 
 
cae9ed9
6204178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c77f36f
 
 
 
6204178
 
 
 
 
 
 
416ff82
6204178
 
 
 
 
 
 
 
 
 
 
 
50b6be3
 
 
 
6204178
 
 
 
50b6be3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
from qdrant_client import QdrantClient
from transformers import pipeline
from audiocraft.models import MusicGen
import os
import torch
# import baseten

st.title("Music Recommendation App")
st.subheader("A :red[Generative AI]-to-Real Music Approach")

st.markdown("""
The purpose of this app is to help creative people explore the possibilities of Generative AI in the music
domain, while comparing their creations to music made by people with all sorts of instruments.  

There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and 
Qdrant for our vector database.
""")

client     = QdrantClient(
    "https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333", 
    api_key=os.environ['QDRANT_API_KEY'],
)

# classifier = baseten.deployed_model_id('20awxxq')
classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device)
model      = MusicGen.get_pretrained('small')

val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5)

model.set_generation_params(
    use_sampling=True,
    top_k=250,
    duration=val1
)

music_prompt = st.text_input(
    label="Music Prompt",
    value="Fast-paced bachata in the style of Romeo Santos."
)

if st.button("Generate Some Music!"):
    with st.spinner("Wait for it..."):
        output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy()
        st.success("Done! :)")

    st.audio(output, sample_rate=32000)

    genres = classifier(output)

    if genres:
        st.markdown("## Best Prediction")
        col1, col2 = st.columns(2, gap="small")
        col1.subheader(genres[0]['label'])
        col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%")

        st.markdown("### Other Predictions")
        col3, col4 = st.columns(2, gap="small")
        for idx, genre in enumerate(genres[1:]):
            if idx % 2 == 0:
                col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
            else:
                col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")

    features = classifier.feature_extractor(
        output, sampling_rate=16_000, return_tensors="pt", padding=True, 
        return_attention_mask=True, max_length=16_000, truncation=True
    )

    with torch.no_grad():
        vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0]


    results = client.search(
        collection_name="music_vectors",
        query_vector=vectr.tolist(),
        limit=10
    )

    st.markdown("## Real Recommendations")

    col5, col6 = st.columns(2)

    for idx, result in enumerate(results):
        if idx % 2 == 0:
            col5.header(f"Genre: {result.payload['genre']}")
            col5.markdown(f"### Artist: {result.payload['artist']}")
            col5.markdown(f"#### Song name: {result.payload['name']}")
            try:
                col5.audio(result.payload["urls"])
            except:
                continue
        else:
            col6.header(f"Genre: {result.payload['genre']}")
            col6.markdown(f"### Artist: {result.payload['artist']}")
            col6.markdown(f"#### Song name: {result.payload['name']}")
            try:
                col6.audio(result.payload["urls"])
            except:
                continue