flosstradamus commited on
Commit
ca26365
1 Parent(s): 502807d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -11,6 +11,7 @@ import numpy as np
11
  import re
12
  import requests
13
  import time
 
14
 
15
  # Import necessary functions and classes
16
  from utils import load_t5, load_clap
@@ -69,9 +70,10 @@ def unload_current_model():
69
  global global_model, current_model_name
70
  if global_model is not None:
71
  del global_model
72
- torch.cuda.empty_cache()
73
  global_model = None
74
  current_model_name = None
 
 
75
 
76
  def load_model(model_name, device, model_url=None):
77
  global global_model, current_model_name
@@ -121,8 +123,7 @@ def load_model(model_name, device, model_url=None):
121
  load_time = end_time - start_time
122
  return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
123
  except Exception as e:
124
- global_model = None
125
- current_model_name = None
126
  print(f"Error loading model {model_name}: {str(e)}")
127
  return f"Failed to load model: {model_name}. Error: {str(e)}"
128
 
@@ -229,6 +230,11 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
229
 
230
  all_waveforms.append(waveform)
231
 
 
 
 
 
 
232
  # Concatenate all waveforms
233
  final_waveform = np.concatenate(all_waveforms)
234
 
 
11
  import re
12
  import requests
13
  import time
14
+ import gc
15
 
16
  # Import necessary functions and classes
17
  from utils import load_t5, load_clap
 
70
  global global_model, current_model_name
71
  if global_model is not None:
72
  del global_model
 
73
  global_model = None
74
  current_model_name = None
75
+ torch.cuda.empty_cache()
76
+ gc.collect()
77
 
78
  def load_model(model_name, device, model_url=None):
79
  global global_model, current_model_name
 
123
  load_time = end_time - start_time
124
  return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
125
  except Exception as e:
126
+ unload_current_model()
 
127
  print(f"Error loading model {model_name}: {str(e)}")
128
  return f"Failed to load model: {model_name}. Error: {str(e)}"
129
 
 
230
 
231
  all_waveforms.append(waveform)
232
 
233
+ # Clear some memory after each segment
234
+ del images, latents, mel_spectrogram, x_i
235
+ torch.cuda.empty_cache()
236
+ gc.collect()
237
+
238
  # Concatenate all waveforms
239
  final_waveform = np.concatenate(all_waveforms)
240