breadlicker45 commited on
Commit
44b72bb
1 Parent(s): 0307f8d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ import httpx
7
+ import json
8
+
9
+ from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat
10
+
11
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
12
+ mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)
13
+
14
+
15
+ def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
16
+ if loop:
17
+ mode = "loop"
18
+ else:
19
+ mode = "track"
20
+ r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
21
+ json={
22
+ "method": "RecordTrackTTM",
23
+ "params": {
24
+ "pat": pat,
25
+ "duration": duration,
26
+ "tags": tags,
27
+ "mode": mode
28
+ }
29
+ })
30
+
31
+ rdata = json.loads(r.text)
32
+ assert rdata['status'] == 1, rdata['error']['text']
33
+ trackurl = rdata['data']['tasks'][0]['download_link']
34
+
35
+ print('Generating track ', end='')
36
+ for i in range(maxit):
37
+ r = httpx.get(trackurl)
38
+ if r.status_code == 200:
39
+ return trackurl
40
+ time.sleep(1)
41
+
42
+
43
+ def generate_track_by_prompt(email, prompt, duration, loop=False):
44
+ try:
45
+ pat = get_pat(email)
46
+ _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
47
+ return get_track_by_tags(tags, pat, int(duration), loop=loop), "Success", ",".join(tags)
48
+ except Exception as e:
49
+ return None, str(e), ""
50
+
51
+
52
+ block = gr.Blocks()
53
+
54
+ with block:
55
+ gr.HTML(
56
+ """
57
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
58
+ <div
59
+ style="
60
+ display: inline-flex;
61
+ align-items: center;
62
+ gap: 0.8rem;
63
+ font-size: 1.75rem;
64
+ "
65
+ >
66
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
67
+ Mubert
68
+ </h1>
69
+ </div>
70
+ <p style="margin-bottom: 10px; font-size: 94%">
71
+ All music is generated by Mubert API – <a href="https://mubert.com" style="text-decoration: underline;" target="_blank">www.mubert.com</a>
72
+ </p>
73
+ </div>
74
+ """
75
+ )
76
+ with gr.Group():
77
+ with gr.Box():
78
+ email = gr.Textbox(label="email")
79
+ prompt = gr.Textbox(label="prompt")
80
+ duration = gr.Slider(label="duration (seconds)", value=30)
81
+ is_loop = gr.Checkbox(label="Generate loop")
82
+ out = gr.Audio()
83
+ result_msg = gr.Text(label="Result message")
84
+ tags = gr.Text(label="Tags")
85
+ btn = gr.Button("Submit").style(full_width=True)
86
+
87
+ btn.click(fn=generate_track_by_prompt, inputs=[email, prompt, duration, is_loop], outputs=[out, result_msg, tags])
88
+ gr.HTML('''
89
+ <div class="footer" style="text-align: center; max-width: 700px; margin: 0 auto;">
90
+ <p>Demo by <a href="https://huggingface.co/Mubert" style="text-decoration: underline;" target="_blank">Mubert</a>
91
+ </p>
92
+ </div>
93
+ ''')
94
+
95
+ block.launch()