|
import torch
|
|
import torchvision.transforms as T
|
|
from safetensors.torch import load_file, save_file
|
|
file_path = "./sd3_medium_incl_clips_t5xxlfp8.safetensors"
|
|
loaded = load_file(file_path)
|
|
for i in loaded:
|
|
|
|
if i.find("model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.9 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.9 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.8 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.85 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.8 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((1.15 + 1) / 2)
|
|
elif i.find("model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight") != -1:
|
|
print(i)
|
|
loaded[i] = loaded[i] * ((0.9 + 1) / 2)
|
|
save_file(loaded, "sd3_manual_surgery-test1_medium_incl_clips_t5xxlfp8.safetensors")
|
|
|
|
|