In [1]:
from audio_diffusion_pytorch import AudioDiffusionModel
import torch
from tqdm import tqdm
from IPython.display import Audio
from pathlib import Path
import torchaudio
import torchaudio.transforms as T
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader, Dataset
import torch.nn.functional as F
import wandb


In [2]:
wandb.init(project="RemFX", entity="mattricesound")

[34m[1mwandb[0m: Currently logged in as: [33mmattricesound[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
SAMPLE_RATE = 22050
LENGTH = 2**17#round(5 * SAMPLE_RATE) 6 seconds

In [19]:
model = AudioDiffusionModel(in_channels=1)

In [20]:
class GuitarDataset(Dataset):
    def __init__(self, root, length=LENGTH):
        self.files = list(Path().glob(f"{root}/**/*.wav"))
        self.resampler = T.Resample(48000, SAMPLE_RATE)
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        x, sr = torchaudio.load(self.files[idx])
#         x = x.view() # Duplicate channel
        resampled_x = self.resampler(x)
        if resampled_x.shape[1] < LENGTH:
            resampled_x = F.pad(resampled_x, (0, LENGTH - resampled_x.shape[1]))
        elif resampled_x.shape[1] > LENGTH:
            resampled_x = resampled_x[:, :LENGTH]
        return resampled_x

In [21]:
g = GuitarDataset(Path("Clean"))

In [237]:
x = g[10]
print(x.shape)

torch.Size([1, 131072])


In [22]:
class RemFXDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size

    def setup(self, stage: str):
#         self.guitar_test = GuitarDataset(self.data_dir, train=False)
#         self.guitar_predict = GuitarDataset(self.data_dir, train=False)
        guitar_full = GuitarDataset(self.data_dir)
        self.guitar_train, self.guitar_val = random_split(guitar_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.guitar_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.guitar_val, batch_size=self.batch_size)

#     def test_dataloader(self):
#         return DataLoader(self.guitar_test, batch_size=self.batch_size)

#     def predict_dataloader(self):
#         return DataLoader(self.guitar_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        pass

In [23]:
data = DataLoader(GuitarDataset(Path("Clean")), batch_size=32)

In [24]:
dataiter = iter(data)
x = next(dataiter)

In [25]:
x[0].shape

torch.Size([1, 131072])

In [26]:
# wandb.log({"Audio": wandb.Audio(x[0].view(-1).numpy(), sample_rate=SAMPLE_RATE)})

In [28]:
epochs = 50
for i in tqdm(range(epochs)):
    for batch in data:
        loss = model(batch)
        loss.backward()
    if i % 5 == 0:
        wandb.log({"loss": loss})
        with torch.no_grad():
            noise = torch.randn(1, 1, 2**17)
            sampled = model.sample(noise=noise, num_steps=40)
            z = sampled.view(-1)
            wandb.log({f"Audio_{i}": wandb.Audio(z.numpy(), sample_rate=SAMPLE_RATE)})
            
            
        

 14%|█████████▏                                                        | 7/50 [7:29:41<59:56:20, 5018.16s/it]wandb: Network error (ConnectionError), entering retry loop.
 14%|█████████▏                                                        | 7/50 [8:13:48<50:33:21, 4232.58s/it]


KeyboardInterrupt: 

In [259]:
noise = torch.randn(1, 1, 2**17)
sampled = model.sample(noise=noise, num_steps=50)

In [260]:
print(sampled.shape, sampled)

torch.Size([1, 1, 131072]) tensor([[[-0.4879, -0.4534, -0.4094,  ..., -1.0000,  0.8554, -0.9605]]])


In [261]:
z = sampled.view(-1)
# z = z.mean(axis=0)

In [262]:
Audio(z, rate=22050)

tensor(0.6213, grad_fn=<MseLossBackward0>)


In [164]:
110250 / 16

6890.625

In [165]:
12 * 22050

264600

In [166]:
264600 / 16

16537.5