schnik's picture
update README with explaination and Gradio interface with examples
8f9d4fd
raw
history blame contribute delete
No virus
11.3 kB
from torch.utils.data import Dataset
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torchaudio
import os
import logging
from torchvision.models import resnet50, ResNet50_Weights, resnet152, resnet18, resnet34, ResNet152_Weights
from PIL import Image
from time import strftime
import math
import numpy as np
import moviepy.editor as mpe
class VideoDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.data_map = []
dir_map = os.listdir(data_dir)
for d in dir_map:
name, extension = os.path.splitext(d)
if extension == ".mp4":
self.data_map.append({"video": os.path.join(data_dir, d)})
def __len__(self):
return len(self.data_map)
def __getitem__(self, idx):
return self.data_map[idx]["video"]
# input: video_path, output: wav_music
class VideoToT5(nn.Module):
def __init__(self,
device: str,
video_extraction_framerate: int,
encoder_input_dimension: int,
encoder_output_dimension: int,
encoder_heads: int,
encoder_dim_feedforward: int,
encoder_layers: int
):
super().__init__()
self.video_extraction_framerate = video_extraction_framerate
self.video_feature_extractor = VideoFeatureExtractor(video_extraction_framerate=video_extraction_framerate,
device=device)
self.video_encoder = VideoEncoder(
device,
encoder_input_dimension,
encoder_output_dimension,
encoder_heads,
encoder_dim_feedforward,
encoder_layers
)
def forward(self, video_paths: [str]):
image_embeddings = []
for video_path in video_paths:
video = mpe.VideoFileClip(video_path)
video_embedding = self.video_feature_extractor(video)
image_embeddings.append(video_embedding)
video_embedding = torch.stack(
image_embeddings) # resulting shape: [batch_size, video_extraction_framerate, resnet_output_dimension]
# not used, gives worse results!
# video_embeddings = torch.mean(video_embeddings, 0, True) # average out all image embedding to one video embedding
t5_embeddings = self.video_encoder(video_embedding) # T5 output: [batch_size, num_tokens,
# t5_embedding_size]
return t5_embeddings
class VideoEncoder(nn.Module):
def __init__(self,
device: str,
encoder_input_dimension: int,
encoder_output_dimension: int,
encoder_heads: int,
encoder_dim_feedforward: int,
encoder_layers: int
):
super().__init__()
self.device = device
self.encoder = (nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=encoder_input_dimension,
nhead=encoder_heads,
dim_feedforward=encoder_dim_feedforward
),
num_layers=encoder_layers,
)
).to(device)
# linear layer to match T5 embedding dimension
self.linear = (nn.Linear(
in_features=encoder_input_dimension,
out_features=encoder_output_dimension)
.to(device))
def forward(self, x):
assert x.dim() == 3
x = torch.transpose(x, 0, 1) # encoder expects [sequence_length, batch_size, embedding_dimension]
x = self.encoder(x) # encoder forward pass
x = self.linear(x) # forward pass through the linear layer
x = torch.transpose(x, 0, 1) # shape: [batch_size, sequence_length, embedding_dimension]
return x
class VideoFeatureExtractor(nn.Module):
def __init__(self,
device: str,
video_extraction_framerate: int = 1,
resnet_output_dimension: int = 2048):
super().__init__()
self.device = device
# using a ResNet trained on ImageNet
self.resnet = resnet50(weights="IMAGENET1K_V2").eval()
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])).to(device) # remove ResNet layer
self.resnet_preprocessor = ResNet50_Weights.DEFAULT.transforms().to(device)
self.video_extraction_framerate = video_extraction_framerate # setting the fps at which the video is processed
self.positional_encoder = PositionalEncoding(resnet_output_dimension).to(device)
def forward(self, video: mpe.VideoFileClip):
embeddings = []
for i in range(0, 30 * self.video_extraction_framerate):
i = video.get_frame(i) # get frame as numpy array
i = Image.fromarray(i) # create PIL image from numpy array
i = self.resnet_preprocessor(i) # preprocess image
i = i.to(self.device)
i = i.unsqueeze(0) # adding a batch dimension
i = self.resnet(i).squeeze() # ResNet forward pass
i = i.squeeze()
embeddings.append(i) # collect embeddings
embeddings = torch.stack(embeddings) # concatenate all frame embeddings into one video embedding
embeddings = embeddings.unsqueeze(1)
embeddings = self.positional_encoder(embeddings) # apply positional encoding with a sequence length of 30
embeddings = embeddings.squeeze()
return embeddings
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_length, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)
def freeze_model(model: nn.Module):
for param in model.parameters():
param.requires_grad = False
model.eval()
def split_dataset_randomly(dataset, validation_split: float, test_split: float, seed: int = None):
dataset_size = len(dataset)
indices = list(range(dataset_size))
datapoints_validation = int(np.floor(validation_split * dataset_size))
datapoints_testing = int(np.floor(test_split * dataset_size))
if seed:
np.random.seed(seed)
np.random.shuffle(indices) # in-place operation
training = indices[datapoints_validation + datapoints_testing:]
validation = indices[datapoints_validation:datapoints_testing + datapoints_validation]
testing = indices[:datapoints_testing]
assert len(validation) == datapoints_validation, "Validation set length incorrect"
assert len(testing) == datapoints_testing, "Testing set length incorrect"
assert len(training) == dataset_size - (datapoints_testing + datapoints_testing), "Training set length incorrect"
assert not any([item in training for item in validation]), "Training and Validation overlap"
assert not any([item in training for item in testing]), "Training and Testing overlap"
assert not any([item in validation for item in testing]), "Validation and Testing overlap"
return training, validation, testing
### private function from audiocraft.solver.musicgen.py => _compute_cross_entropy
def compute_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
def generate_audio_codes(audio_paths: [str],
audiocraft_compression_model: torch.nn.Module,
device: str) -> torch.Tensor:
audio_duration = 30
encodec_sample_rate = audiocraft_compression_model.sample_rate
torch_audios = []
for audio_path in audio_paths:
wav, original_sample_rate = torchaudio.load(audio_path) # load audio from file
wav = torchaudio.functional.resample(wav, original_sample_rate,
encodec_sample_rate) # cast audio to model sample rate
wav = wav[:, :encodec_sample_rate * audio_duration] # enforce an exact audio length of 30 seconds
assert len(wav.shape) == 2, f"audio data is not of shape [channels, duration]"
assert wav.shape[0] == 2, "audio data should be in stereo, but has not 2 channels"
torch_audios.append(wav)
torch_audios = torch.stack(torch_audios)
torch_audios = torch_audios.to(device)
with torch.no_grad():
gen_audio = audiocraft_compression_model.encode(torch_audios)
codes, scale = gen_audio
assert scale is None
return codes
def create_condition_tensors(
video_embeddings: torch.Tensor,
batch_size: int,
video_extraction_framerate: int,
device: str
):
# model T5 mask
mask = torch.ones((batch_size, video_extraction_framerate * 30), dtype=torch.int).to(device)
condition_tensors = {
'description': (video_embeddings, mask)
}
return condition_tensors
def get_current_timestamp():
return strftime("%Y_%m_%d___%H_%M_%S")
def configure_logging(output_dir: str, filename: str, log_level):
# create logs folder, if not existing
os.makedirs(output_dir, exist_ok=True)
level = getattr(logging, log_level)
file_path = output_dir + "/" + filename
logging.basicConfig(filename=file_path, encoding='utf-8', level=level)
logger = logging.getLogger()
# only add a StreamHandler if it is not present yet
if len(logger.handlers) <= 1:
logger.addHandler(logging.StreamHandler())