naotokui's picture
added conda command
e0ce7d2
#%%
import os
os.system("git clone https://github.com/v-iashin/SpecVQGAN")
os.system("conda env create -f ./SpecVQGAN/conda_env.yml")
os.system("conda activate specvqgan")
os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch torchvision librosa gdown")
# %%
import sys
sys.path.append('./SpecVQGAN')
import time
from pathlib import Path
import IPython.display as display_audio
import soundfile
import torch
from IPython import display
from matplotlib import pyplot as plt
from torch.utils.data.dataloader import default_collate
from torchvision.utils import make_grid
from tqdm import tqdm
from feature_extraction.demo_utils import (ExtractResNet50, check_video_for_audio,
extract_melspectrogram, load_model,
show_grid, trim_video)
from sample_visualization import (all_attention_to_st, get_class_preditions,
last_attention_to_st, spec_to_audio_to_st,
tensor_to_plt)
from specvqgan.data.vggsound import CropImage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load model
model_name = '2021-07-30T21-34-25_vggsound_transformer'
log_dir = './logs'
os.chdir("./SpecVQGAN/")
config, sampler, melgan, melception = load_model(model_name, log_dir, device)
# %%
def extract_thumbnails(video_path):
# Trim the video
start_sec = 0 # to start with 01:35 use 95 seconds
video_path = trim_video(video_path, start_sec, trim_duration=10)
# Extract Features
extraction_fps = 21.5
feature_extractor = ExtractResNet50(extraction_fps, config.data.params, device)
visual_features, resampled_frames = feature_extractor(video_path)
# Show the selected frames to extract features for
if not config.data.params.replace_feats_with_random:
fig = show_grid(make_grid(resampled_frames))
fig.show()
# Prepare Input
batch = default_collate([visual_features])
batch['feature'] = batch['feature'].to(device)
c = sampler.get_input(sampler.cond_stage_key, batch)
return c, video_path
# %%
import numpy as np
def generate_audio(video_path, temperature = 1.0):
# Define Sampling Parameters
W_scale = 1
mode = 'full'
top_x = sampler.first_stage_model.quantize.n_e // 2
update_every = 0 # use > 0 value, e.g. 15, to see the progress of generation (slows down the sampling speed)
full_att_mat = True
c, video_path = extract_thumbnails(video_path)
# Start sampling
with torch.no_grad():
start_t = time.time()
quant_c, c_indices = sampler.encode_to_c(c)
# crec = sampler.cond_stage_model.decode(quant_c)
patch_size_i = 5
patch_size_j = 53
B, D, hr_h, hr_w = sampling_shape = (1, 256, 5, 53*W_scale)
z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device)
if mode == 'full':
start_step = 0
else:
start_step = (patch_size_j // 2) * patch_size_i
z_pred_indices[:, :start_step] = z_indices[:, :start_step]
pbar = tqdm(range(start_step, hr_w * hr_h), desc='Sampling Codebook Indices')
for step in pbar:
i = step % hr_h
j = step // hr_h
i_start = min(max(0, i - (patch_size_i // 2)), hr_h - patch_size_i)
j_start = min(max(0, j - (patch_size_j // 2)), hr_w - patch_size_j)
i_end = i_start + patch_size_i
j_end = j_start + patch_size_j
local_i = i - i_start
local_j = j - j_start
patch_2d_shape = (B, D, patch_size_i, patch_size_j)
pbar.set_postfix(
Step=f'({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})'
)
patch = z_pred_indices \
.reshape(B, hr_w, hr_h) \
.permute(0, 2, 1)[:, i_start:i_end, j_start:j_end].permute(0, 2, 1) \
.reshape(B, patch_size_i * patch_size_j)
# assuming we don't crop the conditioning and just use the whole c, if not desired uncomment the above
cpatch = c_indices
logits, _, attention = sampler.transformer(patch[:, :-1], cpatch)
# remove conditioning
logits = logits[:, -patch_size_j*patch_size_i:, :]
local_pos_in_flat = local_j * patch_size_i + local_i
logits = logits[:, local_pos_in_flat, :]
logits = logits / temperature
logits = sampler.top_k_logits(logits, top_x)
# apply softmax to convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution
ix = torch.multinomial(probs, num_samples=1)
z_pred_indices[:, j * hr_h + i] = ix
if update_every > 0 and step % update_every == 0:
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
# fliping the spectrogram just for illustration purposes (low freqs to bottom, high - top)
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
display.clear_output(wait=True)
display.display(z_pred_img_st)
if full_att_mat:
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
display.display(att_plot)
plt.close()
else:
quant_z_shape = sampling_shape
c_length = cpatch.shape[-1]
quant_c_shape = quant_c.shape
c_att_plot, z_att_plot = last_attention_to_st(
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
placeholders=None, flip_c_dims=None, flip_z_dims=(2,))
display.display(c_att_plot)
display.display(z_att_plot)
plt.close()
plt.close()
plt.close()
# quant_z_shape = sampling_shape
z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
# showing the final image
z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
display.clear_output(wait=True)
display.display(z_pred_img_st)
if full_att_mat:
att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
display.display(att_plot)
plt.close()
else:
quant_z_shape = sampling_shape
c_length = cpatch.shape[-1]
quant_c_shape = quant_c.shape
c_att_plot, z_att_plot = last_attention_to_st(
attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
placeholders=None, flip_c_dims=None, flip_z_dims=(2,)
)
display.display(c_att_plot)
display.display(z_att_plot)
plt.close()
plt.close()
plt.close()
print(f'Sampling Time: {time.time() - start_t:3.2f} seconds')
waves = spec_to_audio_to_st(z_pred_img, config.data.params.spec_dir_path,
config.data.params.sample_rate, show_griffin_lim=False,
vocoder=melgan, show_in_st=False)
print(f'Sampling Time (with vocoder): {time.time() - start_t:3.2f} seconds')
print(f'Generated: {len(waves["vocoder"]) / config.data.params.sample_rate:.2f} seconds')
# Melception opinion on the class distribution of the generated sample
topk_preds = get_class_preditions(z_pred_img, melception)
print(topk_preds)
audio_path = os.path.join(log_dir, Path(video_path).stem + '.wav')
audio = waves['vocoder']
audio = np.repeat([audio], 2, axis=0).T
print(audio.shape)
soundfile.write(audio_path, audio, config.data.params.sample_rate, 'PCM_24')
print(f'The sample has been saved @ {audio_path}')
video_out_path = os.path.join(log_dir, Path(video_path).stem + '_audio.mp4')
print(video_path, audio_path, video_out_path)
os.system("ffmpeg -i %s -i %s -map 0:v -map 1:a -c:v copy -shortest %s" % (video_path, audio_path, video_out_path))
return video_out_path
# return config.data.params.sample_rate, audio
# %%
#generate_audio("../kiss.avi")
#%%
import gradio as gr
iface = gr.Interface(generate_audio, "video", ["playable_video"],
description="Generate audio based on the video input")
iface.launch()
# %%