File size: 3,285 Bytes
631e673
 
156d0fd
631e673
 
156d0fd
 
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
 
 
 
 
 
 
 
631e673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659a5e1
 
631e673
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# A simple gradio app that converts music tokens to and from audio using JukeboxVQVAE as the model and Gradio as the UI

import sys

import torch as t
from transformers import JukeboxVQVAE
import gradio as gr

model_id = 'openai/jukebox-5b-lyrics' #@param ['openai/jukebox-1b-lyrics', 'openai/jukebox-5b-lyrics']

if 'google.colab' in sys.modules:

  cache_path = '/content/drive/My Drive/jukebox-webui/_data/' #@param {type:"string"}
  # Connect to your Google Drive
  from google.colab import drive
  drive.mount('/content/drive')

else:

  cache_path = '~/.cache/'

class Convert:

  class TokenList:

    def to_tokens_file(tokens_list):
      # temporary random file name
      filename = f"tmp/{t.randint(0, 1000000)}.jt"
      t.save(validate_tokens_list(tokens_list), filename)
      return filename
    
    def to_audio(tokens_list):
      return model.decode(validate_tokens_list(tokens_list)[2:], start_level=2).squeeze(-1)
      # TODO: Implement converting other levels besides 2

  class TokensFile:

    def to_tokens_list(file):
      return validate_tokens_list(t.load(file))
    
    def to_audio(file):
      return Convert.TokenList.to_audio(Convert.TokensFile.to_tokens_list(file))
  
  class Audio:

    def to_tokens_list(audio):
      return model.encode(audio.unsqueeze(0), start_level=2)
      # (TODO: Generated by copilot, check if it works)

    def to_tokens_file(audio):
      return Convert.TokenList.to_tokens_file(Convert.Audio.to_tokens_list(audio))

def init():
  global model

  try:
    model
    print("Model already initialized")
  except NameError:
    model = JukeboxVQVAE.from_pretrained(
      model_id,
      torch_dtype = t.float16,
      cache_dir = f"{cache_path}/jukebox/models"
    )

def validate_tokens_list(tokens_list):
  # Make sure that:
  # - tokens_list is a list of exactly 3 torch tensors
  assert len(tokens_list) == 3, "Invalid file format: expecting a list of 3 tensors"

  # - each has the same number of dimensions
  assert len(tokens_list[0].shape) == len(tokens_list[1].shape) == len(tokens_list[2].shape), "Invalid file format: each tensor in the list should have the same number of dimensions"

  # - the shape along dimension 0 is the same
  assert tokens_list[0].shape[0] == tokens_list[1].shape[0] == tokens_list[2].shape[0], "Invalid file format: the shape along dimension 0 should be the same for all tensors in the list"

  # - the shape along dimension 1 increases (or stays the same) as we go from 0 to 2
  assert tokens_list[0].shape[1] >= tokens_list[1].shape[1] >= tokens_list[2].shape[1], "Invalid file format: the shape along dimension 1 should decrease (or stay the same) as we go from 0 to 2"
  
  return tokens_list


with gr.Blocks() as ui:

  # File input to upload or download the music tokens file
  tokens = gr.File(label='music_tokens_file')

  # Audio output to play or upload the generated audio
  audio = gr.Audio(label='audio')
  
  # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
  gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)
  gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)
  
if __name__ == '__main__':
  init()
  ui.launch()