File size: 3,233 Bytes
f0a085b 4691b00 f0a085b 4691b00 f0a085b b825d1e f0a085b b825d1e f0a085b 21fcf42 f0a085b 8b09827 f0a085b 6e2b473 21fcf42 f0a085b b825d1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional
import numpy as np
import sherpa_onnx
from model import sample_rate
@dataclass
class Segment:
start: float
duration: float
text: str = ""
@property
def end(self):
return self.start + self.duration
def __str__(self):
s = f"0{timedelta(seconds=self.start)}"[:-3]
s += " --> "
s += f"0{timedelta(seconds=self.end)}"[:-3]
s = s.replace(".", ",")
s += "\n"
s += self.text
return s
def decode(
recognizer: sherpa_onnx.OfflineRecognizer,
vad: sherpa_onnx.VoiceActivityDetector,
punct: Optional[sherpa_onnx.OfflinePunctuation],
filename: str,
) -> str:
ffmpeg_cmd = [
"ffmpeg",
"-i",
filename,
"-f",
"s16le",
"-acodec",
"pcm_s16le",
"-ac",
"1",
"-ar",
str(sample_rate),
"-",
]
process = subprocess.Popen(
ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
)
frames_per_read = int(sample_rate * 100) # 100 second
window_size = 512
buffer = []
segment_list = []
logging.info("Started!")
while True:
# *2 because int16_t has two bytes
data = process.stdout.read(frames_per_read * 2)
if not data:
break
samples = np.frombuffer(data, dtype=np.int16)
samples = samples.astype(np.float32) / 32768
buffer = np.concatenate([buffer, samples])
while len(buffer) > window_size:
vad.accept_waveform(buffer[:window_size])
buffer = buffer[window_size:]
streams = []
segments = []
while not vad.empty():
segment = Segment(
start=vad.front.start / sample_rate,
duration=len(vad.front.samples) / sample_rate,
)
segments.append(segment)
stream = recognizer.create_stream()
stream.accept_waveform(sample_rate, vad.front.samples)
streams.append(stream)
vad.pop()
for s in streams:
recognizer.decode_stream(s)
for seg, stream in zip(segments, streams):
seg.text = stream.result.text.strip()
if punct is not None:
seg.text = punct.add_punctuation(seg.text)
segment_list.append(seg)
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
|