schnik commited on
Commit
3c3b47c
1 Parent(s): 9562105

Upload 3 files

Browse files

Add Inference Files

Files changed (3) hide show
  1. app.py +78 -0
  2. inference.py +117 -0
  3. inference_utils.py +283 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ sys.path.insert(1, '..')
5
+ import inference
6
+ import torch
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ def generate_background_music(video_path, dataset, use_peft, musicgen_size):
11
+ print(f"Start generating background music for {video_path} with model \"{'peft' if use_peft else 'audiocraft'}_{dataset}_{musicgen_size}\"")
12
+
13
+ new_video_path = inference.generate_background_music(
14
+ video_path=video_path,
15
+ dataset=dataset,
16
+ musicgen_size=musicgen_size,
17
+ use_stereo=True,
18
+ use_peft=use_peft,
19
+ musicgen_temperature=1.0,
20
+ musicgen_guidance_scale=3.0,
21
+ top_k_sampling=250,
22
+ device=device
23
+ )
24
+ return gr.Video(new_video_path)
25
+
26
+
27
+ interface = gr.Interface(fn=generate_background_music,
28
+ inputs=[
29
+ gr.Video(
30
+ label="video input",
31
+ min_length=5,
32
+ max_length=20,
33
+ sources=['upload'],
34
+ show_download_button=True,
35
+ include_audio=True
36
+ ),
37
+ gr.Radio(["nature", "symmv"],
38
+ label="Video Encoder Version",
39
+ value="nature",
40
+ info="Choose one of the available Video Encoders."),
41
+ gr.Radio([False, True],
42
+ label="Use MusicGen Audio Decoder Model trained with PEFT",
43
+ value=False,
44
+ info="If set to 'True' the MusicGen Audio Decoder models trained with LoRA "
45
+ "(Low Rank Adaptation) are used. If set to 'False', the original "
46
+ "MusicGen models are used."),
47
+ gr.Radio(["small", "medium", "large"],
48
+ label="MusicGen Audio Decoder Size",
49
+ value="small",
50
+ info="Choose the size of the MusicGen audio decoder."),
51
+ ],
52
+
53
+ outputs=[gr.Video(label="video output")],
54
+ examples=[
55
+ [os.path.abspath("../../../videos/originals/n_1.mp4"), "nature", True, "small"],
56
+ [os.path.abspath("../../../videos/originals/n_2.mp4"), "nature", True, "small"],
57
+ [os.path.abspath("../../../videos/originals/n_3.mp4"), "nature", True, "small"],
58
+ [os.path.abspath("../../../videos/originals/n_4.mp4"), "nature", True, "small"],
59
+ [os.path.abspath("../../../videos/originals/n_5.mp4"), "nature", True, "small"],
60
+ [os.path.abspath("../../../videos/originals/n_6.mp4"), "nature", True, "small"],
61
+ [os.path.abspath("../../../videos/originals/n_7.mp4"), "nature", True, "small"],
62
+ [os.path.abspath("../../../videos/originals/n_8.mp4"), "nature", True, "small"],
63
+ [os.path.abspath("../../../videos/originals/s_1.mp4"), "nature", True, "small"],
64
+ [os.path.abspath("../../../videos/originals/s_2.mp4"), "nature", True, "small"],
65
+ [os.path.abspath("../../../videos/originals/s_3.mp4"), "nature", True, "small"],
66
+ [os.path.abspath("../../../videos/originals/s_4.mp4"), "nature", True, "small"],
67
+ [os.path.abspath("../../../videos/originals/s_5.mp4"), "nature", True, "small"],
68
+ [os.path.abspath("../../../videos/originals/s_6.mp4"), "nature", True, "small"],
69
+ [os.path.abspath("../../../videos/originals/s_7.mp4"), "nature", True, "small"],
70
+ [os.path.abspath("../../../videos/originals/s_8.mp4"), "nature", True, "small"],
71
+ ],
72
+ cache_examples=False
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ interface.launch(
77
+ share=False
78
+ )
inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from peft import PeftConfig, get_peft_model
3
+ from audiocraft.models import MusicGen
4
+ from moviepy.editor import AudioFileClip
5
+ from code.inference.inference_utils import *
6
+ import re
7
+ import time
8
+
9
+ re_file_name = re.compile('([^/]+$)')
10
+
11
+
12
+ def generate_background_music(video_path: str,
13
+ dataset: str,
14
+ musicgen_size: str,
15
+ use_stereo: bool,
16
+ use_peft: bool,
17
+ device: str,
18
+ musicgen_temperature: float = 1.0,
19
+ musicgen_guidance_scale: float = 3.0,
20
+ top_k_sampling: int = 250) -> str:
21
+ start = time.time()
22
+ model_path = "../training/"
23
+ model_path += "models_peft" if use_peft else "models_audiocraft"
24
+ model_path += f"/{dataset}" + f"_{musicgen_size}"
25
+
26
+ conf = OmegaConf.load(model_path + '/configuration.yml')
27
+ use_sampling = True if top_k_sampling > 0 else False
28
+ video = mpe.VideoFileClip(video_path)
29
+
30
+ musicgen_model_id = "facebook/musicgen-" + "stereo-" if use_stereo else ""
31
+ musicgen_model_id += musicgen_size
32
+
33
+ result_dir = "./results"
34
+ os.makedirs(result_dir, exist_ok=True)
35
+
36
+ encoder_output_dimension = None
37
+ if "small" in conf.musicgen_model_id:
38
+ encoder_output_dimension = 1024
39
+ elif "medium" in conf.musicgen_model_id:
40
+ encoder_output_dimension = 1536
41
+ elif "large" in conf.musicgen_model_id:
42
+ encoder_output_dimension = 2048
43
+ assert encoder_output_dimension, f"Video Encoder output dimension could not be determined by {conf.musicgen_model_id}"
44
+
45
+ musicgen_model = MusicGen.get_pretrained(musicgen_model_id)
46
+ musicgen_model.lm.to(device)
47
+ musicgen_model.compression_model.to(device)
48
+ if use_peft:
49
+ peft_path = model_path + "/musicgen_peft_final"
50
+ peft_config = PeftConfig.from_pretrained(peft_path)
51
+ musicgen_model.lm = get_peft_model(musicgen_model.lm, peft_config)
52
+ musicgen_model.lm.load_adapter(peft_path, "default")
53
+
54
+ print("MusicGen Model loaded.")
55
+
56
+ video_to_t5 = VideoToT5(
57
+ video_extraction_framerate=conf.video_extraction_framerate,
58
+ encoder_input_dimension=conf.encoder_input_dimension,
59
+ encoder_output_dimension=encoder_output_dimension,
60
+ encoder_heads=conf.encoder_heads,
61
+ encoder_dim_feedforward=conf.encoder_dim_feedforward,
62
+ encoder_layers=conf.encoder_layers,
63
+ device=device
64
+ )
65
+
66
+ video_to_t5.load_state_dict(torch.load(model_path + "/lm_final.pt", map_location=device))
67
+ print("Video Encoder Model loaded.")
68
+
69
+ print("Starting Video Feature Extraction.")
70
+ video_embedding_t5 = video_to_t5(video_paths=[video_path])
71
+
72
+ condition_tensors = create_condition_tensors(
73
+ video_embeddings=video_embedding_t5,
74
+ batch_size=1,
75
+ video_extraction_framerate=video_to_t5.video_extraction_framerate,
76
+ device=device
77
+ )
78
+
79
+ musicgen_model.generation_params = {
80
+ 'max_gen_len': int(video.duration * musicgen_model.frame_rate),
81
+ 'use_sampling': use_sampling,
82
+ 'temp': musicgen_temperature,
83
+ 'cfg_coef': musicgen_guidance_scale,
84
+ 'two_step_cfg': False,
85
+ }
86
+ if use_sampling:
87
+ musicgen_model.generation_params['top_k'] = 250
88
+
89
+ print("Starting Audio Generation.")
90
+ prompt_tokens = None
91
+ with torch.no_grad():
92
+ with musicgen_model.autocast:
93
+ gen_tokens = musicgen_model.lm.generate(prompt_tokens, [], condition_tensors, callback=None,
94
+ **musicgen_model.generation_params)
95
+ gen_audio = musicgen_model.compression_model.decode(gen_tokens)
96
+
97
+ end = time.time()
98
+ print("Elapsed time for generation: " + str(end - start))
99
+
100
+ _, video_file_name = os.path.split(video_path)
101
+ video_file_name = video_file_name[:-4] # remove .mp4
102
+
103
+ re_result = re_file_name.search(video_file_name) # get video file name
104
+ result_path = f"{'peft' if use_peft else 'audiocraft'}_{dataset}_{musicgen_size}_{re_result.group(1)}"
105
+ audio_result_path = f"{result_dir}/tmp.wav"
106
+ video_result_path = f"{result_dir}/{result_path}_video.mp4"
107
+
108
+ gen_audio = torch.squeeze(gen_audio.detach().cpu()) # remove mini-batch dimension, move to CPU for saving
109
+ sample_rate = musicgen_model.sample_rate
110
+ torchaudio.save(audio_result_path, gen_audio, sample_rate)
111
+ audio_file_clip = AudioFileClip(audio_result_path)
112
+ video.audio = audio_file_clip
113
+
114
+ print("Rendering Video.")
115
+ video.write_videofile(video_result_path)
116
+
117
+ return video_result_path
inference_utils.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ from torch import nn, Tensor
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import os
7
+ import logging
8
+ from torchvision.models import resnet50, ResNet50_Weights, resnet152, resnet18, resnet34, ResNet152_Weights
9
+ from PIL import Image
10
+ from time import strftime
11
+ import math
12
+ import numpy as np
13
+ import moviepy.editor as mpe
14
+
15
+
16
+ class VideoDataset(Dataset):
17
+ def __init__(self, data_dir):
18
+ self.data_dir = data_dir
19
+ self.data_map = []
20
+
21
+ dir_map = os.listdir(data_dir)
22
+ for d in dir_map:
23
+ name, extension = os.path.splitext(d)
24
+ if extension == ".mp4":
25
+ self.data_map.append({"video": os.path.join(data_dir, d)})
26
+
27
+ def __len__(self):
28
+ return len(self.data_map)
29
+
30
+ def __getitem__(self, idx):
31
+ return self.data_map[idx]["video"]
32
+
33
+
34
+ # input: video_path, output: wav_music
35
+ class VideoToT5(nn.Module):
36
+ def __init__(self,
37
+ device: str,
38
+ video_extraction_framerate: int,
39
+ encoder_input_dimension: int,
40
+ encoder_output_dimension: int,
41
+ encoder_heads: int,
42
+ encoder_dim_feedforward: int,
43
+ encoder_layers: int
44
+ ):
45
+ super().__init__()
46
+ self.video_extraction_framerate = video_extraction_framerate
47
+ self.video_feature_extractor = VideoFeatureExtractor(video_extraction_framerate=video_extraction_framerate,
48
+ device=device)
49
+ self.video_encoder = VideoEncoder(
50
+ device,
51
+ encoder_input_dimension,
52
+ encoder_output_dimension,
53
+ encoder_heads,
54
+ encoder_dim_feedforward,
55
+ encoder_layers
56
+ )
57
+
58
+ def forward(self, video_paths: [str]):
59
+ image_embeddings = []
60
+ for video_path in video_paths:
61
+ video = mpe.VideoFileClip(video_path)
62
+ video_embedding = self.video_feature_extractor(video)
63
+ image_embeddings.append(video_embedding)
64
+ video_embedding = torch.stack(
65
+ image_embeddings) # resulting shape: [batch_size, video_extraction_framerate, resnet_output_dimension]
66
+ # not used, gives worse results!
67
+ # video_embeddings = torch.mean(video_embeddings, 0, True) # average out all image embedding to one video embedding
68
+
69
+ t5_embeddings = self.video_encoder(video_embedding) # T5 output: [batch_size, num_tokens,
70
+ # t5_embedding_size]
71
+ return t5_embeddings
72
+
73
+
74
+ class VideoEncoder(nn.Module):
75
+ def __init__(self,
76
+ device: str,
77
+ encoder_input_dimension: int,
78
+ encoder_output_dimension: int,
79
+ encoder_heads: int,
80
+ encoder_dim_feedforward: int,
81
+ encoder_layers: int
82
+ ):
83
+ super().__init__()
84
+ self.device = device
85
+ self.encoder = (nn.TransformerEncoder(
86
+ nn.TransformerEncoderLayer(
87
+ d_model=encoder_input_dimension,
88
+ nhead=encoder_heads,
89
+ dim_feedforward=encoder_dim_feedforward
90
+ ),
91
+ num_layers=encoder_layers,
92
+ )
93
+ ).to(device)
94
+
95
+ # linear layer to match T5 embedding dimension
96
+ self.linear = (nn.Linear(
97
+ in_features=encoder_input_dimension,
98
+ out_features=encoder_output_dimension)
99
+ .to(device))
100
+
101
+ def forward(self, x):
102
+ assert x.dim() == 3
103
+ x = torch.transpose(x, 0, 1) # encoder expects [sequence_length, batch_size, embedding_dimension]
104
+ x = self.encoder(x) # encoder forward pass
105
+ x = self.linear(x) # forward pass through the linear layer
106
+ x = torch.transpose(x, 0, 1) # shape: [batch_size, sequence_length, embedding_dimension]
107
+ return x
108
+
109
+
110
+ class VideoFeatureExtractor(nn.Module):
111
+ def __init__(self,
112
+ device: str,
113
+ video_extraction_framerate: int = 1,
114
+ resnet_output_dimension: int = 2048):
115
+ super().__init__()
116
+ self.device = device
117
+
118
+ # using a ResNet trained on ImageNet
119
+ self.resnet = resnet50(weights="IMAGENET1K_V2").eval()
120
+ self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])).to(device) # remove ResNet layer
121
+ self.resnet_preprocessor = ResNet50_Weights.DEFAULT.transforms().to(device)
122
+ self.video_extraction_framerate = video_extraction_framerate # setting the fps at which the video is processed
123
+ self.positional_encoder = PositionalEncoding(resnet_output_dimension).to(device)
124
+
125
+ def forward(self, video: mpe.VideoFileClip):
126
+ embeddings = []
127
+ for i in range(0, 30 * self.video_extraction_framerate):
128
+ i = video.get_frame(i) # get frame as numpy array
129
+ i = Image.fromarray(i) # create PIL image from numpy array
130
+ i = self.resnet_preprocessor(i) # preprocess image
131
+ i = i.to(self.device)
132
+ i = i.unsqueeze(0) # adding a batch dimension
133
+ i = self.resnet(i).squeeze() # ResNet forward pass
134
+ i = i.squeeze()
135
+ embeddings.append(i) # collect embeddings
136
+
137
+ embeddings = torch.stack(embeddings) # concatenate all frame embeddings into one video embedding
138
+ embeddings = embeddings.unsqueeze(1)
139
+ embeddings = self.positional_encoder(embeddings) # apply positional encoding with a sequence length of 30
140
+ embeddings = embeddings.squeeze()
141
+ return embeddings
142
+
143
+
144
+ # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
145
+ class PositionalEncoding(nn.Module):
146
+ def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
147
+ super().__init__()
148
+ self.dropout = nn.Dropout(p=dropout)
149
+ position = torch.arange(30).unsqueeze(1)
150
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
151
+ pe = torch.zeros(30, 1, d_model)
152
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
153
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
154
+ self.register_buffer('pe', pe)
155
+
156
+ def forward(self, x: Tensor) -> Tensor:
157
+ x = x + self.pe[:x.size(0)]
158
+ return self.dropout(x)
159
+
160
+
161
+ def freeze_model(model: nn.Module):
162
+ for param in model.parameters():
163
+ param.requires_grad = False
164
+ model.eval()
165
+
166
+
167
+ def split_dataset_randomly(dataset, validation_split: float, test_split: float, seed: int = None):
168
+ dataset_size = len(dataset)
169
+ indices = list(range(dataset_size))
170
+ datapoints_validation = int(np.floor(validation_split * dataset_size))
171
+ datapoints_testing = int(np.floor(test_split * dataset_size))
172
+
173
+ if seed:
174
+ np.random.seed(seed)
175
+
176
+ np.random.shuffle(indices) # in-place operation
177
+ training = indices[datapoints_validation + datapoints_testing:]
178
+ validation = indices[datapoints_validation:datapoints_testing + datapoints_validation]
179
+ testing = indices[:datapoints_testing]
180
+
181
+ assert len(validation) == datapoints_validation, "Validation set length incorrect"
182
+ assert len(testing) == datapoints_testing, "Testing set length incorrect"
183
+ assert len(training) == dataset_size - (datapoints_testing + datapoints_testing), "Training set length incorrect"
184
+ assert not any([item in training for item in validation]), "Training and Validation overlap"
185
+ assert not any([item in training for item in testing]), "Training and Testing overlap"
186
+ assert not any([item in validation for item in testing]), "Validation and Testing overlap"
187
+
188
+ return training, validation, testing
189
+
190
+
191
+ ### private function from audiocraft.solver.musicgen.py => _compute_cross_entropy
192
+ def compute_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
193
+ """Compute cross entropy between multi-codebook targets and model's logits.
194
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
195
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
196
+ timesteps are set to 0.
197
+
198
+ Args:
199
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
200
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
201
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
202
+ Returns:
203
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
204
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
205
+ """
206
+ B, K, T = targets.shape
207
+ assert logits.shape[:-1] == targets.shape
208
+ assert mask.shape == targets.shape
209
+ ce = torch.zeros([], device=targets.device)
210
+ ce_per_codebook = []
211
+ for k in range(K):
212
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
213
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
214
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
215
+ ce_targets = targets_k[mask_k]
216
+ ce_logits = logits_k[mask_k]
217
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
218
+ ce += q_ce
219
+ ce_per_codebook.append(q_ce.detach())
220
+ # average cross entropy across codebooks
221
+ ce = ce / K
222
+ return ce, ce_per_codebook
223
+
224
+
225
+ def generate_audio_codes(audio_paths: [str],
226
+ audiocraft_compression_model: torch.nn.Module,
227
+ device: str) -> torch.Tensor:
228
+ audio_duration = 30
229
+ encodec_sample_rate = audiocraft_compression_model.sample_rate
230
+
231
+ torch_audios = []
232
+ for audio_path in audio_paths:
233
+ wav, original_sample_rate = torchaudio.load(audio_path) # load audio from file
234
+ wav = torchaudio.functional.resample(wav, original_sample_rate,
235
+ encodec_sample_rate) # cast audio to model sample rate
236
+ wav = wav[:, :encodec_sample_rate * audio_duration] # enforce an exact audio length of 30 seconds
237
+
238
+ assert len(wav.shape) == 2, f"audio data is not of shape [channels, duration]"
239
+ assert wav.shape[0] == 2, "audio data should be in stereo, but has not 2 channels"
240
+
241
+ torch_audios.append(wav)
242
+
243
+ torch_audios = torch.stack(torch_audios)
244
+ torch_audios = torch_audios.to(device)
245
+
246
+ with torch.no_grad():
247
+ gen_audio = audiocraft_compression_model.encode(torch_audios)
248
+
249
+ codes, scale = gen_audio
250
+ assert scale is None
251
+
252
+ return codes
253
+
254
+
255
+ def create_condition_tensors(
256
+ video_embeddings: torch.Tensor,
257
+ batch_size: int,
258
+ video_extraction_framerate: int,
259
+ device: str
260
+ ):
261
+ # model T5 mask
262
+ mask = torch.ones((batch_size, video_extraction_framerate * 30), dtype=torch.int).to(device)
263
+
264
+ condition_tensors = {
265
+ 'description': (video_embeddings, mask)
266
+ }
267
+ return condition_tensors
268
+
269
+
270
+ def get_current_timestamp():
271
+ return strftime("%Y_%m_%d___%H_%M_%S")
272
+
273
+
274
+ def configure_logging(output_dir: str, filename: str, log_level):
275
+ # create logs folder, if not existing
276
+ os.makedirs(output_dir, exist_ok=True)
277
+ level = getattr(logging, log_level)
278
+ file_path = output_dir + "/" + filename
279
+ logging.basicConfig(filename=file_path, encoding='utf-8', level=level)
280
+ logger = logging.getLogger()
281
+ # only add a StreamHandler if it is not present yet
282
+ if len(logger.handlers) <= 1:
283
+ logger.addHandler(logging.StreamHandler())