tykiww commited on
Commit
6386953
1 Parent(s): c19e19e

Create asr.py

Browse files
Files changed (1) hide show
  1. services/asr.py +24 -0
services/asr.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+
4
+ class Transcriber:
5
+ def __init__(self, conf):
6
+ self.conf = conf
7
+ self.pipeline = self.asr_pipeline()
8
+
9
+ def asr_pipeline(self):
10
+ return pipeline(
11
+ self.conf["model"]["asr"]["type"],
12
+ model=self.conf["model"]["asr"]["transcriber"],
13
+ device=0 if torch.cuda.is_available() else -1 # Use 0 for GPU, -1 for CPU
14
+ )
15
+
16
+ def run(self, file_path):
17
+ kwargs = {"max_new_tokens": self.conf["model"]["asr"]["max_new_tokens"]}
18
+ output = self.pipeline(
19
+ file_path,
20
+ generate_kwargs=kwargs,
21
+ return_timestamps=True,
22
+ )
23
+ print(output)
24
+ return output.get("chunks", output) # Use .get to avoid key errors