import torch import torchaudio #Andy commented: torchaudio.set_audio_backend('soundfile') #Andy commented: from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler from audio_encoders_pytorch import MelE1d, TanhBottleneck #Andy commented: from audiodiffusion.audio_encoder import AudioEncoder #Andy commented: from IPython.display import Audio, display #Andy commented: import matplotlib #Andy commented: import matplotlib.pyplot as plt import pandas as pd #Andy commented: from archisound import ArchiSound print(torch.cuda.is_available(), torch.cuda.device_count()) #Andy removed: import wandb #Andy removed: wandb.init(project="audio_encoder_attack") #Andy commented: from tqdm import tqdm import auraloss from transformers import EncodecModel, AutoProcessor import cdpam import audio_diffusion_attacks_forhf.src.losses as losses #Andy edited: from audiotools import AudioSignal #Andy edited step 2: from audiotools.audiotools.core.audio_signal.py import AudioSignal from audiotools import AudioSignal from audio_diffusion_attacks_forhf.src.balancer import Balancer #Andy commented: from gradnorm_pytorch import ( #Andy commented: GradNormLossWeighter, #Andy commented: MockNetworkWithMultipleLosses #Andy commented: ) '''Andy commented: from audiocraft.losses import ( MelSpectrogramL1Loss, MultiScaleMelSpectrogramLoss, MRSTFTLoss, SISNR, STFTLoss, ) ''' from audio_diffusion_attacks_forhf.src.music_gen import MusicGenEval from audio_diffusion_attacks_forhf.src.speech_inference import XTTS_Eval # From https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html#loading-audio-data-into-tensor def print_stats(waveform, sample_rate=None, src=None): if src: print("-" * 10) print("Source:", src) print("-" * 10) if sample_rate:exit() print("Sample Rate:", sample_rate) print("Shape:", tuple(waveform.shape)) print("Dtype:", waveform.dtype) print(f" - Max: {waveform.max().item():6.3f}") print(f" - Min: {waveform.min().item():6.3f}") print(f" - Mean: {waveform.mean().item():6.3f}") print(f" - Std Dev: {waveform.std().item():6.3f}") print() print(waveform) print() def si_snr(estimate, reference, epsilon=1e-8): estimate = estimate - estimate.mean() reference = reference - reference.mean() reference_pow = reference.pow(2).mean(axis=1, keepdim=True) mix_pow = (estimate * reference).mean(axis=1, keepdim=True) scale = mix_pow / (reference_pow + epsilon) reference = scale * reference error = estimate - reference reference_pow = reference.pow(2) error_pow = error.pow(2) reference_pow = reference_pow.mean(axis=1) error_pow = error_pow.mean(axis=1) si_snr = 10 * torch.log10(reference_pow) - 10 * torch.log10(error_pow) return si_snr.item() # Train autoencoder with audio samples #waveform = torch.randn(2, 2**10) # [batch, in_channels, length] # loss.backward() #andy edited: def poison_audio(audio_folder, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"): def poison_audio(waveform, sample_rate, encoders, audio_difference_weights=[1], method='encoder', weight=1, modality="music"): ''' Protect a folder of audio. audio_folder: string, path to folder of audio files. Protected audio files will be saved in that folder. encoders: encoders to protect against. See initialization at end of file. ''' for encoder in encoders: #Andy removed: encoder.to(device='cuda') encoder.eval() for p in encoder.parameters(): p.requires_grad = False audio_len=1000000 #Andy removed: waveform, sample_rate = torchaudio.load(f"test_audio/Texas Sun.mp3") if modality=="music": music_gen_eval=MusicGenEval(sample_rate, audio_len) elif modality=="speech": music_gen_eval=XTTS_Eval(sample_rate) processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") #Andy edited: loss_fn = cdpam.CDPAM(dev='cuda:0') my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") loss_fn = cdpam.CDPAM(dev=my_device) for p in loss_fn.model.parameters(): p.requires_grad = False #Andy removed: for audio_file in tqdm(os.listdir(audio_folder)): for diff_weight in audio_difference_weights: #Andy edited: waveform, sample_rate = torchaudio.load(os.path.join(audio_folder, audio_file)) # convert mono to stereo if waveform.shape[0]==1: stereo_waveform=torch.zeros((2, waveform.shape[1])) stereo_waveform[:,:]=waveform waveform=stereo_waveform waveform=waveform[:, :audio_len] inputs = processor(raw_audio=waveform, sampling_rate=processor.sampling_rate, return_tensors="pt") waveform=inputs['input_values'][0] #Andy removed: wandb.log({f"unperturbed {audio_name}": wandb.Audio(waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0) waveform=torch.reshape(waveform, (1, waveform.shape[0], waveform.shape[1])) #Andy removed: waveform=waveform.to(device='cuda') #Andy edited: inputs["padding_mask"]=inputs["padding_mask"].to(device='cuda') inputs["padding_mask"]=inputs["padding_mask"] if method=="encoder": unperturbed_waveform=waveform.clone().detach() unperturbed_latents=[] for encoder in encoders: unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach() unperturbed_latents.append(unperturbed_latent) if method=="style_transfer": style_waveform, style_sample_rate = torchaudio.load(f"test_audio/Il Sogno Del Marinaio - Nanos' Waltz.mp3") style_waveform=style_waveform[:, :audio_len] style_inputs = processor(raw_audio=style_waveform, sampling_rate=processor.sampling_rate, return_tensors="pt") style_waveform=style_inputs['input_values'][0] #Andy removed: wandb.log({f"transfer style": wandb.Audio(style_waveform[0].detach().numpy().flatten(), sample_rate=sample_rate)}, step=0) style_waveform=torch.reshape(style_waveform, (1, style_waveform.shape[0], style_waveform.shape[1])) #Andy edited: style_waveform=style_waveform.to(device='cuda') style_waveform=style_waveform #Andy edited: style_inputs["padding_mask"]=style_inputs["padding_mask"].to(device='cuda') style_inputs["padding_mask"]=style_inputs["padding_mask"] # unperturbed_latent=encoder(waveform, inputs["padding_mask"]).audio_values.detach() unperturbed_waveform=style_waveform.clone().detach() unperturbed_latents=[] for encoder in encoders: unperturbed_latent=encoder(style_waveform, style_inputs["padding_mask"]).audio_values.detach() unperturbed_latents.append(unperturbed_latent) noise=torch.normal(torch.zeros(waveform.shape), 0.0) #Andy removed: noise=noise.to(device='cuda') noise.requires_grad=True # waveform=torch.nn.parameter.Parameter(waveform) weights = {'waveform_diff': weight, 'latent_diff': 1} balancer = Balancer(weights) l1loss = torch.nn.L1Loss() # for p in mel_loss.parameters(): # p.requires_grad = False optim = torch.optim.AdamW([noise], lr=0.002, weight_decay=0.005) #optim_diff = torch.optim.Adam([waveform], lr=0.02) # loss_weighter = GradNormLossWeighter( # num_losses = 2, # learning_rate = 0.00002, # restoring_force_alpha = 0., # 0. is perfectly balanced losses, while anything greater than 1 would account for the relative training rates of each loss. in the paper, they go as high as 3. # grad_norm_parameters = waveform # ) downsample = torchaudio.transforms.Resample(sample_rate, 22050) #Andy removed: downsample=downsample.to(device='cuda') cos = torch.nn.CosineSimilarity() mrstft = auraloss.perceptual.FIRFilter()#auraloss.time.SISDRLoss()#torch.nn.functional.l1_loss #Andy removed: mrstft.to(device='cuda') waveform_loss = losses.L1Loss() stft_loss = losses.MultiScaleSTFTLoss() mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mel_fmin=[0, 0, 0, 0, 0, 0, 0], pow=1.0, clamp_eps=1.0e-5, mag_weight=0.0) past_10_latent_losses=[] latent_weight=1000 latent_diff=0 #Andy edited for testing purposes: number_steps=500 number_steps=5 if diff_weight>-1: for step in range(number_steps): latent_diff=0 perturned_waveform=noise+waveform for encoder_ind in range(len(encoders)): perturbed_latent = encoders[encoder_ind](perturned_waveform, inputs["padding_mask"]).audio_values latent_diff+=cos(torch.reshape(perturbed_latent, (1,-1)), torch.reshape(unperturbed_latents[encoder_ind], (1, -1))) #latent_diff+=1-torch.mean(torch.abs((torch.reshape(perturbed_latent, (1,-1))-torch.reshape(unperturbed_latents[encoder_ind], (1, -1))))) #latent_diff=-l1loss(perturbed_latent,unperturbed_latents[0]) latent_diff=latent_diff/len(encoders) #waveform_diff=mrstft(waveform, unperturbed_waveform) #waveform_diff=mrstft(torch.reshape(waveform, (1,-1)), torch.reshape(unperturbed_waveform, (1,-1))) #waveform_diff=si_snr(waveform, unperturbed_waveform) a_waveform=AudioSignal(perturned_waveform, sample_rate) a_uwaveform=AudioSignal(unperturbed_waveform, sample_rate) c_waveform_loss=waveform_loss(a_waveform, a_uwaveform)*100 c_stft_loss=stft_loss(a_waveform, a_uwaveform)/6.0 c_mel_loss=mel_loss(a_waveform, a_uwaveform) l1_loss=torch.mean(torch.abs(perturned_waveform-unperturbed_waveform)) waveform_diff=c_mel_loss#(c_waveform_loss+c_stft_loss+c_mel_loss)/3.0 #loss=100*latent_diff+waveform_diff # loss=latent_weight*latent_diff+waveform_diff # past_10_latent_losses.append(latent_diff.detach().cpu().numpy().item()) # if len(past_10_latent_losses)>10: # mean=sum(past_10_latent_losses)/len(past_10_latent_losses) # if meanlatent_diff*1.01: # latent_weight=latent_weight/1.1 # past_10_latent_losses=past_10_latent_losses[1:] # print('latent_weight', latent_weight) # if latent_diff>0.85: # loss=1500*latent_diff+waveform_diff #loss=1000*latent_diff+waveform_diff # if method=='encoder': # if latent_diff>0.75: # loss=1000*latent_diff+waveform_diff # print('latent') # else: loss=waveform_diff+latent_diff # print('waveform_diff') # elif method=='style_transfer': # loss=latent_diff '''Andy removed: if step%10==0 or step==number_steps-1: wandb.log({"loss": loss, "latent_diff": latent_diff, 'waveform_diff': waveform_diff}, step=step) if step%100==0 or step==number_steps-1: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten() wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) if step%100==0 or step==number_steps-1: music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform) audio_save=torch.reshape(unprotected_gen, (2, unprotected_gen.shape[1]))[0].detach().cpu().numpy().flatten() wandb.log({f"unprotected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) audio_save=torch.reshape(protected_gen, (2, protected_gen.shape[1]))[0].detach().cpu().numpy().flatten() wandb.log({f"protected_gen_{latent_diff}_diff_weight_{diff_weight}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) wandb.log(music_gen_eval_dict, step=step) ''' # if c_mel_loss>0.5: # loss=waveform_diff # else: # loss=latent_diff # noise=noise*0.99 loss_dict = {} loss_dict['waveform_diff'] = waveform_diff loss_dict['latent_diff'] = latent_diff[0] effective_loss = balancer.backward(loss_dict, noise) # loss=latent_diff # loss.backward() #loss_weighter.backward([latent_diff, c_mel_loss]) # torch.nn.utils.clip_grad_norm_(waveform, 10e-8) optim.step() optim.zero_grad() # if latent_diff>0.5: # latent_diff.backward() # optim_diff.step() # optim_diff.zero_grad() # else: # loss=waveform_diff # loss.backward() # optim_diff.step() # optim_diff.zero_grad() encoder.zero_grad() mel_loss.zero_grad() # with torch.no_grad(): # noise_clip=0.25 # noise.clamp_(-noise_clip, noise_clip) # print('noise max', torch.max(noise)) print('step', step, 'loss', loss.item(), 'latent loss', latent_diff.item(), 'audio loss', waveform_diff.item(), 'c_waveform_loss', c_waveform_loss.item(), 'c_stft_loss', c_stft_loss.item(), 'l1_loss', l1_loss.item()) latent_diff=latent_diff.detach().item() #Andy removed: audio_save=torch.reshape((noise+waveform), (2, waveform.shape[2]))[0, :audio_len].detach().cpu().numpy().flatten() #Andy removed: wandb.log({f"perturbed cos_dist_{latent_diff}_diff_weight_{diff_weight}_{audio_name}": wandb.Audio(audio_save, sample_rate=sample_rate)}, step=step) #Andy moved from inside the loop: music_gen_eval_dict, unprotected_gen, protected_gen=music_gen_eval.eval(waveform, noise+waveform) #Andy edited: torchaudio.save(os.path.join(audio_folder, f"protected_{audio_name}_{audio_len}_mel_{latent_diff}_diff_weight_{waveform_diff}"), torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2])), sample_rate) return (torch.reshape((noise+waveform).detach().cpu(), (2, waveform.shape[2]))), music_gen_eval_dict, unprotected_gen, protected_gen # encoders = [ArchiSound.from_pretrained('autoencoder1d-AT-v1'), # ArchiSound.from_pretrained("dmae1d-ATC64-v2"), # ArchiSound.from_pretrained("dmae1d-ATC32-v3"), # AudioEncoder.from_pretrained("teticio/audio-encoder"), encoders = [EncodecModel.from_pretrained("facebook/encodec_48khz")] audio_difference_weights=[1] #Andy commented out: poison_audio(, encoders, [1], method="encoder", weight=weight) #Andy removed: wandb.finish()