Spaces:
Running
Running
flosstradamus
commited on
Commit
•
ca26365
1
Parent(s):
502807d
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ import numpy as np
|
|
11 |
import re
|
12 |
import requests
|
13 |
import time
|
|
|
14 |
|
15 |
# Import necessary functions and classes
|
16 |
from utils import load_t5, load_clap
|
@@ -69,9 +70,10 @@ def unload_current_model():
|
|
69 |
global global_model, current_model_name
|
70 |
if global_model is not None:
|
71 |
del global_model
|
72 |
-
torch.cuda.empty_cache()
|
73 |
global_model = None
|
74 |
current_model_name = None
|
|
|
|
|
75 |
|
76 |
def load_model(model_name, device, model_url=None):
|
77 |
global global_model, current_model_name
|
@@ -121,8 +123,7 @@ def load_model(model_name, device, model_url=None):
|
|
121 |
load_time = end_time - start_time
|
122 |
return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
|
123 |
except Exception as e:
|
124 |
-
|
125 |
-
current_model_name = None
|
126 |
print(f"Error loading model {model_name}: {str(e)}")
|
127 |
return f"Failed to load model: {model_name}. Error: {str(e)}"
|
128 |
|
@@ -229,6 +230,11 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
|
|
229 |
|
230 |
all_waveforms.append(waveform)
|
231 |
|
|
|
|
|
|
|
|
|
|
|
232 |
# Concatenate all waveforms
|
233 |
final_waveform = np.concatenate(all_waveforms)
|
234 |
|
|
|
11 |
import re
|
12 |
import requests
|
13 |
import time
|
14 |
+
import gc
|
15 |
|
16 |
# Import necessary functions and classes
|
17 |
from utils import load_t5, load_clap
|
|
|
70 |
global global_model, current_model_name
|
71 |
if global_model is not None:
|
72 |
del global_model
|
|
|
73 |
global_model = None
|
74 |
current_model_name = None
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
gc.collect()
|
77 |
|
78 |
def load_model(model_name, device, model_url=None):
|
79 |
global global_model, current_model_name
|
|
|
123 |
load_time = end_time - start_time
|
124 |
return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
|
125 |
except Exception as e:
|
126 |
+
unload_current_model()
|
|
|
127 |
print(f"Error loading model {model_name}: {str(e)}")
|
128 |
return f"Failed to load model: {model_name}. Error: {str(e)}"
|
129 |
|
|
|
230 |
|
231 |
all_waveforms.append(waveform)
|
232 |
|
233 |
+
# Clear some memory after each segment
|
234 |
+
del images, latents, mel_spectrogram, x_i
|
235 |
+
torch.cuda.empty_cache()
|
236 |
+
gc.collect()
|
237 |
+
|
238 |
# Concatenate all waveforms
|
239 |
final_waveform = np.concatenate(all_waveforms)
|
240 |
|