Deadmon's picture
Update app.py
7f15638 verified
raw
history blame
3.9 kB
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
from einops import rearrange
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.utils import load_image
from gradio_imageslider import ImageSlider # Import ImageSlider
# Load the new ControlNet model
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Union'
device = torch.device("cuda")
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to(device)
def preprocess_image(image, target_width, target_height, crop=True):
if crop:
image = image.crop((0, 0, min(image.size), min(image.size))) # Crop the image to square
original_width, original_height = image.size
# Resize to match the target size without stretching
scale = max(target_width / original_width, target_height / original_height)
resized_width = int(scale * original_width)
resized_height = int(scale * original_height)
image = image.resize((resized_width, resized_height), Image.LANCZOS)
# Center crop to match the target dimensions
left = (resized_width - target_width) // 2
top = (resized_height - target_height) // 2
image = image.crop((left, top, left + target_width, top + target_height))
else:
image = image.resize((target_width, target_height), Image.LANCZOS)
return image
def preprocess_canny_image(image, target_width, target_height, crop=True):
image = preprocess_image(image, target_width, target_height, crop=crop)
image = np.array(image.convert('L')) # Convert to grayscale for Canny processing
image = cv2.Canny(image, 100, 200) # Apply Canny edge detection
image = Image.fromarray(image)
return image
def generate_image(prompt, control_image, num_steps=24, guidance=3.5, width=512, height=512, seed=42, random_seed=False, control_mode=0):
if random_seed:
seed = np.random.randint(0, 10000)
if not os.path.isdir("./controlnet_results/"):
os.makedirs("./controlnet_results/")
torch.manual_seed(seed)
control_image = preprocess_canny_image(control_image, width, height) # Preprocess the control image for Canny mode
controlnet_conditioning_scale = 0.5 # ControlNet conditioning scale
# Generate the image using the pipeline
image = pipe(
prompt,
control_image=control_image,
control_mode=control_mode,
width=width,
height=height,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_steps,
guidance_scale=guidance,
).images[0]
return [control_image, image] # Return both images for slider
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt"),
gr.Image(type="pil", label="Control Image"),
gr.Slider(step=1, minimum=1, maximum=64, value=24, label="Num Steps"),
gr.Slider(minimum=0.1, maximum=10, value=3.5, label="Guidance"),
gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Width"),
gr.Slider(minimum=128, maximum=2048, step=128, value=1024, label="Height"),
gr.Number(value=42, label="Seed"),
gr.Checkbox(label="Random Seed"),
gr.Radio(choices=[0, 1, 2, 3, 4, 5, 6], value=0, label="Control Mode")
],
outputs=ImageSlider(label="Before / After"), # Use ImageSlider as the output
title="FLUX.1 Controlnet Canny",
description="Generate images using ControlNet and a text prompt.\n[[non-commercial license, Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)]"
)
if __name__ == "__main__":
interface.launch()