cuda
Browse files
app.py
CHANGED
@@ -113,9 +113,15 @@ def run(duration, resolution, pitch, glissando, vibrato, stiffness, tension, plu
|
|
113 |
checkpoint = filter_state_dict(checkpoint['state_dict'])
|
114 |
model = Synthesizer(**configs)
|
115 |
model.load_state_dict(checkpoint)
|
|
|
|
|
116 |
|
117 |
params, f_0, u_0 = get_data( \
|
118 |
duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
|
|
|
|
|
|
|
|
|
119 |
|
120 |
with torch.no_grad():
|
121 |
ut, mode_input, mode_output = model(params, f_0, u_0)
|
|
|
113 |
checkpoint = filter_state_dict(checkpoint['state_dict'])
|
114 |
model = Synthesizer(**configs)
|
115 |
model.load_state_dict(checkpoint)
|
116 |
+
if torch.cuda.is_available():
|
117 |
+
model = model.cuda()
|
118 |
|
119 |
params, f_0, u_0 = get_data( \
|
120 |
duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
|
121 |
+
if torch.cuda.is_available():
|
122 |
+
params = [p.cuda() for p in params]
|
123 |
+
f_0 = f_0.cuda()
|
124 |
+
u_0 = u_0.cuda()
|
125 |
|
126 |
with torch.no_grad():
|
127 |
ut, mode_input, mode_output = model(params, f_0, u_0)
|