File size: 2,429 Bytes
33216f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e29651b
 
 
33216f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import torch
from pyannote.audio import Pipeline

def extract_files(files):
    filepaths = [file.name for file in files]
    return filepaths

class Diarizer:
    def __init__(self, conf):
        self.conf = conf
        self.pipeline = self.pyannote_pipeline()
        
    def pyannote_pipeline(self):
        pipeline = Pipeline.from_pretrained(
            self.conf["model"]["diarizer"],
            use_auth_token=os.environ["HUGGINGFACE_TOKEN"]
        )
        return pipeline

    def get_pipeline(self):
        return self.pipeline
    
    def add_device(self, pipeline):
        """Offloaded to allow for best timing when working with GPUs"""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pipeline.to(device)
        return pipeline

    def diarize_audio(self, temp_file, num_speakers):
        pipeline = self.add_device(self.pipeline)
        diarization = pipeline(temp_file, num_speakers=num_speakers)
        # os.remove(temp_file)  # Uncomment if you want to remove the temp file after processing
        return str(diarization)

    def extract_seconds(self, timestamp):
        h, m, s = map(float, timestamp.split(':'))
        return 3600 * h + 60 * m + s
    
    def generate_labels_from_diarization(self, diarized_output):
        labels_path = 'labels.txt'
        lines = diarized_output.strip().split('\n')
        plaintext = ""
        for line in lines:
            try:
                parts = line.strip()[1:-1].split(' --> ')
                if len(parts) == 2:
                    label = line.split()[-1].strip()

                    start_seconds = self.extract_seconds(parts[0].strip())
                    end_seconds = self.extract_seconds(parts[1].split(']')[0].strip())
                    plaintext += f"{start_seconds}\t{end_seconds}\t{label}\n"
                else:
                    raise ValueError("Unexpected format in diarized output")
            except Exception as e:
                print(f"Error processing line: '{line.strip()}'. Error: {e}")
    
        with open(labels_path, "w") as file:
            file.write(plaintext)
            
        return labels_path
    
    def run(self, temp_file, num_speakers):
        diarization_result = self.diarize_audio(temp_file, num_speakers) 
        label_file = self.generate_labels_from_diarization(diarization_result)
        return diarization_result, label_file