File size: 3,668 Bytes
4aa12d0
95261ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa12d0
95261ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa12d0
 
 
 
95261ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import multiprocessing
from src.vad import AbstractTranscription, TranscriptionConfig
from src.whisperContainer import WhisperCallback

from multiprocessing import Pool

from typing import List
import os

class ParallelTranscriptionConfig(TranscriptionConfig):
    def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
        super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
        self.device_id = device_id
        self.override_timestamps = override_timestamps
    
class ParallelTranscription(AbstractTranscription):
    def __init__(self, sampling_rate: int = 16000):
        super().__init__(sampling_rate=sampling_rate)

    
    def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str]):
        # First, get the timestamps for the original audio
        merged = transcription.get_merged_timestamps(audio, config)

        # Split into a list for each device
        # TODO: Split by time instead of by number of chunks
        merged_split = self._chunks(merged, len(merged) // len(devices))

        # Parameters that will be passed to the transcribe function
        parameters = []
        segment_index = config.initial_segment_index

        for i in range(len(devices)):
            device_segment_list = merged_split[i]

            # Create a new config with the given device ID
            device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
            segment_index += len(device_segment_list)

            parameters.append([audio, whisperCallable, device_config]);

        merged = {
            'text': '',
            'segments': [],
            'language': None
        }

        # Spawn a separate process for each device
        context = multiprocessing.get_context('spawn')

        with context.Pool(len(devices)) as p:
            # Run the transcription in parallel
            results = p.starmap(self.transcribe, parameters)

            for result in results:
                # Merge the results
                if (result['text'] is not None):
                    merged['text'] += result['text']
                if (result['segments'] is not None):
                    merged['segments'].extend(result['segments'])
                if (result['language'] is not None):
                    merged['language'] = result['language']

        return merged

    def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
        return []

    def get_merged_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
        # Override timestamps that will be processed
        if (config.override_timestamps is not None):
            print("Using override timestamps of size " + str(len(config.override_timestamps)))
            return config.override_timestamps
        return super().get_merged_timestamps(audio, config)

    def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
        # Override device ID
        if (config.device_id is not None):
            print("Using device " + config.device_id)
            os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
        return super().transcribe(audio, whisperCallable, config)

    def _chunks(self, lst, n):
        """Yield successive n-sized chunks from lst."""
        return [lst[i:i + n] for i in range(0, len(lst), n)]