Spaces:
Runtime error
Runtime error
# Hacked together using the code from https://github.com/nikhilsinghmus/image2reverb | |
import os, types | |
import numpy as np | |
import gradio as gr | |
import soundfile as sf | |
import scipy | |
import librosa.display | |
from PIL import Image | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
import torch | |
from torch.utils.data import Dataset | |
import torchvision.transforms as transforms | |
from pytorch_lightning import Trainer | |
from image2reverb.model import Image2Reverb | |
from image2reverb.stft import STFT | |
predicted_ir = None | |
predicted_spectrogram = None | |
predicted_depthmap = None | |
def test_step(self, batch, batch_idx): | |
spec, label, paths = batch | |
examples = [os.path.splitext(os.path.basename(s))[0] for _, s in zip(*paths)] | |
f, img = self.enc.forward(label) | |
shape = ( | |
f.shape[0], | |
(self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], | |
f.shape[2], | |
f.shape[3] | |
) | |
z = torch.cat((f, torch.randn(shape, device=model.device)), 1) | |
fake_spec = self.g(z) | |
stft = STFT() | |
y_f = [stft.inverse(s.squeeze()) for s in fake_spec] | |
# TODO: bit hacky | |
global predicted_ir, predicted_spectrogram, predicted_depthmap | |
predicted_ir = y_f[0] | |
s = fake_spec.squeeze().cpu().numpy() | |
predicted_spectrogram = np.exp((((s + 1) * 0.5) * 19.5) - 17.5) - 1e-8 | |
img = (img + 1) * 0.5 | |
predicted_depthmap = img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().numpy() | |
return {"test_audio": y_f, "test_examples": examples} | |
def test_epoch_end(self, outputs): | |
if not self.test_callback: | |
return | |
examples = [] | |
audio = [] | |
for output in outputs: | |
for i in range(len(output["test_examples"])): | |
audio.append(output["test_audio"][i]) | |
examples.append(output["test_examples"][i]) | |
self.test_callback(examples, audio) | |
checkpoint_path = "./checkpoints/image2reverb_f22.ckpt" | |
encoder_path = None | |
depthmodel_path = "./checkpoints/mono_odom_640x192" | |
constant_depth = None | |
latent_dimension = 512 | |
model = Image2Reverb(encoder_path, depthmodel_path) | |
m = torch.load(checkpoint_path, map_location=model.device) | |
model.load_state_dict(m["state_dict"]) | |
model.test_step = types.MethodType(test_step, model) | |
model.test_epoch_end = types.MethodType(test_epoch_end, model) | |
image_transforms = transforms.Compose([ | |
transforms.Resize([224, 224], transforms.functional.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
class Image2ReverbDemoDataset(Dataset): | |
def __init__(self, image): | |
self.image = Image.fromarray(image) | |
self.stft = STFT() | |
def __getitem__(self, index): | |
img_tensor = image_transforms(self.image.convert("RGB")) | |
return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", "") | |
def __len__(self): | |
return 1 | |
def name(self): | |
return "Image2ReverbDemo" | |
def convolve(audio, reverb): | |
# convolve audio with reverb | |
wet_audio = np.concatenate((audio, np.zeros(reverb.shape))) | |
wet_audio = scipy.signal.oaconvolve(wet_audio, reverb, "full")[:len(wet_audio)] | |
# normalize audio to roughly -1 dB peak and remove DC offset | |
wet_audio /= np.max(np.abs(wet_audio)) | |
wet_audio -= np.mean(wet_audio) | |
wet_audio *= 0.9 | |
return wet_audio | |
def predict(image, audio): | |
# image = numpy (height, width, channels) | |
# audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels)) | |
test_set = Image2ReverbDemoDataset(image) | |
test_loader = torch.utils.data.DataLoader(test_set, num_workers=0, batch_size=1) | |
trainer = Trainer(limit_test_batches=1) | |
trainer.test(model, test_loader, verbose=True) | |
# depthmap output | |
depthmap_fig = plt.figure() | |
plt.imshow(predicted_depthmap) | |
plt.close() | |
# spectrogram output | |
spectrogram_fig = plt.figure() | |
librosa.display.specshow(predicted_spectrogram, sr=22050, x_axis="time", y_axis="hz") | |
plt.close() | |
# plot the IR as a waveform | |
waveform_fig = plt.figure() | |
librosa.display.waveshow(predicted_ir, sr=22050, alpha=0.5) | |
plt.close() | |
# output audio as 16-bit signed integer | |
ir = (22050, (predicted_ir * 32767).astype(np.int16)) | |
sample_rate, original_audio = audio | |
# incoming audio is 16-bit signed integer, convert to float and normalize | |
original_audio = original_audio.astype(np.float32) / 32768.0 | |
original_audio /= np.max(np.abs(original_audio)) | |
# resample reverb to sample_rate first, also normalize | |
reverb = predicted_ir.copy() | |
reverb = scipy.signal.resample_poly(reverb, up=sample_rate, down=22050) | |
reverb /= np.max(np.abs(reverb)) | |
# stereo? | |
if len(original_audio.shape) > 1: | |
wet_left = convolve(original_audio[:, 0], reverb) | |
wet_right = convolve(original_audio[:, 1], reverb) | |
wet_audio = np.concatenate([wet_left[:, None], wet_right[:, None]], axis=1) | |
else: | |
wet_audio = convolve(original_audio, reverb) | |
# 50% dry-wet mix | |
mixed_audio = wet_audio * 0.5 | |
mixed_audio[:len(original_audio), ...] += original_audio * 0.9 * 0.5 | |
# convert back to 16-bit signed integer | |
wet_audio = (wet_audio * 32767).astype(np.int16) | |
mixed_audio = (mixed_audio * 32767).astype(np.int16) | |
convolved_audio_100 = (sample_rate, wet_audio) | |
convolved_audio_50 = (sample_rate, mixed_audio) | |
return depthmap_fig, spectrogram_fig, waveform_fig, ir, convolved_audio_100, convolved_audio_50 | |
title = "Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis" | |
description = """ | |
<b>Image2Reverb</b> predicts the acoustic reverberation of a given environment from a 2D image. <a href="https://arxiv.org/abs/2103.14201">Read the paper</a> | |
How to use: Choose an image of a room or other environment and an audio file. | |
The model will predict what the reverb of the room sounds like and applies this to the audio file. | |
First, the image is resized to 224ร224. The monodepth model is used to predict a depthmap, which is added as an | |
additional channel to the image input. A ResNet-based encoder then converts the image into features, and | |
finally a GAN predicts the spectrogram of the reverb's impulse response. | |
<center><img src="file/model.jpg" width="870" height="297" alt="model architecture"></center> | |
The predicted impulse response is mono 22050 kHz. It is upsampled to the sampling rate of the audio | |
file and applied to both channels if the audio is stereo. | |
Generating the impulse response involves a certain amount of randomness, making it sound a little | |
different every time you try it. | |
""" | |
article = """ | |
<div style='margin:20px auto;'> | |
<p>Based on original work by Nikhil Singh, Jeff Mentch, Jerry Ng, Matthew Beveridge, Iddo Drori. | |
<a href="https://web.media.mit.edu/~nsingh1/image2reverb/">Project Page</a> | | |
<a href="https://arxiv.org/abs/2103.14201">Paper</a> | | |
<a href="https://github.com/nikhilsinghmus/image2reverb">GitHub</a></p> | |
<pre> | |
@InProceedings{Singh_2021_ICCV, | |
author = {Singh, Nikhil and Mentch, Jeff and Ng, Jerry and Beveridge, Matthew and Drori, Iddo}, | |
title = {Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis}, | |
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
month = {October}, | |
year = {2021}, | |
pages = {286-295} | |
} | |
</pre> | |
<p>๐ Example images from <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">the original project page</a>.</p> | |
<p>๐ถ Example sound from <a href="https://freesound.org/people/ashesanddreams/sounds/610414/">Ashes and Dreams @ freesound.org</a> (CC BY 4.0 license). This is a mono 48 kHz recording that has no reverb on it.</p> | |
</div> | |
""" | |
audio_example = "examples/ashesanddreams.wav" | |
examples = [ | |
["examples/input.4e2f71f6.png", audio_example], | |
["examples/input.321eef38.png", audio_example], | |
["examples/input.2238dc21.png", audio_example], | |
["examples/input.4d280b40.png", audio_example], | |
["examples/input.0c3f5013.png", audio_example], | |
["examples/input.98773b90.png", audio_example], | |
["examples/input.ac61500f.png", audio_example], | |
["examples/input.5416407f.png", audio_example], | |
] | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.inputs.Image(label="Upload Image"), | |
gr.inputs.Audio(label="Upload Audio", source="upload"), | |
], | |
outputs=[ | |
gr.Plot(label="Depthmap"), | |
gr.Plot(label="Impulse Response Spectrogram"), | |
gr.Plot(label="Impulse Response Waveform"), | |
gr.outputs.Audio(label="Impulse Response"), | |
gr.outputs.Audio(label="Output Audio (100% Wet)"), | |
gr.outputs.Audio(label="Output Audio (50% Dry, 50% Wet)"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
).launch() | |