import tempfile from typing import Tuple import numpy as np import soundfile as sf import torch from pathlib import Path from model import convNet from preprocess import Audio, fft_and_melscale from synthesize import create_tja, detect, synthesize def trim_silence(data: np.ndarray, sr: int): start = 0 end = len(data) - 1 while start < len(data) and np.abs(data[start]) < 0.2: start += 1 while end > 0 and np.abs(data[end]) < 0.1: end -= 1 start = max(start - sr * 3, 0) end = min(end + sr * 3, len(data)) print( f"Trimming {start/sr} seconds from the start and {end/sr} seconds from the end" ) data = data[start:end] return data class ODCNN: def __init__(self, don_model: str, ka_model: str, device: torch.device = "cpu"): donNet = convNet() donNet = donNet.to(device) donNet.load_state_dict(torch.load(don_model, map_location="cpu")) self.donNet = donNet kaNet = convNet() kaNet = kaNet.to(device) kaNet.load_state_dict(torch.load(ka_model, map_location="cpu")) self.kaNet = kaNet self.device = device def run(self, file: str, delta=0.05, trim=True) -> Tuple[str, str]: data, sr = sf.read(file, always_2d=True) song = Audio(data, sr) song.data = song.data.mean(axis=1) if trim: song.data = trim_silence(song.data, sr) song.feats = fft_and_melscale( song, nhop=512, nffts=[1024, 2048, 4096], mel_nband=80, mel_freqlo=27.5, mel_freqhi=16000.0, ) don_inference = self.donNet.infer(song.feats, self.device, minibatch=4192) don_inference = np.reshape(don_inference, (-1)) ka_inference = self.kaNet.infer(song.feats, self.device, minibatch=4192) ka_inference = np.reshape(ka_inference, (-1)) easy_detection = detect(don_inference, ka_inference, delta=0.25) normal_detection = detect(don_inference, ka_inference, delta=0.2) hard_detection = detect(don_inference, ka_inference, delta=0.15) oni_detection = detect(don_inference, ka_inference, delta=0.075) ura_detection = detect(don_inference, ka_inference, delta=delta) synthesized_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name synthesize(*hard_detection, song, synthesized_path) file = Path(file) tja = create_tja( song, timestamps=[ easy_detection, normal_detection, hard_detection, oni_detection, ura_detection, ], title=file.stem, wave=file.name, ) return synthesized_path, tja