import numpy as np import torch import safetensors from safetensors.torch import save_file import matplotlib.pyplot as plt model = safetensors.safe_open('sd3_medium_incl_clips_t5xxlfp16.safetensors', 'pt') keys = model.keys() dic = {key:model.get_tensor(key) for key in keys} parts = ['diffusion_model'] count = 0 for k in keys: if all(i in k for i in parts): v = dic[k] print(f'{k}: {v.std()}') dic[k] += torch.normal(torch.zeros_like(v)*v.mean(), torch.ones_like(v)*v.std()*.02) count += 1 print(count) save_file(dic, 'sd3_medium_incl_clips_t5xxlfp16.safetensors_perturbed3.safetensors', model.metadata())