Spaces:
Build error
Build error
Create run_diarization.py
Browse files- run_diarization.py +45 -0
run_diarization.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pyannote.audio import Pipeline
|
2 |
+
pipeline = Pipeline.from_pretrained(
|
3 |
+
"pyannote/speaker-diarization-3.1")
|
4 |
+
|
5 |
+
# send pipeline to GPU (when available)
|
6 |
+
import os
|
7 |
+
from pydub import AudioSegment
|
8 |
+
import torch
|
9 |
+
pipeline.to(torch.device("cuda"))
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
def run_diarization(input_file):
|
13 |
+
# apply pretrained pipeline
|
14 |
+
diarization = pipeline(input_file)
|
15 |
+
rttm_out=diarization.to_rttm()
|
16 |
+
rttm_file= open(Path(input_file).stem+'.rttm','w')
|
17 |
+
rttm_file.write(rttm_out)
|
18 |
+
rttm_file.close
|
19 |
+
diarization
|
20 |
+
diarization_result= []
|
21 |
+
# print the result
|
22 |
+
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
23 |
+
#diarization_results.append([{turn.start:.1f},{turn.end:.1f},{speaker}])
|
24 |
+
print_out=f"{turn.start:.1f} {turn.end:.1f} {speaker}"
|
25 |
+
diarization_result.append(print_out.split(' '))
|
26 |
+
print(f"start={turn.start:.1f}s stop={turn.end:.1f}s speaker_{speaker}")
|
27 |
+
audio_segments=[]
|
28 |
+
for segment in diarization_result:
|
29 |
+
start= float(segment[0])
|
30 |
+
end= float(segment[1])
|
31 |
+
speaker= segment[2]
|
32 |
+
start_time= start*1000
|
33 |
+
end_time= end*1000
|
34 |
+
name=speaker+'_'+'['+str(start)+'_'+str(end)+']'
|
35 |
+
audio_segments.append([name+'.wav',start_time,end_time])
|
36 |
+
|
37 |
+
sound= AudioSegment.from_wav(input_file)
|
38 |
+
output_directory=Path(input_file).stem+"_segments"
|
39 |
+
os.mkdir(output_directory)
|
40 |
+
counter=1
|
41 |
+
for interval in audio_segments:
|
42 |
+
extract= sound[interval[1]:interval[2]]
|
43 |
+
segment_name=output_directory+'/'+str(counter)+'_'+interval[0]
|
44 |
+
extract.export(segment_name, format='wav')
|
45 |
+
counter=counter+1
|