File size: 6,207 Bytes
dc9672a
a6e28db
f48deb6
a6e28db
f48deb6
e94cea2
 
5a48378
a6e28db
 
5a48378
1002305
f48deb6
 
 
 
 
e94cea2
1002305
 
 
f48deb6
1002305
 
e94cea2
1002305
 
 
 
d1f2a65
1002305
 
73decb9
 
d1f2a65
1002305
 
b74c847
f48deb6
1002305
a6e28db
e94cea2
 
 
 
 
 
 
 
 
 
 
 
 
 
152de82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e94cea2
 
 
 
 
 
 
 
 
 
 
 
dc9672a
2db0f69
1002305
 
 
 
f48deb6
 
 
 
 
1002305
152de82
 
 
e94cea2
152de82
a6e28db
e94cea2
 
a6e28db
e94cea2
 
 
 
73decb9
 
 
 
e94cea2
 
 
 
73decb9
2db0f69
dc9672a
a6e28db
2db0f69
dc9672a
f48deb6
 
 
 
 
dc9672a
 
f48deb6
 
e94cea2
 
f48deb6
1002305
dc9672a
 
f48deb6
 
e94cea2
 
f48deb6
1002305
 
e94cea2
 
1002305
e94cea2
 
 
1002305
dc9672a
e94cea2
 
 
 
 
 
dc9672a
f48deb6
e94cea2
 
 
 
 
 
 
 
 
 
 
f48deb6
 
e94cea2
dc9672a
 
 
 
 
e94cea2
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from controlnet_aux import CannyDetector
import gc
import numpy as np
from PIL import Image

# Initialize the canny edge detector
canny = CannyDetector()

def create_pipeline():
    # Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

    # Load ControlNet
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/sd-controlnet-canny",
        torch_dtype=torch.float16,
        use_safetensors=True
    )

    # Load pipeline
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        "nitrosocke/Ghibli-Diffusion",
        controlnet=controlnet,
        torch_dtype=torch.float16,
        safety_checker=None
    )
    
    if torch.cuda.is_available():
        pipe.enable_model_cpu_offload()
        pipe.enable_attention_slicing(1)
    
    return pipe

# Create pipeline
pipe = create_pipeline()

def enhance_prompt(base_prompt):
    """Add detailed Ghibli-specific style keywords to the prompt"""
    style_elements = [
        "Studio Ghibli masterpiece",
        "hand-painted animation style",
        "Hayao Miyazaki inspired",
        "soft detailed lighting",
        "gentle color palette",
        "delicate line art",
        "atmospheric background"
    ]
    
    return f"{', '.join(style_elements)}, {base_prompt}, high quality, detailed features, smooth lines"

def preprocess_image(image):
    """Preprocess image to ensure consistent dimensions"""
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Resize image to a maximum size while maintaining aspect ratio
    max_size = 512
    ratio = max_size / max(image.size)
    new_size = tuple([int(x * ratio) for x in image.size])
    image = image.resize(new_size, Image.Resampling.LANCZOS)
    
    # Create a new image with padding to make it square
    new_image = Image.new("RGB", (max_size, max_size), (255, 255, 255))
    offset = ((max_size - new_size[0]) // 2, (max_size - new_size[1]) // 2)
    new_image.paste(image, offset)
    
    return new_image

def process_image_for_canny(image):
    """Optimize image for better edge detection"""
    # Convert to numpy array if it's a PIL Image
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # Ensure image is in RGB format
    if len(image.shape) == 2:  # If grayscale
        image = np.stack([image] * 3, axis=-1)
    
    return image

def generate_image(input_image, prompt):
    try:
        if input_image is None:
            raise gr.Error("Please upload an image")
        if not prompt:
            raise gr.Error("Please enter a prompt")
        
        # Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            
        # Preprocess the input image first
        preprocessed_image = preprocess_image(input_image)
            
        # Process image for better edge detection
        processed_image = process_image_for_canny(preprocessed_image)
        
        # Generate canny edge detection with optimized parameters
        canny_image = canny(processed_image, low_threshold=100, high_threshold=200)
        
        # Enhance prompt with style elements
        enhanced_prompt = enhance_prompt(prompt)
        
        # Generate image with optimized parameters
        with torch.inference_mode():
            output_image = pipe(
                prompt=enhanced_prompt,
                image=canny_image,
                num_inference_steps=30,  # Increased for better detail
                guidance_scale=8.5,      # Increased for stronger adherence to prompt
                controlnet_conditioning_scale=1.0,  # Balance between control and creativity
                negative_prompt="blurry, low quality, broken lines, distorted features, asymmetrical"
            ).images[0]
        
        return output_image, enhanced_prompt
    
    except Exception as e:
        raise gr.Error(str(e))
    finally:
        # Clear memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()

# Create Gradio interface
with gr.Blocks(css="style.css") as demo:
    gr.Markdown("""
    # 🎨 Enhanced Ghibli Art Generator
    Transform your images into the magical style of Studio Ghibli with improved detail and quality
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                type="pil",
                label="Upload Image",
                elem_id="input-image"
            )
            prompt = gr.Textbox(
                label="Enter your prompt",
                placeholder="A peaceful mountain cabin surrounded by nature...",
                elem_id="prompt-input"
            )
            with gr.Row():
                generate_btn = gr.Button("🎨 Generate", variant="primary", elem_id="generate-btn")
                clear_btn = gr.Button("🗑️ Clear", elem_id="clear-btn")
        
        with gr.Column():
            output_image = gr.Image(label="Generated Image", elem_id="output-image")
            used_prompt = gr.Textbox(
                label="Enhanced Prompt",
                elem_id="enhanced-prompt",
                interactive=False
            )
    
    gr.Markdown("""
    ## 🌟 Improved Features
    - Enhanced detail with 30 inference steps
    - Stronger style adherence with 8.5 guidance scale
    - Optimized edge detection
    - Rich Ghibli-style prompt enhancement
    
    ## 💡 Tips
    - Use clear, well-lit images
    - Be specific in your prompts
    - Include mood and atmosphere descriptions
    - Expect 15-20 seconds for generation
    """)
    
    # Set up event handlers
    generate_btn.click(
        fn=generate_image,
        inputs=[input_image, prompt],
        outputs=[output_image, used_prompt]
    )
    
    clear_btn.click(
        lambda: [None, ""],
        outputs=[output_image, used_prompt]
    )

# Launch with minimal queue and custom queue message
demo.queue(max_size=5, concurrency_count=1).launch(
    share=False,
    debug=True,
    show_error=True
)