Spaces:
Running
Running
import os | |
import torch | |
from pyannote.audio import Pipeline | |
def extract_files(files): | |
filepaths = [file.name for file in files] | |
return filepaths | |
class Diarizer: | |
def __init__(self, conf): | |
self.conf = conf | |
self.pipeline = self.pyannote_pipeline() | |
def pyannote_pipeline(self): | |
pipeline = Pipeline.from_pretrained( | |
self.conf["model"]["diarizer"], | |
use_auth_token=os.environ["HUGGINGFACE_TOKEN"] | |
) | |
return pipeline | |
def add_device(self, pipeline): | |
"""Offloaded to allow for best timing when working with GPUs""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipeline.to(device) | |
return pipeline | |
def diarize_audio(self, temp_file, num_speakers): | |
pipeline = self.add_device(self.pipeline) | |
diarization = pipeline(temp_file, num_speakers=num_speakers) | |
# os.remove(temp_file) # Uncomment if you want to remove the temp file after processing | |
return str(diarization) | |
def extract_seconds(self, timestamp): | |
h, m, s = map(float, timestamp.split(':')) | |
return 3600 * h + 60 * m + s | |
def generate_labels_from_diarization(self, diarized_output): | |
labels_path = 'labels.txt' | |
lines = diarized_output.strip().split('\n') | |
plaintext = "" | |
for line in lines: | |
try: | |
parts = line.strip()[1:-1].split(' --> ') | |
if len(parts) == 2: | |
label = line.split()[-1].strip() | |
start_seconds = self.extract_seconds(parts[0].strip()) | |
end_seconds = self.extract_seconds(parts[1].split(']')[0].strip()) | |
plaintext += f"{start_seconds}\t{end_seconds}\t{label}\n" | |
else: | |
raise ValueError("Unexpected format in diarized output") | |
except Exception as e: | |
print(f"Error processing line: '{line.strip()}'. Error: {e}") | |
with open(labels_path, "w") as file: | |
file.write(plaintext) | |
return labels_path | |
def run(self, temp_file, num_speakers): | |
diarization_result = self.diarize_audio(temp_file, num_speakers) | |
label_file = self.generate_labels_from_diarization(diarization_result) | |
return diarization_result, label_file | |