szin94 commited on
Commit
7e07e49
·
1 Parent(s): e1e5da9
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -113,9 +113,15 @@ def run(duration, resolution, pitch, glissando, vibrato, stiffness, tension, plu
113
  checkpoint = filter_state_dict(checkpoint['state_dict'])
114
  model = Synthesizer(**configs)
115
  model.load_state_dict(checkpoint)
 
 
116
 
117
  params, f_0, u_0 = get_data( \
118
  duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
 
 
 
 
119
 
120
  with torch.no_grad():
121
  ut, mode_input, mode_output = model(params, f_0, u_0)
 
113
  checkpoint = filter_state_dict(checkpoint['state_dict'])
114
  model = Synthesizer(**configs)
115
  model.load_state_dict(checkpoint)
116
+ if torch.cuda.is_available():
117
+ model = model.cuda()
118
 
119
  params, f_0, u_0 = get_data( \
120
  duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
121
+ if torch.cuda.is_available():
122
+ params = [p.cuda() for p in params]
123
+ f_0 = f_0.cuda()
124
+ u_0 = u_0.cuda()
125
 
126
  with torch.no_grad():
127
  ut, mode_input, mode_output = model(params, f_0, u_0)