Gertie01 commited on
Commit
8645d6f
Β·
verified Β·
1 Parent(s): e08be35

Deploy Gradio app with multiple files

Browse files
Files changed (3) hide show
  1. app.py +53 -0
  2. models.py +141 -0
  3. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import models
3
+
4
+ # Global flag to ensure models are loaded and compiled only once
5
+ # In a multi-file setup, load_and_compile_models should be called once globally
6
+ # before the Gradio app is launched.
7
+ # This assumes models.py gets imported and its global functions run.
8
+ # Alternatively, it could be called within a gr.Blocks.load event, but that's per-session.
9
+ # For AoT, it must be during startup.
10
+
11
+ with gr.Blocks(css=".container { max-width: 1200px; margin: auto; }") as demo:
12
+ gr.HTML("""
13
+ <div style="text-align: center; margin-bottom: 20px;">
14
+ <h1 style="font-size: 2.5em; color: #333;">🎨 SDXL IP-Adapter Image Remixer</h1>
15
+ <p style="font-size: 1.1em; color: #555;">Drag up to three reference images, add a text prompt, and let the AI remix them into something new!</p>
16
+ <p style="font-size: 0.9em; color: #777;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #007bff; text-decoration: none;">anycoder</a></p>
17
+ </div>
18
+ """)
19
+
20
+ with gr.Column(elem_classes="container"):
21
+ with gr.Row():
22
+ image_input_1 = gr.Image(label="Reference Image 1 (Optional)", type="pil", height=256, sources=["upload", "clipboard"], interactive=True)
23
+ image_input_2 = gr.Image(label="Reference Image 2 (Optional)", type="pil", height=256, sources=["upload", "clipboard"], interactive=True)
24
+ image_input_3 = gr.Image(label="Reference Image 3 (Optional)", type="pil", height=256, sources=["upload", "clipboard"], interactive=True)
25
+
26
+ prompt_input = gr.Textbox(
27
+ label="Prompt",
28
+ placeholder="A whimsical creature made of clouds and starlight, fantastical, vivid colors, highly detailed, 4k",
29
+ lines=2,
30
+ interactive=True,
31
+ )
32
+
33
+ generate_btn = gr.Button("Remix Images", variant="primary")
34
+
35
+ output_gallery = gr.Gallery(
36
+ label="Generated Images",
37
+ columns=2, rows=1, height=512, object_fit="contain",
38
+ allow_preview=True,
39
+ interactive=False,
40
+ )
41
+
42
+ # Event listener for the generate button
43
+ generate_btn.click(
44
+ fn=models.remix_images,
45
+ inputs=[prompt_input, image_input_1, image_input_2, image_input_3],
46
+ outputs=output_gallery,
47
+ api_name="remix_images",
48
+ queue=True,
49
+ show_progress="full",
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ demo.launch(max_threads=10)
models.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ from diffusers import DiffusionPipeline, AutoencoderKL
4
+ from ip_adapter import IPAdapter
5
+ from PIL import Image
6
+ import gradio as gr
7
+
8
+ # --- Configuration Constants ---
9
+ SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
10
+ IP_ADAPTER_MODEL_ID = "h94/IP-Adapter-Plus-SDXL"
11
+ IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sdxl_vit-h.bin"
12
+
13
+ # --- Global Model Instances ---
14
+ # These will be initialized and compiled during startup
15
+ pipe_global: DiffusionPipeline = None
16
+ ip_adapter_global: IPAdapter = None
17
+
18
+ @spaces.GPU(duration=1500) # Allocate maximum time for startup compilation
19
+ def load_and_compile_models():
20
+ """
21
+ Loads the SDXL and IP-Adapter models and performs Ahead-of-Time (AoT) compilation
22
+ of the UNet for performance optimization using ZeroGPU.
23
+ This function is called once during application startup.
24
+ """
25
+ global pipe_global, ip_adapter_global
26
+
27
+ print("πŸš€ Starting model loading and compilation...")
28
+
29
+ # 1. Load SDXL base pipeline
30
+ print(f"Loading SDXL base model: {SDXL_BASE_MODEL_ID}")
31
+ pipe_global = DiffusionPipeline.from_pretrained(
32
+ SDXL_BASE_MODEL_ID,
33
+ torch_dtype=torch.float16,
34
+ add_watermarker=False, # Disable watermarking for potential speedup
35
+ variant="fp16" # Use fp16 variant if available for better performance
36
+ )
37
+ # Load VAE separately as recommended for stabilityai models
38
+ pipe_global.vae = AutoencoderKL.from_pretrained(
39
+ "stabilityai/sdxl-vae", torch_dtype=torch.float16, variant="fp16"
40
+ )
41
+ pipe_global.to("cuda")
42
+ print("SDXL base model loaded and moved to CUDA.")
43
+
44
+ # 2. Load IP-Adapter
45
+ print(f"Loading IP-Adapter from: {IP_ADAPTER_MODEL_ID}/{IP_ADAPTER_WEIGHT_NAME}")
46
+ ip_adapter_global = IPAdapter(
47
+ pipe_global,
48
+ image_encoder_path=IP_ADAPTER_MODEL_ID,
49
+ ip_ckpt=IP_ADAPTER_WEIGHT_NAME,
50
+ device="cuda"
51
+ )
52
+ print("IP-Adapter loaded and integrated into the pipeline.")
53
+
54
+ # 3. Perform AoT compilation for the UNet (main generation component)
55
+ print("Starting Ahead-of-Time (AoT) compilation for pipe_global.unet with IP-Adapter...")
56
+
57
+ # Prepare dummy inputs for capturing UNet's forward pass.
58
+ # We need to call a function that internally uses pipe_global.unet
59
+ # and has IP-Adapter inputs integrated. The `ip_adapter_global.generate` method
60
+ # is designed for this. We use minimal steps for tracing.
61
+ dummy_prompt = "a photorealistic image of a beautiful landscape"
62
+ dummy_ip_image = Image.new('RGB', (224, 224), color = 'red') # IP-Adapter typically uses 224x224 or 256x256 input
63
+
64
+ with spaces.aoti_capture(ip_adapter_global.pipe.unet) as call:
65
+ # Execute a minimal generation using the IP-Adapter's generate method.
66
+ # This will trigger the forward pass of `pipe_global.unet` with
67
+ # all the necessary IP-Adapter embeddings, allowing `aoti_capture` to trace it.
68
+ _ = ip_adapter_global.generate(
69
+ prompt=dummy_prompt,
70
+ images=[dummy_ip_image], # Provide a dummy image to trace the IP-Adapter path
71
+ height=1024, width=1024,
72
+ num_inference_steps=2, # Use minimal steps for fast tracing
73
+ guidance_scale=7.5,
74
+ num_images_per_prompt=1,
75
+ output_type="pil",
76
+ ).images[0]
77
+
78
+ # Export the captured UNet module
79
+ print("Exporting UNet...")
80
+ exported_unet = torch.export.export(
81
+ ip_adapter_global.pipe.unet,
82
+ args=call.args,
83
+ kwargs=call.kwargs,
84
+ )
85
+
86
+ # Compile the exported UNet module
87
+ print("Compiling UNet...")
88
+ compiled_unet = spaces.aoti_compile(exported_unet)
89
+ print("UNet compilation complete.")
90
+
91
+ # Apply the compiled module back to the pipeline's UNet
92
+ spaces.aoti_apply(compiled_unet, ip_adapter_global.pipe.unet)
93
+ print("AoT compiled UNet applied to the pipeline.")
94
+ print("βœ… Models loaded and compiled successfully!")
95
+
96
+ # Call the loading and compilation function once when this module is imported
97
+ load_and_compile_models()
98
+
99
+ @spaces.GPU(duration=60) # Allocate up to 60 seconds for actual image generation
100
+ def remix_images(
101
+ prompt: str,
102
+ image1: Image.Image | None,
103
+ image2: Image.Image | None,
104
+ image3: Image.Image | None
105
+ ) -> list[Image.Image]:
106
+ """
107
+ Generates images based on a text prompt and up to three input images using SDXL with IP-Adapter.
108
+
109
+ Args:
110
+ prompt (str): The text prompt for image generation.
111
+ image1 (PIL.Image.Image | None): The first input image.
112
+ image2 (PIL.Image.Image | None): The second input image.
113
+ image3 (PIL.Image.Image | None): The third input image.
114
+
115
+ Returns:
116
+ list[PIL.Image.Image]: A list of generated images.
117
+ """
118
+ if not prompt:
119
+ raise gr.Error("Prompt cannot be empty! Please provide a textual description.")
120
+
121
+ # Filter out None images to create a list of valid input images
122
+ input_images = [img for img in [image1, image2, image3] if img is not None]
123
+
124
+ print(f"Generating image(s) for prompt: '{prompt}'")
125
+ print(f"Using {len(input_images)} input images for IP-Adapter.")
126
+
127
+ # Call the IP-Adapter's generate method.
128
+ # The `ip-adapter` library's `generate` method is designed to handle
129
+ # an empty `images` list by falling back to pure text-to-image generation.
130
+ generated_images = ip_adapter_global.generate(
131
+ prompt=prompt,
132
+ images=input_images, # This can be an empty list
133
+ height=1024, width=1024,
134
+ num_inference_steps=30, # Standard number of inference steps
135
+ guidance_scale=7.5, # Classifier-free guidance scale
136
+ num_images_per_prompt=1, # Generate one image per request
137
+ output_type="pil", # Ensure output is PIL Image objects
138
+ # No seed is used as per requirement
139
+ ).images
140
+
141
+ return generated_images
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ git+https://github.com/huggingface/diffusers
4
+ git+https://github.com/huggingface/transformers
5
+ accelerate
6
+ Pillow
7
+ safetensors
8
+ xformers
9
+ spaces
10
+ ip-adapter