File size: 5,114 Bytes
ae0b9cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import librosa
import numpy as np
import moviepy.editor as mpy
import random
import torch
from tqdm import tqdm
import stylegan3.dnnlib
import stylegan3.legacy


def visualize(audio_file, network, truncation, batch_size, *args, **kwargs):
    # print(audio_file, truncation, network)
    # print(args)
    # print(kwargs)

    if audio_file:
        print('\nReading audio \n')
        y, sr = librosa.load(audio_file.name)
    else:
        raise ValueError("you must enter an audio file name in the --song argument")

    resolution = 512

    duration = None

    frame_length = 512

    tempo_sensitivity = 0.25
    tempo_sensitivity = tempo_sensitivity * frame_length / 512

    jitter = 0.5

    outfile = "output.mp4"

    # Load pre-trained model
    device = torch.device('cuda')
    with stylegan3.dnnlib.open_url(network) as f:
        G = stylegan3.legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
        G.eval()

    with torch.no_grad():
        z = torch.randn([1, G.z_dim]).cuda()    # latent codes
        c = None                                # class labels (not used in this example)
        img = G(z, c)                           # NCHW, float32, dynamic range [-1, +1], no truncation

    #set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #create spectrogram
    spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=512,fmax=8000, hop_length=frame_length)

    #get mean power at each time point
    specm=np.mean(spec,axis=0)

    #compute power gradient across time points
    gradm=np.gradient(specm)

    #set max to 1
    gradm=gradm/np.max(gradm)

    #set negative gradient time points to zero
    gradm = gradm.clip(min=0)

    #normalize mean power between 0-1
    specm=(specm-np.min(specm))/np.ptp(specm)

    #initialize first noise vector
    nv1 = torch.randn([G.z_dim]).cuda()

    #initialize list of class and noise vectors
    noise_vectors=[nv1]

    #initialize previous vectors (will be used to track the previous frame)
    nvlast=nv1

    #initialize the direction of noise vector unit updates
    update_dir=np.zeros(512)
    print(len(nv1))
    for ni,n in enumerate(nv1):
        if n<0:
            update_dir[ni] = 1
        else:
            update_dir[ni] = -1

    #initialize noise unit update
    update_last=np.zeros(512)

    #get new jitters
    def new_jitters(jitter):
        jitters=np.zeros(512)
        for j in range(512):
            if random.uniform(0,1)<0.5:
                jitters[j]=1
            else:
                jitters[j]=1-jitter
        return jitters


    #get new update directions
    def new_update_dir(nv2,update_dir):
        for ni,n in enumerate(nv2):
            if n >= 2*truncation - tempo_sensitivity:
                update_dir[ni] = -1

            elif n < -2*truncation + tempo_sensitivity:
                update_dir[ni] = 1
        return update_dir

    print('\nGenerating input vectors \n')
    for i in tqdm(range(len(gradm))):

        #update jitter vector every 100 frames by setting ~half of noise vector units to lower sensitivity
        if i%200==0:
            jitters=new_jitters(jitter)

        #get last noise vector
        nv1=nvlast

        #set noise vector update based on direction, sensitivity, jitter, and combination of overall power and gradient of power
        update = np.array([tempo_sensitivity for k in range(512)]) * (gradm[i]+specm[i]) * update_dir * jitters

        #smooth the update with the previous update (to avoid overly sharp frame transitions)
        update=(update+update_last*3)/4

        #set last update
        update_last=update

        #update noise vector
        nv2=nv1.cpu()+update

        #append to noise vectors
        noise_vectors.append(nv2)

        #set last noise vector
        nvlast=nv2

        #update the direction of noise units
        update_dir=new_update_dir(nv2,update_dir)

    noise_vectors = torch.stack([nv.cuda() for nv in noise_vectors])


    print('\n\nGenerating frames \n')
    frames = []
    for i in tqdm(range(noise_vectors.shape[0] // batch_size)):

        #print progress
        pass

        noise_vector=noise_vectors[i*batch_size:(i+1)*batch_size]

        c = None  # class labels (not used in this example)
        with torch.no_grad():
            img = np.array(G(noise_vector, c, truncation_psi=truncation, noise_mode='const').cpu())            # NCHW, float32, dynamic range [-1, +1], no truncation
            img = np.transpose(img, (0,2,3,1)) #CHW -> HWC
            img = np.clip((img * 127.5 + 128), 0, 255).astype(np.uint8)

        # add to frames
        for im in img:
            frames.append(im)


    #Save video
    aud = mpy.AudioFileClip(audio_file.name, fps = 44100)

    if duration:
        aud.duration = duration

    fps = 22050/frame_length
    clip = mpy.ImageSequenceClip(frames, fps=fps)
    clip = clip.set_audio(aud)
    clip.write_videofile(outfile, audio_codec='aac', ffmpeg_params=["-vf", "scale=-1:2160:flags=lanczos", "-bf", "2", "-g", f"{fps/2}", "-crf", "18", "-movflags", "faststart"])


    return outfile