vovahimself commited on
Commit
5fb0595
1 Parent(s): e789c49
Files changed (1) hide show
  1. app.py +101 -4
app.py CHANGED
@@ -2,9 +2,106 @@ from transformers import JukeboxModel , JukeboxTokenizer
2
  from transformers.models.jukebox import convert_jukebox
3
 
4
  import gradio as gr
 
5
 
6
- def greet(name):
7
- return "Hello " + name + "!!"
 
 
 
 
 
8
 
9
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
10
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers.models.jukebox import convert_jukebox
3
 
4
  import gradio as gr
5
+ import torch as t
6
 
7
+ model_id = 'openai/jukebox-1b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']
8
+ sample_rate = 44100
9
+ total_duration_in_seconds = 200
10
+ raw_to_tokens = 128
11
+ chunk_size = 32
12
+ max_batch_size = 16
13
+ cache_path = '~/.cache/'
14
 
15
+ def tokens_to_seconds(tokens, level = 2):
16
+
17
+ global sample_rate, raw_to_tokens
18
+ return tokens * raw_to_tokens / sample_rate / 4 ** (2 - level)
19
+
20
+ def seconds_to_tokens(sec, level = 2):
21
+
22
+ global sample_rate, raw_to_tokens, chunk_size
23
+
24
+ tokens = sec * sample_rate // raw_to_tokens
25
+ tokens = ( (tokens // chunk_size) + 1 ) * chunk_size
26
+
27
+ # For levels 1 and 0, multiply by 4 and 16 respectively
28
+ tokens *= 4 ** (2 - level)
29
+
30
+ return int(tokens)
31
+
32
+ # Init is ran on server startup
33
+ # Load your model to GPU as a global variable here using the variable name "model"
34
+ def init():
35
+ global model
36
+
37
+ print(f"Loading model from/to {cache_path}...")
38
+ model = JukeboxModel.from_pretrained(
39
+ model_id,
40
+ device_map = "auto",
41
+ torch_dtype = t.float16,
42
+ cache_dir = f"{cache_path}/jukebox/models",
43
+ resume_download = True,
44
+ min_duration = 0
45
+ ).eval()
46
+ print("Model loaded: ", model)
47
+
48
+ # Inference is ran for every server call
49
+ # Reference your preloaded global model variable here.
50
+ def inference(artist, genres, lyrics):
51
+ global model, zs
52
+
53
+ n_samples = 4
54
+ generation_length = seconds_to_tokens(1)
55
+ offset = 0
56
+ level = 0
57
+
58
+ model.total_length = seconds_to_tokens(total_duration_in_seconds)
59
+
60
+ sampling_kwargs = dict(
61
+ temp = 0.98,
62
+ chunk_size = chunk_size,
63
+ )
64
+
65
+ metas = dict(
66
+ artist = artist,
67
+ genres = genres,
68
+ lyrics = lyrics,
69
+ )
70
+
71
+ labels = JukeboxTokenizer.from_pretrained(model_id)(**metas)['input_ids'][level].repeat(n_samples, 1).cuda()
72
+ print(f"Labels: {labels.shape}")
73
+
74
+ zs = [ t.zeros(n_samples, 0, dtype=t.long, device='cuda') for _ in range(3) ]
75
+ print(f"Zs: {[z.shape for z in zs]}")
76
+
77
+ zs = model.sample_partial_window(
78
+ zs, labels, offset, sampling_kwargs, level = level, tokens_to_sample = generation_length, max_batch_size = max_batch_size
79
+ )
80
+ print(f"Zs after sampling: {[z.shape for z in zs]}")
81
+
82
+ # Convert to numpy array
83
+ return zs.cpu().numpy()
84
+
85
+
86
+ with gr.Blocks() as ui:
87
+
88
+ # Define UI components
89
+ title = gr.Textbox(lines=1, label="Title")
90
+ artist = gr.Textbox(lines=1, label="Artist")
91
+ genres = gr.Textbox(lines=1, label="Genre(s)", placeholder="Separate with spaces")
92
+ lyrics = gr.Textbox(lines=5, label="Lyrics", placeholder="Shift+Enter for new line")
93
+ submit = gr.Button(label="Generate")
94
+
95
+ output_zs = gr.Dataframe(label="zs")
96
+
97
+ submit.click(
98
+ inference,
99
+ inputs = [ artist, genres, lyrics ],
100
+ outputs = output_zs,
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+
105
+ init()
106
+
107
+ gr.launch()