vovahimself commited on
Commit
659a5e1
1 Parent(s): 631e673

init fix + don't reinit + buttons

Browse files
Files changed (2) hide show
  1. app.py +11 -8
  2. jukwi-vqvae.ipynb +11 -8
app.py CHANGED
@@ -52,12 +52,15 @@ class Convert:
52
  def init():
53
  global model
54
 
55
- model = JukeboxVQVAE.from_pretrained(
56
- model_id,
57
- device_map = "auto",
58
- torch_dtype = t.float16,
59
- cache_dir = f"{cache_path}/jukebox/models"
60
- )
 
 
 
61
 
62
  def validate_tokens_list(tokens_list):
63
  # Make sure that:
@@ -85,8 +88,8 @@ with gr.Blocks() as ui:
85
  audio = gr.Audio(label='audio')
86
 
87
  # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
88
- gr.Button(label="Convert tokens to audio", primary=True).click(Convert.TokensFile.to_audio, tokens, audio)
89
- gr.Button(label="Convert audio to tokens", primary=False).click(Convert.Audio.to_tokens_file, audio, tokens)
90
 
91
  if __name__ == '__main__':
92
  init()
 
52
  def init():
53
  global model
54
 
55
+ try:
56
+ model
57
+ print("Model already initialized")
58
+ except NameError:
59
+ model = JukeboxVQVAE.from_pretrained(
60
+ model_id,
61
+ torch_dtype = t.float16,
62
+ cache_dir = f"{cache_path}/jukebox/models"
63
+ )
64
 
65
  def validate_tokens_list(tokens_list):
66
  # Make sure that:
 
88
  audio = gr.Audio(label='audio')
89
 
90
  # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)
91
+ gr.Button("Convert tokens to audio", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)
92
+ gr.Button("Convert audio to tokens", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)
93
 
94
  if __name__ == '__main__':
95
  init()
jukwi-vqvae.ipynb CHANGED
@@ -81,12 +81,15 @@
81
  "def init():\n",
82
  " global model\n",
83
  "\n",
84
- " model = JukeboxVQVAE.from_pretrained(\n",
85
- " model_id,\n",
86
- " device_map = \"auto\",\n",
87
- " torch_dtype = t.float16,\n",
88
- " cache_dir = f\"{cache_path}/jukebox/models\"\n",
89
- " )\n",
 
 
 
90
  "\n",
91
  "def validate_tokens_list(tokens_list):\n",
92
  " # Make sure that:\n",
@@ -114,8 +117,8 @@
114
  " audio = gr.Audio(label='audio')\n",
115
  "\n",
116
  " # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n",
117
- " gr.Button(label=\"Convert tokens to audio\", primary=True).click(Convert.TokensFile.to_audio, tokens, audio)\n",
118
- " gr.Button(label=\"Convert audio to tokens\", primary=False).click(Convert.Audio.to_tokens_file, audio, tokens)\n",
119
  "\n",
120
  "if __name__ == '__main__':\n",
121
  " init()\n",
 
81
  "def init():\n",
82
  " global model\n",
83
  "\n",
84
+ " try:\n",
85
+ " model\n",
86
+ " print(\"Model already initialized\")\n",
87
+ " except NameError:\n",
88
+ " model = JukeboxVQVAE.from_pretrained(\n",
89
+ " model_id,\n",
90
+ " torch_dtype = t.float16,\n",
91
+ " cache_dir = f\"{cache_path}/jukebox/models\"\n",
92
+ " )\n",
93
  "\n",
94
  "def validate_tokens_list(tokens_list):\n",
95
  " # Make sure that:\n",
 
117
  " audio = gr.Audio(label='audio')\n",
118
  "\n",
119
  " # Buttons to convert from music tokens to audio (primary) and vice versa (secondary)\n",
120
+ " gr.Button(\"Convert tokens to audio\", variant='primary').click(Convert.TokensFile.to_audio, tokens, audio)\n",
121
+ " gr.Button(\"Convert audio to tokens\", variant='secondary').click(Convert.Audio.to_tokens_file, audio, tokens)\n",
122
  "\n",
123
  "if __name__ == '__main__':\n",
124
  " init()\n",