OneAfterlife commited on
Commit
8c0b554
1 Parent(s): f1d2159

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +96 -0
  2. constants.py +7 -0
  3. requirements.txt +2 -0
  4. utils.py +50 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ import httpx
7
+ import json
8
+
9
+ from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
10
+
11
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
12
+ mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
13
+
14
+
15
+ def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
16
+ if loop:
17
+ mode = "loop"
18
+ else:
19
+ mode = "track"
20
+ r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
21
+ json={
22
+ "method": "RecordTrackTTM",
23
+ "params": {
24
+ "pat": pat,
25
+ "duration": duration,
26
+ "tags": tags,
27
+ "mode": mode
28
+ }
29
+ })
30
+
31
+ rdata = json.loads(r.text)
32
+ assert rdata['status'] == 1, rdata['error']['text']
33
+ trackurl = rdata['data']['tasks'][0]['download_link']
34
+
35
+ print('Generating track ', end='')
36
+ for i in range(maxit):
37
+ r = httpx.get(trackurl)
38
+ if r.status_code == 200:
39
+ return trackurl
40
+ time.sleep(1)
41
+
42
+
43
+ def generate_track_by_prompt(email, prompt, duration, loop=False):
44
+ try:
45
+ pat = get_pat(email)
46
+ _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
47
+ return get_track_by_tags(tags, pat, int(duration), loop=loop), "Success", ",".join(tags)
48
+ except Exception as e:
49
+ return None, str(e), ""
50
+
51
+
52
+ block = gr.Blocks()
53
+
54
+ with block:
55
+ gr.HTML(
56
+ """
57
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
58
+ <div
59
+ style="
60
+ display: inline-flex;
61
+ align-items: center;
62
+ gap: 0.8rem;
63
+ font-size: 1.75rem;
64
+ "
65
+ >
66
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
67
+ TTM By Satwik
68
+ </h1>
69
+ </div>
70
+ < </div>
71
+ """
72
+ )
73
+ with gr.Group():
74
+ with gr.Box():
75
+ email = gr.Textbox(label="email")
76
+ #email="oneafterlif3@gmail.com"
77
+ prompt = gr.Textbox(label="prompt")
78
+ duration = gr.Slider(label="duration (seconds)", value=60, maximum=300)
79
+ is_loop = gr.Checkbox(label="Generate loop")
80
+ out = gr.Audio()
81
+ result_msg = gr.Text(label="Result message")
82
+ tags = gr.Text(label="Tags")
83
+ btn = gr.Button("Submit").style(full_width=True)
84
+
85
+ btn.click(fn=generate_track_by_prompt, inputs=[email, prompt, duration, is_loop], outputs=[out, result_msg, tags])
86
+
87
+ gr.HTML('''
88
+ <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
89
+ </div>
90
+ </div>
91
+ <p style="margin-bottom: 10px; font-size: 94%">
92
+ if you put anything over 250 seconds, you will need to wait 10 or 30 second after it is done processing.
93
+ </div>
94
+ ''')
95
+
96
+ block.launch()
constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ MUBERT_TAGS_STRING = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'
4
+ MUBERT_TAGS = np.array(MUBERT_TAGS_STRING.split(','))
5
+ MUBERT_LICENSE = "ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5"
6
+ MUBERT_MODE = "loop"
7
+ MUBERT_TOKEN = "4951f6428e83172a4f39de05d5b3ab10d58560b8"
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ httpx
2
+ sentence-transformers
utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import httpx
4
+
5
+ from constants import MUBERT_TAGS, MUBERT_LICENSE, MUBERT_MODE, MUBERT_TOKEN
6
+
7
+
8
+ def get_mubert_tags_embeddings(w2v_model):
9
+ return w2v_model.encode(MUBERT_TAGS)
10
+
11
+
12
+ def get_pat(email: str):
13
+ r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess',
14
+ json={
15
+ "method": "GetServiceAccess",
16
+ "params": {
17
+ "email": email,
18
+ "license": MUBERT_LICENSE,
19
+ "token": MUBERT_TOKEN,
20
+ "mode": MUBERT_MODE,
21
+ }
22
+ })
23
+
24
+ rdata = json.loads(r.text)
25
+ assert rdata['status'] == 1, "probably incorrect e-mail"
26
+ pat = rdata['data']['pat']
27
+ return pat
28
+
29
+
30
+ def find_similar(em, embeddings, method='cosine'):
31
+ scores = []
32
+ for ref in embeddings:
33
+ if method == 'cosine':
34
+ scores.append(1 - np.dot(ref, em) / (np.linalg.norm(ref) * np.linalg.norm(em)))
35
+ if method == 'norm':
36
+ scores.append(np.linalg.norm(ref - em))
37
+ return np.array(scores), np.argsort(scores)
38
+
39
+
40
+ def get_tags_for_prompts(w2v_model, mubert_tags_embeddings, prompts, top_n=3, debug=False):
41
+ prompts_embeddings = w2v_model.encode(prompts)
42
+ ret = []
43
+ for i, pe in enumerate(prompts_embeddings):
44
+ scores, idxs = find_similar(pe, mubert_tags_embeddings)
45
+ top_tags = MUBERT_TAGS[idxs[:top_n]]
46
+ top_prob = 1 - scores[idxs[:top_n]]
47
+ if debug:
48
+ print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")
49
+ ret.append((prompts[i], list(top_tags)))
50
+ return ret