import librosa import numpy as np import moviepy.editor as mpy import random import torch from tqdm import tqdm import stylegan3.dnnlib import stylegan3.legacy def visualize(audio_file, network, truncation, batch_size, *args, **kwargs): # print(audio_file, truncation, network) # print(args) # print(kwargs) if audio_file: print('\nReading audio \n') y, sr = librosa.load(audio_file.name) else: raise ValueError("you must enter an audio file name in the --song argument") resolution = 512 duration = None frame_length = 512 tempo_sensitivity = 0.25 tempo_sensitivity = tempo_sensitivity * frame_length / 512 jitter = 0.5 outfile = "output.mp4" # Load pre-trained model device = torch.device('cuda') with stylegan3.dnnlib.open_url(network) as f: G = stylegan3.legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore G.eval() with torch.no_grad(): z = torch.randn([1, G.z_dim]).cuda() # latent codes c = None # class labels (not used in this example) img = G(z, c) # NCHW, float32, dynamic range [-1, +1], no truncation #set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #create spectrogram spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=512,fmax=8000, hop_length=frame_length) #get mean power at each time point specm=np.mean(spec,axis=0) #compute power gradient across time points gradm=np.gradient(specm) #set max to 1 gradm=gradm/np.max(gradm) #set negative gradient time points to zero gradm = gradm.clip(min=0) #normalize mean power between 0-1 specm=(specm-np.min(specm))/np.ptp(specm) #initialize first noise vector nv1 = torch.randn([G.z_dim]).cuda() #initialize list of class and noise vectors noise_vectors=[nv1] #initialize previous vectors (will be used to track the previous frame) nvlast=nv1 #initialize the direction of noise vector unit updates update_dir=np.zeros(512) print(len(nv1)) for ni,n in enumerate(nv1): if n<0: update_dir[ni] = 1 else: update_dir[ni] = -1 #initialize noise unit update update_last=np.zeros(512) #get new jitters def new_jitters(jitter): jitters=np.zeros(512) for j in range(512): if random.uniform(0,1)<0.5: jitters[j]=1 else: jitters[j]=1-jitter return jitters #get new update directions def new_update_dir(nv2,update_dir): for ni,n in enumerate(nv2): if n >= 2*truncation - tempo_sensitivity: update_dir[ni] = -1 elif n < -2*truncation + tempo_sensitivity: update_dir[ni] = 1 return update_dir print('\nGenerating input vectors \n') for i in tqdm(range(len(gradm))): #update jitter vector every 100 frames by setting ~half of noise vector units to lower sensitivity if i%200==0: jitters=new_jitters(jitter) #get last noise vector nv1=nvlast #set noise vector update based on direction, sensitivity, jitter, and combination of overall power and gradient of power update = np.array([tempo_sensitivity for k in range(512)]) * (gradm[i]+specm[i]) * update_dir * jitters #smooth the update with the previous update (to avoid overly sharp frame transitions) update=(update+update_last*3)/4 #set last update update_last=update #update noise vector nv2=nv1.cpu()+update #append to noise vectors noise_vectors.append(nv2) #set last noise vector nvlast=nv2 #update the direction of noise units update_dir=new_update_dir(nv2,update_dir) noise_vectors = torch.stack([nv.cuda() for nv in noise_vectors]) print('\n\nGenerating frames \n') frames = [] for i in tqdm(range(noise_vectors.shape[0] // batch_size)): #print progress pass noise_vector=noise_vectors[i*batch_size:(i+1)*batch_size] c = None # class labels (not used in this example) with torch.no_grad(): img = np.array(G(noise_vector, c, truncation_psi=truncation, noise_mode='const').cpu()) # NCHW, float32, dynamic range [-1, +1], no truncation img = np.transpose(img, (0,2,3,1)) #CHW -> HWC img = np.clip((img * 127.5 + 128), 0, 255).astype(np.uint8) # add to frames for im in img: frames.append(im) #Save video aud = mpy.AudioFileClip(audio_file.name, fps = 44100) if duration: aud.duration = duration fps = 22050/frame_length clip = mpy.ImageSequenceClip(frames, fps=fps) clip = clip.set_audio(aud) clip.write_videofile(outfile, audio_codec='aac', ffmpeg_params=["-vf", "scale=-1:2160:flags=lanczos", "-bf", "2", "-g", f"{fps/2}", "-crf", "18", "-movflags", "faststart"]) return outfile