humtrans / infer.py
hayaton0005's picture
Upload 11 files
c094356 verified
import torch
import torchaudio
import os
import soundfile as sf
import librosa
from utils import unpack_sequence, token_seg_list_to_midi
from train import LitTranscriber
from utils import rms_normalize_wav
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # backend/src を指す
PTH_PATH = os.path.join(BASE_DIR, "model.pth") # ✅ .pth に変更
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
args = {
"n_mels": 128,
"sample_rate": 16000,
"n_fft": 1024,
"hop_length": 128,
}
model = LitTranscriber(transcriber_args=args, lr=1e-4, lr_decay=0.99)
state_dict = torch.load(PTH_PATH, map_location=device) # ✅ .pthをロード
model.load_state_dict(state_dict)
#model.to(device) # ✅ デバイスに転送
model.eval()
return model
def convert_to_pcm_wav(input_path, output_path):
# librosaで読み込み(自動的にPCM形式に変換される)
y, sr = librosa.load(input_path, sr=16000, mono=True)
sf.write(output_path, y, sr)
def infer_midi_from_wav(input_wav_path: str) -> str:
model = load_model()
converted_path = os.path.join(BASE_DIR, "converted_input.wav")
convert_to_pcm_wav(input_wav_path, converted_path)
normalized_path = os.path.join(BASE_DIR, "tmp_normalized.wav")
rms_normalize_wav(converted_path, normalized_path, target_rms=0.1)
waveform, sr = torchaudio.load(normalized_path)
waveform = waveform.mean(0).to(device)
if sr != model.transcriber.sr:
waveform = torchaudio.functional.resample(
waveform, sr, model.transcriber.sr
).to(device)
with torch.no_grad():
output_tokens = model(waveform)
unpadded_tokens = unpack_sequence(output_tokens.cpu().numpy())
unpadded_tokens = [t[1:] for t in unpadded_tokens]
est_midi = token_seg_list_to_midi(unpadded_tokens)
midi_path = os.path.join(BASE_DIR, "output.mid")
est_midi.write(midi_path)
print(f"MIDI saved at: {midi_path}")
return midi_path