|
import gradio as gr
|
|
import torch
|
|
from diffusers import DDIMScheduler, DiffusionPipeline
|
|
from masactrl.diffuser_utils import MasaCtrlPipeline
|
|
from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
|
|
from masactrl.masactrl import MutualSelfAttentionControl
|
|
from pytorch_lightning import seed_everything
|
|
import os
|
|
import re
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
|
model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
|
|
|
|
def pathify(s):
|
|
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
|
|
|
|
def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
|
|
seed_everything(seed)
|
|
|
|
|
|
out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
|
|
os.makedirs(out_dir_ori, exist_ok=True)
|
|
|
|
prompts = [prompt1, prompt2]
|
|
|
|
|
|
start_code = torch.randn([1, 4, 128, 128], device=device)
|
|
start_code = start_code.expand(len(prompts), -1, -1, -1)
|
|
|
|
|
|
editor = AttentionBase()
|
|
regiter_attention_editor_diffusers(model, editor)
|
|
image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
|
|
|
images = []
|
|
|
|
editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
|
|
regiter_attention_editor_diffusers(model, editor)
|
|
|
|
|
|
image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
|
|
|
|
sample_count = len(os.listdir(out_dir_ori))
|
|
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
|
|
image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
|
|
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
|
|
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
|
|
for p in prompts:
|
|
f.write(p + "\n")
|
|
f.write(f"seed: {seed}\n")
|
|
f.write(f"starting_step: {starting_step}\n")
|
|
f.write(f"starting_layer: {starting_layer}\n")
|
|
print("Synthesized images are saved in", out_dir)
|
|
|
|
return [image_ori[0], image_ori[1], image_masactrl[-1]]
|
|
|
|
def create_demo_synthesis():
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**")
|
|
gr.Markdown("## **Input Settings**")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
|
|
prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
|
|
with gr.Row():
|
|
starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
|
|
starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
|
|
run_btn = gr.Button("Run")
|
|
with gr.Column():
|
|
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
|
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
|
|
|
|
gr.Markdown("## **Output**")
|
|
with gr.Row():
|
|
image_source = gr.Image(label="Source Image")
|
|
image_without_masactrl = gr.Image(label="Image without MasaCtrl")
|
|
image_with_masactrl = gr.Image(label="Image with MasaCtrl")
|
|
|
|
inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
|
|
run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
|
|
|
|
gr.Examples(
|
|
[
|
|
["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
|
|
["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
|
|
],
|
|
[prompt1, prompt2, seed, starting_step, starting_layer],
|
|
)
|
|
return demo
|
|
|
|
if __name__ == "__main__":
|
|
demo_synthesis = create_demo_synthesis()
|
|
demo_synthesis.launch(share = True)
|
|
|