|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
from einops import rearrange, repeat |
|
from omegaconf import OmegaConf |
|
from diffusers import DDIMScheduler, StableDiffusionPipeline, DiffusionPipeline |
|
from torchvision.utils import save_image |
|
from torchvision.io import read_image |
|
from pytorch_lightning import seed_everything |
|
|
|
from masactrl.masactrl_processor import register_attention_processor |
|
|
|
torch.cuda.set_device(0) |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
weight_dtype = torch.float16 |
|
model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
scheduler = DDIMScheduler( |
|
beta_start=0.00085, |
|
beta_end=0.012, |
|
beta_schedule="scaled_linear", |
|
clip_sample=False, |
|
set_alpha_to_one=False |
|
) |
|
pipe = DiffusionPipeline.from_pretrained( |
|
model_path, |
|
scheduler=scheduler, |
|
torch_dtype=weight_dtype |
|
).to(device) |
|
|
|
|
|
def consistent_synthesis(): |
|
seed = 42 |
|
seed_everything(seed) |
|
|
|
out_dir_ori = "./workdir/masactrl_exp/oldman_smiling" |
|
os.makedirs(out_dir_ori, exist_ok=True) |
|
|
|
prompts = [ |
|
"A portrait of an old man, facing camera, best quality", |
|
"A portrait of an old man, facing camera, smiling, best quality", |
|
] |
|
|
|
|
|
|
|
STEP = 4 |
|
LAYER_LIST = [44, 54, 64] |
|
MODEL_TYPE = "SDXL" |
|
|
|
|
|
start_code = torch.randn([1, 4, 128, 128], dtype=weight_dtype, device=device) |
|
|
|
start_code = start_code.expand(len(prompts), -1, -1, -1) |
|
|
|
|
|
image_ori = pipe(prompts, latents=start_code, guidance_scale=7.5).images |
|
|
|
for LAYER in LAYER_LIST: |
|
|
|
processor_args = { |
|
"start_step": STEP, |
|
"start_layer": LAYER, |
|
"model_type": MODEL_TYPE |
|
} |
|
register_attention_processor(pipe.unet, processor_type="MasaCtrlProcessor") |
|
|
|
|
|
image_masactrl = pipe(prompts, latents=start_code, guidance_scale=7.5).images |
|
|
|
sample_count = len(os.listdir(out_dir_ori)) |
|
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}") |
|
os.makedirs(out_dir, exist_ok=True) |
|
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png")) |
|
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png")) |
|
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png")) |
|
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f: |
|
for p in prompts: |
|
f.write(p + "\n") |
|
f.write(f"seed: {seed}\n") |
|
print("Syntheiszed images are saved in", out_dir) |
|
|
|
|
|
if __name__ == "__main__": |
|
consistent_synthesis() |
|
|