In [None]:
import torch
import random
import numpy as np
from PIL import Image
from datasets import load_dataset
from IPython.display import Audio
from diffusers import AutoencoderKL, AudioDiffusionPipeline, Mel

In [None]:
mel = Mel()
vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')

In [None]:
vae.config

In [None]:
ds = load_dataset('teticio/audio-diffusion-256')

### Reconstruct audio

In [None]:
image = random.choice(ds['train'])['image']
display(image)
Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())

In [None]:
# encode
input_image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
 (image.height, image.width, 1))
input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)
posterior = vae.encode(torch.tensor([input_image],
 dtype=torch.float32)).latent_dist
latents = posterior.sample()

In [None]:
# reconstruct
output_image = vae.decode(latents)['sample']
output_image = torch.clamp(output_image, -1., 1.)
output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
output_image = (output_image.detach().cpu().numpy() *
 255).round().astype("uint8").transpose(0, 2, 3, 1)[0, :, :, 0]
output_image = Image.fromarray(output_image)
display(output_image)
Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())

### Random sample from latent space
(Don't expect interesting results!)

In [None]:
# sample
output_image = vae.decode(torch.randn_like(latents))['sample']
output_image = torch.clamp(output_image, -1., 1.)
output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
output_image = (output_image.detach().cpu().numpy() *
 255).round().astype("uint8").transpose(0, 2, 3, 1)[0, :, :, 0]
output_image = Image.fromarray(output_image)
display(output_image)
Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())

### Interpolate between two audios in latent space

In [None]:
image2 = random.choice(ds['train'])['image']
display(image2)
Audio(data=mel.image_to_audio(image2), rate=mel.get_sample_rate())

In [None]:
# encode
input_image2 = np.frombuffer(image2.tobytes(), dtype="uint8").reshape(
 (image2.height, image2.width, 1))
input_image2 = ((input_image2 / 255) * 2 - 1).transpose(2, 0, 1)
posterior2 = vae.encode(torch.tensor([input_image2],
 dtype=torch.float32)).latent_dist
latents2 = posterior2.sample()

In [None]:
# interpolate
alpha = 0.5 #@param {type:"slider", min:0, max:1, step:0.1}
output_image = vae.decode(
 AudioDiffusionPipeline.slerp(latents, latents2, alpha))['sample']
output_image = torch.clamp(output_image, -1., 1.)
output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
output_image = (output_image.detach().cpu().numpy() *
 255).round().astype("uint8").transpose(0, 2, 3, 1)[0, :, :, 0]
output_image = Image.fromarray(output_image)
display(output_image)
display(Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate()))
display(Audio(data=mel.image_to_audio(image2), rate=mel.get_sample_rate()))
display(
 Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate()))