pop2piano / transformer_wrapper.py
sweetcocoa's picture
move to gradio
7a3b53b
raw
history blame
No virus
11.2 kB
import os
import random
import numpy as np
import librosa
import torch
import pytorch_lightning as pl
import soundfile as sf
from torch.nn.utils.rnn import pad_sequence
from transformers import T5Config, T5ForConditionalGeneration
from midi_tokenizer import MidiTokenizer, extrapolate_beat_times
from layer.input import LogMelSpectrogram, ConcatEmbeddingToMel
from preprocess.beat_quantizer import extract_rhythm, interpolate_beat_times
from utils.dsp import get_stereo
DEFAULT_COMPOSERS = {"various composer": 2052}
class TransformerWrapper(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.tokenizer = MidiTokenizer(config.tokenizer)
self.t5config = T5Config.from_pretrained("t5-small")
for k, v in config.t5.items():
self.t5config.__setattr__(k, v)
self.transformer = T5ForConditionalGeneration(self.t5config)
self.use_mel = self.config.dataset.use_mel
self.mel_is_conditioned = self.config.dataset.mel_is_conditioned
self.composer_to_feature_token = config.composer_to_feature_token
if self.use_mel and not self.mel_is_conditioned:
self.composer_to_feature_token = DEFAULT_COMPOSERS
if self.use_mel:
self.spectrogram = LogMelSpectrogram()
if self.mel_is_conditioned:
n_dim = 512
composer_n_vocab = len(self.composer_to_feature_token)
embedding_offset = min(self.composer_to_feature_token.values())
self.mel_conditioner = ConcatEmbeddingToMel(
embedding_offset=embedding_offset,
n_vocab=composer_n_vocab,
n_dim=n_dim,
)
else:
self.spectrogram = None
self.lr = config.training.lr
def forward(self, input_ids, labels):
"""
Deprecated.
"""
rt = self.transformer(input_ids=input_ids, labels=labels)
return rt
@torch.no_grad()
def single_inference(
self,
feature_tokens=None,
audio=None,
beatstep=None,
max_length=256,
max_batch_size=64,
n_bars=None,
composer_value=None,
):
"""
generate a long audio sequence
feature_tokens or audio : shape (time, )
beatstep : shape (time, )
- input_ids๊ฐ€ ํ•ด๋‹นํ•˜๋Š” beatstep ๊ฐ’๋“ค
(offset ๋น ์ง, ์ฆ‰ beatstep[0] == 0)
- beatstep[-1] : input_ids๊ฐ€ ๋๋‚˜๋Š” ์ง€์ ์˜ ์‹œ๊ฐ„๊ฐ’
(์ฆ‰ beatstep[-1] == len(y)//sr)
"""
assert feature_tokens is not None or audio is not None
assert beatstep is not None
if feature_tokens is not None:
assert len(feature_tokens.shape) == 1
if audio is not None:
assert len(audio.shape) == 1
config = self.config
PAD = self.t5config.pad_token_id
n_bars = config.dataset.n_bars if n_bars is None else n_bars
if beatstep[0] > 0.01:
print(
"inference warning : beatstep[0] is not 0 ({beatstep[0]}). all beatstep will be shifted."
)
beatstep = beatstep - beatstep[0]
if self.use_mel:
input_ids = None
inputs_embeds, ext_beatstep = self.prepare_inference_mel(
audio,
beatstep,
n_bars=n_bars,
padding_value=PAD,
composer_value=composer_value,
)
batch_size = inputs_embeds.shape[0]
else:
raise NotImplementedError
# Considering GPU capacity, some sequence would not be generated at once.
relative_tokens = list()
for i in range(0, batch_size, max_batch_size):
start = i
end = min(batch_size, i + max_batch_size)
if input_ids is None:
_input_ids = None
_inputs_embeds = inputs_embeds[start:end]
else:
_input_ids = input_ids[start:end]
_inputs_embeds = None
_relative_tokens = self.transformer.generate(
input_ids=_input_ids,
inputs_embeds=_inputs_embeds,
max_length=max_length,
)
_relative_tokens = _relative_tokens.cpu().numpy()
relative_tokens.append(_relative_tokens)
max_length = max([rt.shape[-1] for rt in relative_tokens])
for i in range(len(relative_tokens)):
relative_tokens[i] = np.pad(
relative_tokens[i],
[(0, 0), (0, max_length - relative_tokens[i].shape[-1])],
constant_values=PAD,
)
relative_tokens = np.concatenate(relative_tokens)
pm, notes = self.tokenizer.relative_batch_tokens_to_midi(
relative_tokens,
beatstep=ext_beatstep,
bars_per_batch=n_bars,
cutoff_time_idx=(n_bars + 1) * 4,
)
return relative_tokens, notes, pm
def prepare_inference_mel(self, audio, beatstep, n_bars, padding_value, composer_value=None):
n_steps = n_bars * 4
n_target_step = len(beatstep)
sample_rate = self.config.dataset.sample_rate
ext_beatstep = extrapolate_beat_times(beatstep, (n_bars + 1) * 4 + 1)
def split_audio(audio):
# Split audio corresponding beat intervals.
# Each audio's lengths are different.
# Because each corresponding beat interval times are different.
batch = []
for i in range(0, n_target_step, n_steps):
start_idx = i
end_idx = min(i + n_steps, n_target_step)
start_sample = int(ext_beatstep[start_idx] * sample_rate)
end_sample = int(ext_beatstep[end_idx] * sample_rate)
feature = audio[start_sample:end_sample]
batch.append(feature)
return batch
def pad_and_stack_batch(batch):
batch = pad_sequence(batch, batch_first=True, padding_value=padding_value)
return batch
batch = split_audio(audio)
batch = pad_and_stack_batch(batch)
inputs_embeds = self.spectrogram(batch).transpose(-1, -2)
if self.mel_is_conditioned:
composer_value = torch.tensor(composer_value).to(self.device)
composer_value = composer_value.repeat(inputs_embeds.shape[0])
inputs_embeds = self.mel_conditioner(inputs_embeds, composer_value)
return inputs_embeds, ext_beatstep
@torch.no_grad()
def generate(
self,
audio_path=None,
composer=None,
model="generated",
steps_per_beat=2,
stereo_amp=0.5,
n_bars=2,
ignore_duplicate=True,
show_plot=False,
save_midi=False,
save_mix=False,
midi_path=None,
mix_path=None,
click_amp=0.2,
add_click=False,
max_batch_size=None,
beatsteps=None,
mix_sample_rate=None,
audio_y=None,
audio_sr=None,
):
config = self.config
device = self.device
if audio_path is not None:
extension = os.path.splitext(audio_path)[1]
mix_path = (
audio_path.replace(extension, f".{model}.{composer}.wav")
if mix_path is None
else mix_path
)
midi_path = (
audio_path.replace(extension, f".{model}.{composer}.mid")
if midi_path is None
else midi_path
)
max_batch_size = 64 // n_bars if max_batch_size is None else max_batch_size
composer_to_feature_token = self.composer_to_feature_token
if composer is None:
composer = random.sample(list(composer_to_feature_token.keys()), 1)[0]
composer_value = composer_to_feature_token[composer]
mix_sample_rate = config.dataset.sample_rate if mix_sample_rate is None else mix_sample_rate
if not ignore_duplicate:
if os.path.exists(midi_path):
return
ESSENTIA_SAMPLERATE = 44100
if beatsteps is None:
y, sr = librosa.load(audio_path, sr=ESSENTIA_SAMPLERATE)
(
bpm,
beat_times,
confidence,
estimates,
essentia_beat_intervals,
) = extract_rhythm(audio_path, y=y)
beat_times = np.array(beat_times)
beatsteps = interpolate_beat_times(beat_times, steps_per_beat, extend=True)
else:
y = None
if self.use_mel:
if audio_y is None and config.dataset.sample_rate != ESSENTIA_SAMPLERATE:
if y is not None:
y = librosa.core.resample(
y,
orig_sr=ESSENTIA_SAMPLERATE,
target_sr=config.dataset.sample_rate,
)
sr = config.dataset.sample_rate
else:
y, sr = librosa.load(audio_path, sr=config.dataset.sample_rate)
elif audio_y is not None:
if audio_sr != config.dataset.sample_rate:
audio_y = librosa.core.resample(
audio_y, orig_sr=audio_sr, target_sr=config.dataset.sample_rate
)
audio_sr = config.dataset.sample_rate
y = audio_y
sr = audio_sr
start_sample = int(beatsteps[0] * sr)
end_sample = int(beatsteps[-1] * sr)
_audio = torch.from_numpy(y)[start_sample:end_sample].to(device)
fzs = None
else:
raise NotImplementedError
relative_tokens, notes, pm = self.single_inference(
feature_tokens=fzs,
audio=_audio,
beatstep=beatsteps - beatsteps[0],
max_length=config.dataset.target_length * max(1, (n_bars // config.dataset.n_bars)),
max_batch_size=max_batch_size,
n_bars=n_bars,
composer_value=composer_value,
)
for n in pm.instruments[0].notes:
n.start += beatsteps[0]
n.end += beatsteps[0]
if show_plot or save_mix:
if mix_sample_rate != sr:
y = librosa.core.resample(y, orig_sr=sr, target_sr=mix_sample_rate)
sr = mix_sample_rate
if add_click:
clicks = librosa.clicks(times=beatsteps, sr=sr, length=len(y)) * click_amp
y = y + clicks
pm_y = pm.fluidsynth(sr)
stereo = get_stereo(y, pm_y, pop_scale=stereo_amp)
if show_plot:
import note_seq
note_seq.plot_sequence(note_seq.midi_to_note_sequence(pm))
if save_mix:
sf.write(
file=mix_path,
data=stereo.T,
samplerate=sr,
format="wav",
)
if save_midi:
pm.write(midi_path)
return pm, composer, mix_path, midi_path