ramonpzg commited on
Commit
6204178
1 Parent(s): 7d054f9

create app

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from qdrant_client import QdrantClient
3
+ from transformers import pipeline
4
+ from audiocraft.models import MusicGen
5
+ import os
6
+ # import baseten
7
+
8
+ st.title("Music Recommendation App")
9
+ st.subheader("A :red[Generative AI]-to-Real Music Approach")
10
+
11
+ st.markdown("""
12
+ The purpose of this app is to help creative people explore the possibilities of Generative AI in the music
13
+ domain, while comparing their creations to music made by people with all sorts of instruments.
14
+
15
+ There are several moving parts to this app and the most important ones are `transformers`, `audiocraft`, and
16
+ Qdrant for our vector database.
17
+ """)
18
+
19
+ client = QdrantClient(
20
+ "https://394294d5-30bb-4958-ad1a-15a3561edce5.us-east-1-0.aws.cloud.qdrant.io:6333",
21
+ api_key=os.environ['QDRANT_API_KEY'],
22
+ )
23
+
24
+ # classifier = baseten.deployed_model_id('20awxxq')
25
+ classifier = pipeline("audio-classification", model="ramonpzg/wav2musicgenre")#.to(device)
26
+ model = MusicGen.get_pretrained('small')
27
+
28
+ val1 = st.slider("How many seconds?", 5.0, 30.0, value=5.0, step=0.5)
29
+
30
+ model.set_generation_params(
31
+ use_sampling=True,
32
+ top_k=250,
33
+ duration=val1
34
+ )
35
+
36
+ music_prompt = st.text_input(
37
+ label="Music Prompt",
38
+ value="Fast-paced bachata in the style of Romeo Santos."
39
+ )
40
+
41
+ if st.button("Generate Some Music!"):
42
+ with st.spinner("Wait for it..."):
43
+ output = model.generate(descriptions=[music_prompt],progress=True)[0, 0, :].cpu().numpy()
44
+ st.success("Done! :)")
45
+
46
+ st.audio(output, sample_rate=32000)
47
+
48
+ genres = classifier(output)
49
+
50
+ if genres:
51
+ st.markdown("## Best Prediction")
52
+ col1, col2 = st.columns(2, gap="small")
53
+ col1.subheader(genres[0]['label'])
54
+ col2.metric(label="Score", value=f"{genres[0]['score']*100:.2f}%")
55
+
56
+ st.markdown("### Other Predictions")
57
+ col3, col4 = st.columns(2, gap="small")
58
+ for idx, genre in enumerate(genres[1:]):
59
+ if idx % 2 == 0:
60
+ col3.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
61
+ else:
62
+ col4.metric(label=genre['label'], value=f"{genre['score']*100:.2f}%")
63
+
64
+ features = classifier.feature_extractor(output)
65
+
66
+ with torch.no_grad():
67
+ vectr = classifier.model(**features, output_hidden_states=True).hidden_states[-1].mean(dim=1)[0]
68
+
69
+ # vectr = embedding[0]
70
+
71
+ results = client.search(
72
+ collection_name="music_vectors",
73
+ query_vector=vectr,
74
+ limit=10
75
+ )
76
+
77
+ st.markdown("## Real Recommendations")
78
+
79
+ col5, col6 = st.columns(2)
80
+
81
+ for idx, result in enumerate(results):
82
+ if idx % 2 == 0:
83
+ col5.header(f"Genre: {result.payload['genre']}")
84
+ col5.markdown(f"### Artist: {result.payload['artist']}")
85
+ col5.markdown(f"#### Song name: {result.payload['name']}")
86
+ col5.audio(result.payload["urls"])
87
+ else:
88
+ col6.header(f"Genre: {result.payload['genre']}")
89
+ col6.markdown(f"### Artist: {result.payload['artist']}")
90
+ col6.markdown(f"#### Song name: {result.payload['name']}")
91
+ col6.audio(result.payload["urls"])