File size: 11,480 Bytes
457b619
 
 
 
e30bec4
457b619
e30bec4
c7e8bdd
 
457b619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30bec4
457b619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7e8bdd
457b619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30bec4
457b619
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30bec4
 
457b619
e30bec4
457b619
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# --- Fix 1: Set Matplotlib backend ---
import matplotlib
matplotlib.use('Agg') # Set backend BEFORE importing pyplot or other conflicting libs
# --- End Fix 1 ---

import gradio as gr
import torch
from diffusers import  EulerAncestralDiscreteScheduler
from DoodlePix_pipeline import StableDiffusionInstructPix2PixPipeline
from PIL import Image, ImageOps # Added ImageOps for inversion
import numpy as np
import os
import importlib
import traceback # For detailed error printing

# --- FidelityMLP Class (Ensure this is correct as provided by user) ---
class FidelityMLP(torch.nn.Module):
    def __init__(self, hidden_size, output_size=None):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size or hidden_size
        self.net = torch.nn.Sequential(
            torch.nn.Linear(1, 128), torch.nn.LayerNorm(128), torch.nn.SiLU(),
            torch.nn.Linear(128, 256), torch.nn.LayerNorm(256), torch.nn.SiLU(),
            torch.nn.Linear(256, hidden_size), torch.nn.LayerNorm(hidden_size), torch.nn.Tanh()
        )
        self.output_proj = torch.nn.Linear(hidden_size, self.output_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None: module.bias.data.zero_()

    def forward(self, x, target_dim=None):
        features = self.net(x)
        outputs = self.output_proj(features)
        if target_dim is not None and target_dim != self.output_size:
            return self._adjust_dimension(outputs, target_dim)
        return outputs

    def _adjust_dimension(self, embeddings, target_dim):
        current_dim = embeddings.shape[-1]
        if target_dim > current_dim:
            pad_size = target_dim - current_dim
            padding = torch.zeros((*embeddings.shape[:-1], pad_size), device=embeddings.device, dtype=embeddings.dtype)
            return torch.cat([embeddings, padding], dim=-1)
        elif target_dim < current_dim:
            return embeddings[..., :target_dim]
        return embeddings

    def save_pretrained(self, save_directory):
        os.makedirs(save_directory, exist_ok=True)
        config = {"hidden_size": self.hidden_size, "output_size": self.output_size}
        torch.save(config, os.path.join(save_directory, "config.json"))
        torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))

    @classmethod
    def from_pretrained(cls, pretrained_model_path):
        config_file = os.path.join(pretrained_model_path, "config.json")
        model_file = os.path.join(pretrained_model_path, "pytorch_model.bin")
        if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found at {config_file}")
        if not os.path.exists(model_file): raise FileNotFoundError(f"Model file not found at {model_file}")
        try:
            config = torch.load(config_file, map_location=torch.device('cpu'))
            if not isinstance(config, dict): raise TypeError(f"Expected config dict, got {type(config)}")
        except Exception as e: print(f"Error loading config {config_file}: {e}"); raise
        model = cls(hidden_size=config["hidden_size"], output_size=config.get("output_size", config["hidden_size"]))
        try:
            state_dict = torch.load(model_file, map_location=torch.device('cpu'))
            model.load_state_dict(state_dict)
            print(f"Successfully loaded FidelityMLP state dict from {model_file}")
        except Exception as e: print(f"Error loading state dict {model_file}: {e}"); raise
        return model

# --- Global Variables ---
pipeline = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "Scaryplasmon96/DoodlePixV1"

# --- Model Loading Function ---
def load_pipeline():
    global pipeline
    if pipeline is not None: return True
    print(f"Loading model {model_id} onto {device}...")
    try:
        hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
        local_model_path = model_id # Let diffusers find/download

        # Load Fidelity MLP if possible
        fidelity_mlp_instance = None
        try:
            from huggingface_hub import snapshot_download, hf_hub_download
            # Attempt to download config first to check existence
            hf_hub_download(repo_id=model_id, filename="fidelity_mlp/config.json", cache_dir=hf_cache_dir)
            # If config exists, download the whole subfolder
            fidelity_mlp_path = snapshot_download(repo_id=model_id, allow_patterns="fidelity_mlp/*", local_dir_use_symlinks=False, cache_dir=hf_cache_dir)
            fidelity_mlp_instance = FidelityMLP.from_pretrained(os.path.join(fidelity_mlp_path, "fidelity_mlp"))
            fidelity_mlp_instance = fidelity_mlp_instance.to(device=device, dtype=torch.float16)
            print("Fidelity MLP loaded successfully.")
        except Exception as e:
            print(f"Fidelity MLP not found or failed to load for {model_id}: {e}. Proceeding without MLP.")
            fidelity_mlp_instance = None

        scheduler = EulerAncestralDiscreteScheduler.from_pretrained(local_model_path, subfolder="scheduler")
        pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
            local_model_path, torch_dtype=torch.float16, scheduler=scheduler, safety_checker=None
        ).to(device)

        if fidelity_mlp_instance:
            pipeline.fidelity_mlp = fidelity_mlp_instance
            print("Attached Fidelity MLP to pipeline.")

        # Optimizations
        if device == "cuda" and hasattr(pipeline, "enable_xformers_memory_efficient_attention"):
            try: pipeline.enable_xformers_memory_efficient_attention(); print("Enabled xformers.")
            except: print("Could not enable xformers. Using attention slicing."); pipeline.enable_attention_slicing()
        else: pipeline.enable_attention_slicing(); print("Enabled attention slicing.")

        print("Pipeline loaded successfully.")
        return True
    except Exception as e:
        print(f"Error loading pipeline: {e}"); traceback.print_exc()
        pipeline = None; raise gr.Error(f"Failed to load model: {e}")

# --- Image Generation Function (Corrected Input Handling) ---
def generate_image(drawing_input, prompt, fidelity_slider, steps, guidance, image_guidance, seed_val):
    global pipeline
    if pipeline is None:
        if not load_pipeline(): return None, "Model not loaded. Check logs."

    # --- Corrected Input Processing ---
    print(f"DEBUG: Received drawing_input type: {type(drawing_input)}")
    if isinstance(drawing_input, dict): print(f"DEBUG: Received drawing_input keys: {drawing_input.keys()}")

    # Check if input is dict and get PIL image from 'composite' key
    if isinstance(drawing_input, dict) and "composite" in drawing_input and isinstance(drawing_input["composite"], Image.Image):
        input_image_pil = drawing_input["composite"].convert("RGB") # Get composite image
        print("DEBUG: Using PIL Image from 'composite' key.")
    else:
        err_msg = "Drawing input format unexpected. Expected dict with PIL Image under 'composite' key."
        print(f"ERROR: {err_msg} Input: {drawing_input}")
        return None, err_msg
    # --- End Corrected Input Processing ---

    try:
        # Invert the image: White bg -> Black bg, Black lines -> White lines
        input_image_inverted = ImageOps.invert(input_image_pil)
        #save the inverted image
        # input_image_inverted.save("input_image_inverted.png")

        # Ensure image is 512x512
        if input_image_inverted.size != (512, 512):
            print(f"Resizing input image from {input_image_inverted.size} to (512, 512)")
            input_image_inverted = input_image_inverted.resize((512, 512), Image.Resampling.LANCZOS)

        # Prompt Construction
        final_prompt = f"f{int(fidelity_slider)}, {prompt}"
        if not final_prompt.endswith("background."): final_prompt += " background."

        negative_prompt = "artifacts, blur, jpg, uncanny, deformed, glow, shadow, text, words, letters, signature, watermark"

        # Generation
        print(f"Generating with: Prompt='{final_prompt[:100]}...', Fidelity={int(fidelity_slider)}, Steps={steps}, Guidance={guidance}, ImageGuidance={image_guidance}, Seed={seed_val}")
        seed_val = int(seed_val)
        generator = torch.Generator(device=device).manual_seed(seed_val)

        with torch.no_grad():
             output = pipeline(
                 prompt=final_prompt, negative_prompt=negative_prompt, image=input_image_inverted,
                 num_inference_steps=int(steps), guidance_scale=float(guidance),
                 image_guidance_scale=float(image_guidance), generator=generator,
             ).images[0]

        print("Generation complete.")
        return output, "Generation Complete"

    except Exception as e:
        print(f"Error during generation: {e}"); traceback.print_exc()
        return None, f"Error during generation: {str(e)}"

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", secondary_hue="blue")) as demo:
    gr.Markdown("# DoodlePix Gradio App")
    gr.Markdown(f"Using model: `{model_id}`.")
    status_output = gr.Textbox(label="Status", interactive=False, value="App loading...")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("## 1. Draw Something (Black on White)")
            # Keep type="pil" as it provides the composite key
            drawing = gr.Sketchpad(
                label="Drawing Canvas",
                type="pil", # type="pil" gives dict output with 'composite' key
                height=512, width=512,
                brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=5),
                show_label=True
            )
            prompt_input = gr.Textbox(label="2. Enter Prompt", placeholder="Describe the image you want...")
            fidelity = gr.Slider(0, 9, step=1, value=4, label="Fidelity (0=Creative, 9=Faithful)")
            num_steps = gr.Slider(10, 50, step=1, value=25, label="Inference Steps")
            guidance_scale = gr.Slider(1.0, 15.0, step=0.5, value=7.5, label="Guidance Scale (CFG)")
            image_guidance_scale = gr.Slider(0.5, 5.0, step=0.1, value=1.5, label="Image Guidance Scale")
            seed = gr.Number(label="Seed", value=42, precision=0)
            generate_button = gr.Button("🚀 Generate Image!", variant="primary")

        with gr.Column(scale=1):
            gr.Markdown("## 3. Generated Image")
            output_image = gr.Image(label="Result", type="pil", height=512, width=512, show_label=True)

    generate_button.click(
        fn=generate_image,
        inputs=[drawing, prompt_input, fidelity, num_steps, guidance_scale, image_guidance_scale, seed],
        outputs=[output_image, status_output]
    )

# --- Launch App ---
if __name__ == "__main__":
    initial_status = "App loading..."
    print("Attempting to pre-load pipeline...")
    try:
        if load_pipeline(): initial_status = "Model pre-loaded successfully."
        else: initial_status = "Model pre-loading failed. Will retry on first generation."
    except Exception as e:
        print(f"Pre-loading failed: {e}")
        initial_status = f"Model pre-loading failed: {e}. Will retry on first generation."
    print(f"Pre-loading status: {initial_status}")

    demo.launch()