Spaces:
Running
Running
import gradio as gr | |
from gradio_image_annotation import image_annotator | |
from diffusers import EulerDiscreteScheduler | |
import torch | |
import os | |
import random | |
from migc.migc_pipeline import StableDiffusionMIGCPipeline, MIGCProcessor, AttentionStore | |
from migc.migc_utils import seed_everything, load_migc | |
from huggingface_hub import hf_hub_download | |
# 下载模型文件 | |
migc_ckpt_path = hf_hub_download(repo_id="limuloo1999/MIGC", filename="MIGC_SD14.ckpt") | |
RV_path = hf_hub_download(repo_id="SG161222/Realistic_Vision_V6.0_B1_noVAE", filename="Realistic_Vision_V6.0_NV_B1.safetensors") | |
anime_path = hf_hub_download(repo_id="ckpt/cetus-mix", filename="cetusMix_v4.safetensors") | |
# -------- 风格切换器类 -------- | |
class StyleSwitcher: | |
def __init__(self): | |
self.pipe = None | |
self.attn_store = AttentionStore() | |
self.styles = { | |
"realistic": RV_path, | |
"anime": anime_path | |
} | |
self.current_style = None | |
def load_model(self, style): | |
if style == self.current_style: | |
return self.pipe | |
if self.pipe: | |
del self.pipe | |
torch.cuda.empty_cache() | |
print(f"[Info] Switched from {self.current_style} to {style}.") | |
model_path = self.styles[style] | |
print(f"[Info] Loading {style} model...") | |
self.pipe = StableDiffusionMIGCPipeline.from_single_file( | |
model_path, | |
torch_dtype=torch.float32 | |
) | |
self.pipe.safety_checker = None | |
self.pipe.attention_store = self.attn_store | |
load_migc(self.pipe.unet, self.attn_store, migc_ckpt_path, attn_processor=MIGCProcessor) | |
self.pipe = self.pipe.to("cuda" if torch.cuda.is_available() else "cpu") | |
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config) | |
self.current_style = style | |
return self.pipe | |
style_switcher = StyleSwitcher() | |
# ⬇️ 新增函数:返回随机 seed | |
def generate_random_seed(): | |
return random.randint(0, 2**32 - 1) | |
# 生成函数 | |
def get_boxes_json(annotations, seed_value, edit_mode, style_selection): | |
seed_everything(seed_value) | |
pipe = style_switcher.load_model(style_selection) | |
image = annotations["image"] | |
width = image.shape[1] | |
height = image.shape[0] | |
boxes = annotations["boxes"] | |
prompt_final = [[]] | |
bboxes = [[]] | |
for box in boxes: | |
box["xmin"] /= width | |
box["xmax"] /= width | |
box["ymin"] /= height | |
box["ymax"] /= height | |
prompt_final[0].append(box["label"]) | |
bboxes[0].append([box["xmin"], box["ymin"], box["xmax"], box["ymax"]]) | |
prompt = ", ".join(prompt_final[0]) | |
prompt_final[0].insert(0, prompt) | |
negative_prompt = 'worst quality, low quality, bad anatomy, watermark, text, blurry' | |
output_image = pipe(prompt_final, bboxes, num_inference_steps=30, guidance_scale=7.5, | |
MIGCsteps=15, aug_phase_with_and=False, negative_prompt=negative_prompt, | |
sa_preserve=True, use_sa_preserve=edit_mode).images[0] | |
return output_image | |
# 示例标注图 | |
example_annotation = { | |
"image": os.path.join(os.path.dirname(__file__), "background.png"), | |
"boxes": [], | |
} | |
# ------------- Gradio UI ------------- | |
with gr.Blocks() as demo: | |
with gr.Tab("DreamRenderer", id="DreamRenderer"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
annotator = image_annotator(example_annotation, height=512, width=512) | |
with gr.Column(scale=1): | |
generated_image = gr.Image(label="Generated Image", height=512, width=512) | |
seed_input = gr.Number(label="Seed (Optional)", precision=0) | |
seed_random_btn = gr.Button("🎲 Random Seed") | |
edit_mode_toggle = gr.Checkbox(label="Edit Mode") | |
style_selector = gr.Radio(choices=["realistic", "anime"], label="风格选择", value="realistic") | |
button_get = gr.Button("生成图像") | |
button_get.click( | |
fn=get_boxes_json, | |
inputs=[annotator, seed_input, edit_mode_toggle, style_selector], | |
outputs=generated_image | |
) | |
seed_random_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |