Politrees commited on
Commit
75d3f3c
1 Parent(s): 610bb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -1,10 +1,14 @@
1
  import os
 
2
  import shutil
3
  import logging
4
  import gradio as gr
5
 
6
  from audio_separator.separator import Separator
7
 
 
 
 
8
  # Model lists
9
  ROFORMER_MODELS = {
10
  'BS-Roformer-Viperx-1297.ckpt': 'model_bs_roformer_ep_317_sdr_12.9755.ckpt',
@@ -148,6 +152,7 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
148
  output_format=out_format,
149
  normalization_threshold=norm_thresh,
150
  amplification_threshold=amp_thresh,
 
151
  mdxc_params={
152
  "batch_size": 1,
153
  "segment_size": seg_size,
@@ -183,6 +188,7 @@ def mdx23c_separator(audio, model, seg_size, override_seg_size, overlap, pitch_s
183
  output_format=out_format,
184
  normalization_threshold=norm_thresh,
185
  amplification_threshold=amp_thresh,
 
186
  mdxc_params={
187
  "batch_size": 1,
188
  "segment_size": seg_size,
@@ -218,6 +224,7 @@ def mdx_separator(audio, model, hop_length, seg_size, overlap, denoise, model_di
218
  output_format=out_format,
219
  normalization_threshold=norm_thresh,
220
  amplification_threshold=amp_thresh,
 
221
  mdx_params={
222
  "batch_size": 1,
223
  "hop_length": hop_length,
@@ -253,6 +260,7 @@ def vr_separator(audio, model, window_size, aggression, tta, post_process, post_
253
  output_format=out_format,
254
  normalization_threshold=norm_thresh,
255
  amplification_threshold=amp_thresh,
 
256
  vr_params={
257
  "batch_size": 1,
258
  "window_size": window_size,
@@ -290,6 +298,7 @@ def demucs_separator(audio, model, seg_size, shifts, overlap, segments_enabled,
290
  output_format=out_format,
291
  normalization_threshold=norm_thresh,
292
  amplification_threshold=amp_thresh,
 
293
  demucs_params={
294
  "segment_size": seg_size,
295
  "shifts": shifts,
 
1
  import os
2
+ import torch
3
  import shutil
4
  import logging
5
  import gradio as gr
6
 
7
  from audio_separator.separator import Separator
8
 
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ use_autocast = device == "cuda"
11
+
12
  # Model lists
13
  ROFORMER_MODELS = {
14
  'BS-Roformer-Viperx-1297.ckpt': 'model_bs_roformer_ep_317_sdr_12.9755.ckpt',
 
152
  output_format=out_format,
153
  normalization_threshold=norm_thresh,
154
  amplification_threshold=amp_thresh,
155
+ use_autocast=use_autocast,
156
  mdxc_params={
157
  "batch_size": 1,
158
  "segment_size": seg_size,
 
188
  output_format=out_format,
189
  normalization_threshold=norm_thresh,
190
  amplification_threshold=amp_thresh,
191
+ use_autocast=use_autocast,
192
  mdxc_params={
193
  "batch_size": 1,
194
  "segment_size": seg_size,
 
224
  output_format=out_format,
225
  normalization_threshold=norm_thresh,
226
  amplification_threshold=amp_thresh,
227
+ use_autocast=use_autocast,
228
  mdx_params={
229
  "batch_size": 1,
230
  "hop_length": hop_length,
 
260
  output_format=out_format,
261
  normalization_threshold=norm_thresh,
262
  amplification_threshold=amp_thresh,
263
+ use_autocast=use_autocast,
264
  vr_params={
265
  "batch_size": 1,
266
  "window_size": window_size,
 
298
  output_format=out_format,
299
  normalization_threshold=norm_thresh,
300
  amplification_threshold=amp_thresh,
301
+ use_autocast=use_autocast,
302
  demucs_params={
303
  "segment_size": seg_size,
304
  "shifts": shifts,