svjack's picture
Upload 23 files
f070657 verified
import gradio as gr
import torch
from diffusers import DDIMScheduler, DiffusionPipeline
from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase, regiter_attention_editor_diffusers
from masactrl.masactrl import MutualSelfAttentionControl
from pytorch_lightning import seed_everything
import os
import re
# 初始化设备和模型
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = DiffusionPipeline.from_pretrained("svjack/GenshinImpact_XL_Base", scheduler=scheduler).to(device)
def pathify(s):
return re.sub(r'[^a-zA-Z0-9]', '_', s.lower())
def consistent_synthesis(prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer):
seed_everything(seed)
# 创建输出目录
out_dir_ori = os.path.join("masactrl_exp", pathify(prompt2))
os.makedirs(out_dir_ori, exist_ok=True)
prompts = [prompt1, prompt2]
# 初始化噪声图
start_code = torch.randn([1, 4, 128, 128], device=device)
start_code = start_code.expand(len(prompts), -1, -1, -1)
# 推理没有 MasaCtrl 的图像
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
images = []
# 劫持注意力模块
editor = MutualSelfAttentionControl(starting_step, starting_layer, model_type="SDXL")
regiter_attention_editor_diffusers(model, editor)
# 推理带 MasaCtrl 的图像
image_masactrl = model(prompts, latents=start_code, guidance_scale=guidance_scale).images
sample_count = len(os.listdir(out_dir_ori))
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)
image_ori[0].save(os.path.join(out_dir, f"source_step{starting_step}_layer{starting_layer}.png"))
image_ori[1].save(os.path.join(out_dir, f"without_step{starting_step}_layer{starting_layer}.png"))
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{starting_step}_layer{starting_layer}.png"))
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
for p in prompts:
f.write(p + "\n")
f.write(f"seed: {seed}\n")
f.write(f"starting_step: {starting_step}\n")
f.write(f"starting_layer: {starting_layer}\n")
print("Synthesized images are saved in", out_dir)
return [image_ori[0], image_ori[1], image_masactrl[-1]]
def create_demo_synthesis():
with gr.Blocks() as demo:
gr.Markdown("# **Genshin Impact XL MasaCtrl Image Synthesis**") # 添加标题
gr.Markdown("## **Input Settings**")
with gr.Row():
with gr.Column():
prompt1 = gr.Textbox(label="Prompt 1", value="solo,ZHONGLI(genshin impact),1boy,highres,")
prompt2 = gr.Textbox(label="Prompt 2", value="solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,")
with gr.Row():
starting_step = gr.Slider(label="Starting Step", minimum=0, maximum=999, value=4, step=1)
starting_layer = gr.Slider(label="Starting Layer", minimum=0, maximum=999, value=64, step=1)
run_btn = gr.Button("Run")
with gr.Column():
guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
gr.Markdown("## **Output**")
with gr.Row():
image_source = gr.Image(label="Source Image")
image_without_masactrl = gr.Image(label="Image without MasaCtrl")
image_with_masactrl = gr.Image(label="Image with MasaCtrl")
inputs = [prompt1, prompt2, guidance_scale, seed, starting_step, starting_layer]
run_btn.click(consistent_synthesis, inputs, [image_source, image_without_masactrl, image_with_masactrl])
gr.Examples(
[
["solo,ZHONGLI(genshin impact),1boy,highres,", "solo,ZHONGLI drink tea use chinese cup (genshin impact),1boy,highres,", 42, 4, 64],
["solo,KAMISATO AYATO(genshin impact),1boy,highres,", "solo,KAMISATO AYATO smiling (genshin impact),1boy,highres,", 42, 4, 55]
],
[prompt1, prompt2, seed, starting_step, starting_layer],
)
return demo
if __name__ == "__main__":
demo_synthesis = create_demo_synthesis()
demo_synthesis.launch(share = True)