File size: 4,157 Bytes
c68160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a39e35
c68160d
 
 
 
 
 
6dcf304
c68160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dcf304
c68160d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import librosa
import numpy as np
import gradio as gr
import soundfile as sf

from moviepy.editor import *


cache_wav_path = [f'/tmp/{str(i).zfill(2)}.wav' for i in range(50)]
wave_path_iter = iter(cache_wav_path)
cache_mp4_path = [f'/tmp/{str(i).zfill(2)}.mp4' for i in range(50)]
path_iter = iter(cache_mp4_path)

def merge_times(times, times2):
    ids = np.unique(np.where(abs(times2[...,None] - times[None]) < 0.2)[1])
    mask = np.ones_like(times, dtype=np.bool)
    mask[ids] = False
    times = times[mask]
    times = np.concatenate([times, times2])
    times = np.sort(times)

    return times


def beat_interpolator(wave_path, generator, latent_dim, seed, fps=30, batch_size=1, strength=1, max_duration=None, use_peak=False):
    fps = max(10, fps)
    strength = np.clip(strength, 0, 1)
    hop_length = 512
    y, sr = librosa.load(wave_path, sr=24000)
    duration = librosa.get_duration(y=y, sr=sr)

    if max_duration is not None:
        y_len = y.shape[0]
        y_idx = int(y_len * max_duration / duration)
        y = y[:y_idx]

        global wave_path_iter
        try:
            wave_path = next(wave_path_iter)
        except:
            wave_path_iter = iter(cache_wav_path)
            wave_path = next(wave_path_iter)
        sf.write(wave_path, y, sr, subtype='PCM_24')
        y, sr = librosa.load(wave_path, sr=24000)
        duration = librosa.get_duration(y=y, sr=sr)
    
    S = np.abs(librosa.stft(y))
    db = librosa.power_to_db(S**2, ref=np.median).max(0)
    db_mean = np.mean(db)
    db_max = np.max(db)
    db_min = np.min(db)
    db_times = librosa.frames_to_time(np.arange(len(db)), sr=sr, hop_length=hop_length)
    rng = np.random.RandomState(seed)
    onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=512, aggregate=np.median)
    _, beats = librosa.beat.beat_track(y=y, sr=sr, onset_envelope=onset_env, hop_length=512, units='time')
    times = np.asarray(beats)
    if use_peak:
        peaks = librosa.util.peak_pick(onset_env, 1, 1, 1, 1, 0.8, 5)
        times2 = librosa.frames_to_time(np.arange(len(onset_env)), sr=sr, hop_length=512)[peaks]
        times2 = np.asarray(times)
        times = merge_times(times, times2)
        
    times = np.concatenate([np.asarray([0.]), times, np.asarray([duration])], 0)
    times = list(np.unique(np.int64(np.floor(times * fps / 2))) * 2)

    latents = []
    time0 = 0
    latent0 = rng.randn(latent_dim)
    for time1 in times:
        latent1 = latent0 * (1 - strength) + rng.randn(latent_dim) * strength
        db_cur_index = np.argmin(np.abs(db_times - time1.astype('float32') / fps))
        db_cur = db[db_cur_index]
        if db_cur < db_min + (db_mean - db_min) / 3:
            latent1 = latent0 * 0.8 + latent1 * 0.2
        elif db_cur < db_min + 2 * (db_mean - db_min) / 3:
            latent1 = latent0 * 0.6 + latent1 * 0.4
        elif db_cur < db_mean + (db_max - db_mean) / 3:
            latent1 = latent0 * 0.4 + latent1 * 0.6
        elif db_cur < db_mean + 2 * (db_max - db_mean) / 3:
            latent1 = latent0 * 0.2 + latent1 * 0.8
        else:
            pass
        if time1 > duration * fps:
            time1 = int(duration * fps)
        t1 = time1 - time0
        alpha = 0.5
        latent2 = latent0 * alpha + latent1 * (1 - alpha)
        for j in range(t1):
            alpha = j / t1
            latent = latent0 * (1 - alpha) + latent2 * alpha
            latents.append(latent)
        
        time0 = time1
        latent0 = latent1
        
    outs = []
    ix = 0
    while True:
        if ix + batch_size <= len(latents):
            outs += generator(latents[ix:ix+batch_size])
        elif ix < len(latents):
            outs += generator(latents[ix:])
            break
        else:
            break
        ix += batch_size

    global path_iter
    try:
        video_path = next(path_iter)
    except:
        path_iter = iter(cache_mp4_path)
        video_path = next(path_iter)
    
    video = ImageSequenceClip(outs, fps=fps)
    audioclip = AudioFileClip(wave_path)

    video = video.set_audio(audioclip)
    video.write_videofile(video_path, fps=fps)

    return video_path