File size: 3,789 Bytes
f0a085b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4691b00
f0a085b
 
 
4691b00
f0a085b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b825d1e
f0a085b
b825d1e
f0a085b
 
 
 
 
 
 
 
 
21fcf42
f0a085b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfd7673
 
f0a085b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b09827
 
 
f0a085b
6e2b473
a702d26
 
 
 
f5abfaa
cfd7673
 
 
 
815053b
 
cfd7673
21fcf42
 
f0a085b
c98790c
cfd7673
 
f0a085b
cfd7673
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright      2022-2023  Xiaomi Corp.        (authors: Fangjun Kuang)
#
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional

import numpy as np
import sherpa_onnx

from model import sample_rate


@dataclass
class Segment:
    start: float
    duration: float
    text: str = ""

    @property
    def end(self):
        return self.start + self.duration

    def __str__(self):
        s = f"0{timedelta(seconds=self.start)}"[:-3]
        s += " --> "
        s += f"0{timedelta(seconds=self.end)}"[:-3]
        s = s.replace(".", ",")
        s += "\n"
        s += self.text
        return s


def decode(
    recognizer: sherpa_onnx.OfflineRecognizer,
    vad: sherpa_onnx.VoiceActivityDetector,
    punct: Optional[sherpa_onnx.OfflinePunctuation],
    filename: str,
) -> str:
    ffmpeg_cmd = [
        "ffmpeg",
        "-i",
        filename,
        "-f",
        "s16le",
        "-acodec",
        "pcm_s16le",
        "-ac",
        "1",
        "-ar",
        str(sample_rate),
        "-",
    ]

    process = subprocess.Popen(
        ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
    )

    frames_per_read = int(sample_rate * 100)  # 100 second

    window_size = 512

    buffer = []

    segment_list = []

    logging.info("Started!")

    all_text = []

    while True:
        # *2 because int16_t has two bytes
        data = process.stdout.read(frames_per_read * 2)
        if not data:
            break

        samples = np.frombuffer(data, dtype=np.int16)
        samples = samples.astype(np.float32) / 32768

        buffer = np.concatenate([buffer, samples])
        while len(buffer) > window_size:
            vad.accept_waveform(buffer[:window_size])
            buffer = buffer[window_size:]

        streams = []
        segments = []
        while not vad.empty():
            segment = Segment(
                start=vad.front.start / sample_rate,
                duration=len(vad.front.samples) / sample_rate,
            )
            segments.append(segment)

            stream = recognizer.create_stream()
            stream.accept_waveform(sample_rate, vad.front.samples)

            streams.append(stream)

            vad.pop()

        for s in streams:
            recognizer.decode_stream(s)

        for seg, stream in zip(segments, streams):
            seg.text = stream.result.text.strip()
            if len(seg.text) == 0:
                logging.info("Skip empty segment")
                continue

            if len(all_text) == 0:
                all_text.append(seg.text)
            elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1:
                all_text.append(" ")
                all_text.append(seg.text)
            else:
                all_text.append(seg.text)

            if punct is not None:
                seg.text = punct.add_punctuation(seg.text)
            segment_list.append(seg)
    all_text = "".join(all_text)
    if punct is not None:
        all_text = punct.add_punctuation(all_text)

    return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text