DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame
5.65 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import argparse
import sys
import sounddevice as sd
import torch
from .demucs import DemucsStreamer
from .pretrained import add_model_flags, get_model
from .utils import bold
def get_parser():
parser = argparse.ArgumentParser(
"denoiser.live",
description="Performs live speech enhancement, reading audio from "
"the default mic (or interface specified by --in) and "
"writing the enhanced version to 'Soundflower (2ch)' "
"(or the interface specified by --out)."
)
parser.add_argument(
"-i", "--in", dest="in_",
help="name or index of input interface.")
parser.add_argument(
"-o", "--out", default="Soundflower (2ch)",
help="name or index of output interface.")
add_model_flags(parser)
parser.add_argument(
"--sample_rate", type=int, default=16_000,
help="Sample rate")
parser.add_argument(
"--no_compressor", action="store_false", dest="compressor",
help="Deactivate compressor on output, might lead to clipping.")
parser.add_argument(
"--device", default="cpu")
parser.add_argument(
"--dry", type=float, default=0.04,
help="Dry/wet knob, between 0 and 1. 0=maximum noise removal "
"but it might cause distortions. Default is 0.04")
parser.add_argument(
"-t", "--num_threads", type=int,
help="Number of threads. If you have DDR3 RAM, setting -t 1 can "
"improve performance.")
parser.add_argument(
"-f", "--num_frames", type=int, default=1,
help="Number of frames to process at once. Larger values increase "
"the overall lag, but will improve speed.")
return parser
def parse_audio_device(device):
if device is None:
return device
try:
return int(device)
except ValueError:
return device
def query_devices(device, kind):
try:
caps = sd.query_devices(device, kind=kind)
except ValueError:
message = bold(f"Invalid {kind} audio interface {device}.\n")
message += (
"If you are on Mac OS X, try installing Soundflower "
"(https://github.com/mattingalls/Soundflower).\n"
"You can list available interfaces with `python3 -m sounddevice` on Linux and OS X, "
"and `python.exe -m sounddevice` on Windows. You must have at least one loopback "
"audio interface to use this.")
print(message, file=sys.stderr)
sys.exit(1)
return caps
def main():
args = get_parser().parse_args()
if args.num_threads:
torch.set_num_threads(args.num_threads)
model = get_model(args).to(args.device)
model.eval()
print("Model loaded.")
streamer = DemucsStreamer(model, dry=args.dry, num_frames=args.num_frames)
device_in = parse_audio_device(args.in_)
caps = query_devices(device_in, "input")
channels_in = min(caps['max_input_channels'], 2)
stream_in = sd.InputStream(
device=device_in,
samplerate=args.sample_rate,
channels=channels_in)
device_out = parse_audio_device(args.out)
caps = query_devices(device_out, "output")
channels_out = min(caps['max_output_channels'], 2)
stream_out = sd.OutputStream(
device=device_out,
samplerate=args.sample_rate,
channels=channels_out)
stream_in.start()
stream_out.start()
first = True
current_time = 0
last_log_time = 0
last_error_time = 0
cooldown_time = 2
log_delta = 10
sr_ms = args.sample_rate / 1000
stride_ms = streamer.stride / sr_ms
print(f"Ready to process audio, total lag: {streamer.total_length / sr_ms:.1f}ms.")
while True:
try:
if current_time > last_log_time + log_delta:
last_log_time = current_time
tpf = streamer.time_per_frame * 1000
rtf = tpf / stride_ms
print(f"time per frame: {tpf:.1f}ms, ", end='')
print(f"RTF: {rtf:.1f}")
streamer.reset_time_per_frame()
length = streamer.total_length if first else streamer.stride
first = False
current_time += length / args.sample_rate
frame, overflow = stream_in.read(length)
frame = torch.from_numpy(frame).mean(dim=1).to(args.device)
with torch.no_grad():
out = streamer.feed(frame[None])[0]
if not out.numel():
continue
if args.compressor:
out = 0.99 * torch.tanh(out)
out = out[:, None].repeat(1, channels_out)
mx = out.abs().max().item()
if mx > 1:
print("Clipping!!")
out.clamp_(-1, 1)
out = out.cpu().numpy()
underflow = stream_out.write(out)
if overflow or underflow:
if current_time >= last_error_time + cooldown_time:
last_error_time = current_time
tpf = 1000 * streamer.time_per_frame
print(f"Not processing audio fast enough, time per frame is {tpf:.1f}ms "
f"(should be less than {stride_ms:.1f}ms).")
except KeyboardInterrupt:
print("Stopping")
break
stream_out.stop()
stream_in.stop()
if __name__ == "__main__":
main()