avfranco commited on
Commit
3733074
1 Parent(s): 5777262

torch.cuda device automatic detection

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -18,9 +18,12 @@ def asr_transcriber(audio_file):
18
 
19
  audio_file_wav = audio_converter(audio_file)
20
 
21
- device_id = "cpu"
22
- flash = False
23
-
 
 
 
24
  # Initialize the ASR pipeline
25
  pipe = pipeline(
26
  "automatic-speech-recognition",
@@ -28,11 +31,7 @@ def asr_transcriber(audio_file):
28
  torch_dtype=torch.float16,
29
  device=device_id
30
  )
31
- if device_id == "mps":
32
- torch.mps.empty_cache()
33
- elif not flash:
34
- pipe.model = pipe.model.to_bettertransformer()
35
-
36
  ts = True
37
  language = None
38
  task = "transcribe"
 
18
 
19
  audio_file_wav = audio_converter(audio_file)
20
 
21
+ # Check for CUDA availability (GPU)
22
+ if torch.cuda.is_available():
23
+ device_id = torch.device('cuda')
24
+ else:
25
+ device_id = torch.device('cpu')
26
+
27
  # Initialize the ASR pipeline
28
  pipe = pipeline(
29
  "automatic-speech-recognition",
 
31
  torch_dtype=torch.float16,
32
  device=device_id
33
  )
34
+
 
 
 
 
35
  ts = True
36
  language = None
37
  task = "transcribe"