LayoutPainter / app.py
HBDing's picture
upd
7cfa686
import gradio as gr
from gradio_image_annotation import image_annotator
from diffusers import EulerDiscreteScheduler
import torch
import os
import random
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore
from migc.migc_utils import seed_everything, load_migc
from huggingface_hub import hf_hub_download
# 下载模型文件
migc_ckpt_path = hf_hub_download(repo_id="limuloo1999/MIGC", filename="MIGC_SD14.ckpt")
RV_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1.safetensors")
anime_path = hf_hub_download(repo_id="ckpt/cetus-mix", filename="cetusMix_v4.safetensors")
# -------- 风格切换器类 --------
class StyleSwitcher:
def __init__(self):
self.pipe = None
self.attn_store = AttentionStore()
self.styles = {
"realistic": RV_path,
"anime": anime_path
}
self.current_style = None
def load_model(self, style):
if style == self.current_style:
return self.pipe
if self.pipe:
del self.pipe
torch.cuda.empty_cache()
print(f"[Info] Switched from {self.current_style} to {style}.")
model_path = self.styles[style]
print(f"[Info] Loading {style} model...")
self.pipe = StableDiffusionMIGCPipeline.from_single_file(
model_path,
torch_dtype=torch.float32
)
self.pipe.safety_checker = None
self.pipe.attention_store = self.attn_store
load_migc(self.pipe.unet, self.attn_store, migc_ckpt_path, attn_processor=MIGCProcessor)
self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
self.current_style = style
return self.pipe
style_switcher = StyleSwitcher()
# ⬇️ 新增函数:返回随机 seed
def generate_random_seed():
return random.randint(0, 2**32 - 1)
# 生成函数
def get_boxes_json(annotations, seed_value, edit_mode, style_selection):
seed_everything(seed_value)
pipe = style_switcher.load_model(style_selection)
image = annotations["image"]
width = image.shape[1]
height = image.shape[0]
boxes = annotations["boxes"]
prompt_final = [[]]
bboxes = [[]]
for box in boxes:
box["xmin"] /= width
box["xmax"] /= width
box["ymin"] /= height
box["ymax"] /= height
prompt_final[0].append(box["label"])
bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]])
prompt = ", ".join(prompt_final[0])
prompt_final[0].insert(0, prompt)
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry'
output_image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5,
MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt,
sa_preserve=True, use_sa_preserve=edit_mode).images[0]
return output_image
# 示例标注图
example_annotation = {
"image": os.path.join(os.path.dirname(__file__), "background.png"),
"boxes": [],
}
# ------------- Gradio UI -------------
with gr.Blocks() as demo:
with gr.Tab("DreamRenderer", id="DreamRenderer"):
with gr.Row():
with gr.Column(scale=1):
annotator = image_annotator(example_annotation, height=512, width=512)
with gr.Column(scale=1):
generated_image = gr.Image(label="Generated Image", height=512, width=512)
seed_input = gr.Number(label="Seed (Optional)", precision=0)
seed_random_btn = gr.Button("🎲 Random Seed")
edit_mode_toggle = gr.Checkbox(label="Edit Mode")
style_selector = gr.Radio(choices=["realistic", "anime"], label="风格选择", value="realistic")
button_get = gr.Button("生成图像")
button_get.click(
fn=get_boxes_json,
inputs=[annotator, seed_input, edit_mode_toggle, style_selector],
outputs=generated_image
)
seed_random_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input)
if __name__ == "__main__":
demo.launch(share=True)