naotokui commited on
Commit
06f2f15
1 Parent(s): 57ede1f

added: basic commands

Browse files
Files changed (2) hide show
  1. app.py +219 -4
  2. requirements.txt +0 -0
app.py CHANGED
@@ -1,8 +1,223 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  iface.launch()
8
 
 
 
1
+ #%%
2
+ import os
3
+ os.system("git clone https://github.com/v-iashin/SpecVQGAN")
4
+ os.system("pip install pytorch-lightning==1.2.10 omegaconf==2.0.6 streamlit==0.80 matplotlib==3.4.1 albumentations==0.5.2 SoundFile torch torchvision librosa gdown")
5
+
6
+
7
+ # %%
8
+
9
+ import sys
10
+ sys.path.append('./SpecVQGAN')
11
+ import time
12
+ from pathlib import Path
13
+
14
+ import IPython.display as display_audio
15
+ import soundfile
16
+ import torch
17
+ from IPython import display
18
+ from matplotlib import pyplot as plt
19
+ from torch.utils.data.dataloader import default_collate
20
+ from torchvision.utils import make_grid
21
+ from tqdm import tqdm
22
+
23
+ from feature_extraction.demo_utils import (ExtractResNet50, check_video_for_audio,
24
+ extract_melspectrogram, load_model,
25
+ show_grid, trim_video)
26
+ from sample_visualization import (all_attention_to_st, get_class_preditions,
27
+ last_attention_to_st, spec_to_audio_to_st,
28
+ tensor_to_plt)
29
+ from specvqgan.data.vggsound import CropImage
30
+
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+
33
+ # load model
34
+ model_name = '2021-07-30T21-34-25_vggsound_transformer'
35
+ log_dir = './logs'
36
+ os.chdir("./SpecVQGAN/")
37
+ config, sampler, melgan, melception = load_model(model_name, log_dir, device)
38
+ # %%
39
+
40
+ def extract_thumbnails(video_path):
41
+ # Trim the video
42
+ start_sec = 0 # to start with 01:35 use 95 seconds
43
+ video_path = trim_video(video_path, start_sec, trim_duration=10)
44
+
45
+ # Extract Features
46
+ extraction_fps = 21.5
47
+ feature_extractor = ExtractResNet50(extraction_fps, config.data.params, device)
48
+ visual_features, resampled_frames = feature_extractor(video_path)
49
+
50
+ # Show the selected frames to extract features for
51
+ if not config.data.params.replace_feats_with_random:
52
+ fig = show_grid(make_grid(resampled_frames))
53
+ fig.show()
54
+
55
+ # Prepare Input
56
+ batch = default_collate([visual_features])
57
+ batch['feature'] = batch['feature'].to(device)
58
+ c = sampler.get_input(sampler.cond_stage_key, batch)
59
+ return c, video_path
60
+
61
+ # %%
62
+ import numpy as np
63
+
64
+ def generate_audio(video_path, temperature = 1.0):
65
+ # Define Sampling Parameters
66
+ W_scale = 1
67
+ mode = 'full'
68
+ top_x = sampler.first_stage_model.quantize.n_e // 2
69
+ update_every = 0 # use > 0 value, e.g. 15, to see the progress of generation (slows down the sampling speed)
70
+ full_att_mat = True
71
+
72
+ c, video_path = extract_thumbnails(video_path)
73
+
74
+ # Start sampling
75
+ with torch.no_grad():
76
+ start_t = time.time()
77
+
78
+ quant_c, c_indices = sampler.encode_to_c(c)
79
+ # crec = sampler.cond_stage_model.decode(quant_c)
80
+
81
+ patch_size_i = 5
82
+ patch_size_j = 53
83
+
84
+ B, D, hr_h, hr_w = sampling_shape = (1, 256, 5, 53*W_scale)
85
+
86
+ z_pred_indices = torch.zeros((B, hr_h*hr_w)).long().to(device)
87
+
88
+ if mode == 'full':
89
+ start_step = 0
90
+ else:
91
+ start_step = (patch_size_j // 2) * patch_size_i
92
+ z_pred_indices[:, :start_step] = z_indices[:, :start_step]
93
+
94
+ pbar = tqdm(range(start_step, hr_w * hr_h), desc='Sampling Codebook Indices')
95
+ for step in pbar:
96
+ i = step % hr_h
97
+ j = step // hr_h
98
 
99
+ i_start = min(max(0, i - (patch_size_i // 2)), hr_h - patch_size_i)
100
+ j_start = min(max(0, j - (patch_size_j // 2)), hr_w - patch_size_j)
101
+ i_end = i_start + patch_size_i
102
+ j_end = j_start + patch_size_j
103
+
104
+ local_i = i - i_start
105
+ local_j = j - j_start
106
+
107
+ patch_2d_shape = (B, D, patch_size_i, patch_size_j)
108
+
109
+ pbar.set_postfix(
110
+ Step=f'({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})'
111
+ )
112
+
113
+ patch = z_pred_indices \
114
+ .reshape(B, hr_w, hr_h) \
115
+ .permute(0, 2, 1)[:, i_start:i_end, j_start:j_end].permute(0, 2, 1) \
116
+ .reshape(B, patch_size_i * patch_size_j)
117
+
118
+ # assuming we don't crop the conditioning and just use the whole c, if not desired uncomment the above
119
+ cpatch = c_indices
120
+ logits, _, attention = sampler.transformer(patch[:, :-1], cpatch)
121
+ # remove conditioning
122
+ logits = logits[:, -patch_size_j*patch_size_i:, :]
123
+
124
+ local_pos_in_flat = local_j * patch_size_i + local_i
125
+ logits = logits[:, local_pos_in_flat, :]
126
+
127
+ logits = logits / temperature
128
+ logits = sampler.top_k_logits(logits, top_x)
129
+
130
+ # apply softmax to convert to probabilities
131
+ probs = torch.nn.functional.softmax(logits, dim=-1)
132
+
133
+ # sample from the distribution
134
+ ix = torch.multinomial(probs, num_samples=1)
135
+ z_pred_indices[:, j * hr_h + i] = ix
136
+
137
+ if update_every > 0 and step % update_every == 0:
138
+ z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
139
+ # fliping the spectrogram just for illustration purposes (low freqs to bottom, high - top)
140
+ z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
141
+ display.clear_output(wait=True)
142
+ display.display(z_pred_img_st)
143
+
144
+ if full_att_mat:
145
+ att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
146
+ display.display(att_plot)
147
+ plt.close()
148
+ else:
149
+ quant_z_shape = sampling_shape
150
+ c_length = cpatch.shape[-1]
151
+ quant_c_shape = quant_c.shape
152
+ c_att_plot, z_att_plot = last_attention_to_st(
153
+ attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
154
+ sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
155
+ placeholders=None, flip_c_dims=None, flip_z_dims=(2,))
156
+ display.display(c_att_plot)
157
+ display.display(z_att_plot)
158
+ plt.close()
159
+ plt.close()
160
+ plt.close()
161
+
162
+ # quant_z_shape = sampling_shape
163
+ z_pred_img = sampler.decode_to_img(z_pred_indices, sampling_shape)
164
+
165
+ # showing the final image
166
+ z_pred_img_st = tensor_to_plt(z_pred_img, flip_dims=(2,))
167
+ display.clear_output(wait=True)
168
+ display.display(z_pred_img_st)
169
+
170
+ if full_att_mat:
171
+ att_plot = all_attention_to_st(attention, placeholders=None, scale_by_prior=True)
172
+ display.display(att_plot)
173
+ plt.close()
174
+ else:
175
+ quant_z_shape = sampling_shape
176
+ c_length = cpatch.shape[-1]
177
+ quant_c_shape = quant_c.shape
178
+ c_att_plot, z_att_plot = last_attention_to_st(
179
+ attention, local_pos_in_flat, c_length, sampler.first_stage_permuter,
180
+ sampler.cond_stage_permuter, quant_c_shape, patch_2d_shape,
181
+ placeholders=None, flip_c_dims=None, flip_z_dims=(2,)
182
+ )
183
+ display.display(c_att_plot)
184
+ display.display(z_att_plot)
185
+ plt.close()
186
+ plt.close()
187
+ plt.close()
188
+
189
+ print(f'Sampling Time: {time.time() - start_t:3.2f} seconds')
190
+ waves = spec_to_audio_to_st(z_pred_img, config.data.params.spec_dir_path,
191
+ config.data.params.sample_rate, show_griffin_lim=False,
192
+ vocoder=melgan, show_in_st=False)
193
+ print(f'Sampling Time (with vocoder): {time.time() - start_t:3.2f} seconds')
194
+ print(f'Generated: {len(waves["vocoder"]) / config.data.params.sample_rate:.2f} seconds')
195
+
196
+ # Melception opinion on the class distribution of the generated sample
197
+ topk_preds = get_class_preditions(z_pred_img, melception)
198
+ print(topk_preds)
199
+
200
+ audio_path = os.path.join(log_dir, Path(video_path).stem + '.wav')
201
+ audio = waves['vocoder']
202
+ audio = np.repeat([audio], 2, axis=0).T
203
+ print(audio.shape)
204
+ soundfile.write(audio_path, audio, config.data.params.sample_rate, 'PCM_24')
205
+ print(f'The sample has been saved @ {audio_path}')
206
+
207
+
208
+ video_out_path = os.path.join(log_dir, Path(video_path).stem + '_audio.mp4')
209
+ print(video_path, audio_path, video_out_path)
210
+ os.system("ffmpeg -i %s -i %s -map 0:v -map 1:a -c:v copy -shortest %s" % (video_path, audio_path, video_out_path))
211
+
212
+ return video_out_path
213
+ # return config.data.params.sample_rate, audio
214
+
215
+ # %%
216
+ generate_audio("../kiss.avi")
217
+ #%%
218
+ import gradio as gr
219
 
220
+ iface = gr.Interface(generate_audio, "video", "playable_video")
221
  iface.launch()
222
 
223
+ # %%
requirements.txt ADDED
File without changes