Spaces:
Runtime error
Runtime error
#%% | |
import os | |
os.system("git clone https://github.com/v-iashin/SpecVQGAN") | |
os.system("conda env create -f ./SpecVQGAN/conda_env.yml") | |
os.system("conda activate specvqgan") | |
os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch torchvision librosa gdown") | |
# %% | |
import sys | |
sys.path.append('./SpecVQGAN') | |
import time | |
from pathlib import Path | |
import IPython.display as display_audio | |
import soundfile | |
import torch | |
from IPython import display | |
from matplotlib import pyplot as plt | |
from torch.utils.data.dataloader import default_collate | |
from torchvision.utils import make_grid | |
from tqdm import tqdm | |
from feature_extraction.demo_utils import (ExtractResNet50, check_video_for_audio, | |
extract_melspectrogram, load_model, | |
show_grid, trim_video) | |
from sample_visualization import (all_attention_to_st, get_class_preditions, | |
last_attention_to_st, spec_to_audio_to_st, | |
tensor_to_plt) | |
from specvqgan.data.vggsound import CropImage | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# load model | |
model_name = '2021-07-30T21-34-25_vggsound_transformer' | |
log_dir = './logs' | |
os.chdir("./SpecVQGAN/") | |
config, sampler, melgan, melception = load_model(model_name, log_dir, device) | |
# %% | |
def extract_thumbnails(video_path): | |
# Trim the video | |
start_sec = 0 # to start with 01:35 use 95 seconds | |
video_path = trim_video(video_path, start_sec, trim_duration=10) | |
# Extract Features | |
extraction_fps = 21.5 | |
feature_extractor = ExtractResNet50(extraction_fps, config.data.params, device) | |
visual_features, resampled_frames = feature_extractor(video_path) | |
# Show the selected frames to extract features for | |
if not config.data.params.replace_feats_with_random: | |
fig = show_grid(make_grid(resampled_frames)) | |
fig.show() | |
# Prepare Input | |
batch = default_collate([visual_features]) | |
batch['feature'] = batch['feature'].to(device) | |
c = sampler.get_input(sampler.cond_stage_key, batch) | |
return c, video_path | |
# %% | |
import numpy as np | |
def generate_audio(video_path, temperature = 1.0): | |
# Define Sampling Parameters | |
W_scale = 1 | |
mode = 'full' | |
top_x = sampler.first_stage_model.quantize.n_e // 2 | |
update_every = 0 # use > 0 value, e.g. 15, to see the progress of generation (slows down the sampling speed) | |
full_att_mat = True | |
c, video_path = extract_thumbnails(video_path) | |
# Start sampling | |
with torch.no_grad(): | |
start_t = time.time() | |
quant_c, c_indices = sampler.encode_to_c(c) | |
# crec = sampler.cond_stage_model.decode(quant_c) | |
patch_size_i = 5 | |
patch_size_j = 53 | |
B, D, hr_h, hr_w = sampling_shape = (1, 256, 5, 53*W_scale) | |
z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device) | |
if mode == 'full': | |
start_step = 0 | |
else: | |
start_step = (patch_size_j // 2) * patch_size_i | |
z_pred_indices[:, :start_step] = z_indices[:, :start_step] | |
pbar = tqdm(range(start_step, hr_w * hr_h), desc='Sampling Codebook Indices') | |
for step in pbar: | |
i = step % hr_h | |
j = step // hr_h | |
i_start = min(max(0, i - (patch_size_i // 2)), hr_h - patch_size_i) | |
j_start = min(max(0, j - (patch_size_j // 2)), hr_w - patch_size_j) | |
i_end = i_start + patch_size_i | |
j_end = j_start + patch_size_j | |
local_i = i - i_start | |
local_j = j - j_start | |
patch_2d_shape = (B, D, patch_size_i, patch_size_j) | |
pbar.set_postfix( | |
Step=f'({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})' | |
) | |
patch = z_pred_indices \ | |
.reshape(B, hr_w, hr_h) \ | |
.permute(0, 2, 1)[:, i_start:i_end, j_start:j_end].permute(0, 2, 1) \ | |
.reshape(B, patch_size_i * patch_size_j) | |
# assuming we don't crop the conditioning and just use the whole c, if not desired uncomment the above | |
cpatch = c_indices | |
logits, _, attention = sampler.transformer(patch[:, :-1], cpatch) | |
# remove conditioning | |
logits = logits[:, -patch_size_j*patch_size_i:, :] | |
local_pos_in_flat = local_j * patch_size_i + local_i | |
logits = logits[:, local_pos_in_flat, :] | |
logits = logits / temperature | |
logits = sampler.top_k_logits(logits, top_x) | |
# apply softmax to convert to probabilities | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# sample from the distribution | |
ix = torch.multinomial(probs, num_samples=1) | |
z_pred_indices[:, j * hr_h + i] = ix | |
if update_every > 0 and step % update_every == 0: | |
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape) | |
# fliping the spectrogram just for illustration purposes (low freqs to bottom, high - top) | |
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,)) | |
display.clear_output(wait=True) | |
display.display(z_pred_img_st) | |
if full_att_mat: | |
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True) | |
display.display(att_plot) | |
plt.close() | |
else: | |
quant_z_shape = sampling_shape | |
c_length = cpatch.shape[-1] | |
quant_c_shape = quant_c.shape | |
c_att_plot, z_att_plot = last_attention_to_st( | |
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter, | |
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape, | |
placeholders=None, flip_c_dims=None, flip_z_dims=(2,)) | |
display.display(c_att_plot) | |
display.display(z_att_plot) | |
plt.close() | |
plt.close() | |
plt.close() | |
# quant_z_shape = sampling_shape | |
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape) | |
# showing the final image | |
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,)) | |
display.clear_output(wait=True) | |
display.display(z_pred_img_st) | |
if full_att_mat: | |
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True) | |
display.display(att_plot) | |
plt.close() | |
else: | |
quant_z_shape = sampling_shape | |
c_length = cpatch.shape[-1] | |
quant_c_shape = quant_c.shape | |
c_att_plot, z_att_plot = last_attention_to_st( | |
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter, | |
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape, | |
placeholders=None, flip_c_dims=None, flip_z_dims=(2,) | |
) | |
display.display(c_att_plot) | |
display.display(z_att_plot) | |
plt.close() | |
plt.close() | |
plt.close() | |
print(f'Sampling Time: {time.time() - start_t:3.2f} seconds') | |
waves = spec_to_audio_to_st(z_pred_img, config.data.params.spec_dir_path, | |
config.data.params.sample_rate, show_griffin_lim=False, | |
vocoder=melgan, show_in_st=False) | |
print(f'Sampling Time (with vocoder): {time.time() - start_t:3.2f} seconds') | |
print(f'Generated: {len(waves["vocoder"]) / config.data.params.sample_rate:.2f} seconds') | |
# Melception opinion on the class distribution of the generated sample | |
topk_preds = get_class_preditions(z_pred_img, melception) | |
print(topk_preds) | |
audio_path = os.path.join(log_dir, Path(video_path).stem + '.wav') | |
audio = waves['vocoder'] | |
audio = np.repeat([audio], 2, axis=0).T | |
print(audio.shape) | |
soundfile.write(audio_path, audio, config.data.params.sample_rate, 'PCM_24') | |
print(f'The sample has been saved @ {audio_path}') | |
video_out_path = os.path.join(log_dir, Path(video_path).stem + '_audio.mp4') | |
print(video_path, audio_path, video_out_path) | |
os.system("ffmpeg -i %s -i %s -map 0:v -map 1:a -c:v copy -shortest %s" % (video_path, audio_path, video_out_path)) | |
return video_out_path | |
# return config.data.params.sample_rate, audio | |
# %% | |
#generate_audio("../kiss.avi") | |
#%% | |
import gradio as gr | |
iface = gr.Interface(generate_audio, "video", ["playable_video"], | |
description="Generate audio based on the video input") | |
iface.launch() | |
# %% | |