Spaces:
Runtime error
Runtime error
File size: 6,069 Bytes
fd1c028 092fcaa fd1c028 092fcaa fd1c028 420a964 092fcaa fd1c028 092fcaa 420a964 092fcaa d639c7d fd1c028 420a964 fd1c028 092fcaa fd1c028 420a964 d639c7d 420a964 fd1c028 4b1b707 d639c7d 420a964 d639c7d 420a964 d639c7d 420a964 d639c7d 420a964 d639c7d 420a964 d639c7d 420a964 092fcaa 420a964 d639c7d 092fcaa d639c7d 092fcaa fd1c028 420a964 33739c5 47782a8 fc4a6c8 47782a8 fc4a6c8 33739c5 420a964 fd1c028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import sa_handler
import pipeline_calls
# Initialize models
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
).to("cuda")
# Configure pipeline for CPU offloading and VAE slicing
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
# Initialize style-aligned handler
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler = sa_handler.Handler(pipeline)
handler.register(sa_args, )
# Function to run ControlNet depth with StyleAligned
def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
try:
if depth_map == True:
image = load_image(ref_image)
depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
else:
depth_image = load_image(ref_image).resize((1024, 1024))
controlnet_conditioning_scale = 0.8
num_images_per_prompt = 3 # adjust according to VRAM size
latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
image=depth_image,
num_inference_steps=50,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=num_images_per_prompt,
latents=latents)
return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True)
except Exception as e:
raise gr.Error(f"Error in generating images:{e}")
# Create a Gradio UI
with gr.Blocks() as demo:
gr.HTML('<h1 style="text-align: center;">Style-aligned with ControlNet Depth</h1>')
with gr.Row():
with gr.Column(variant='panel'):
# Textbox for reference style prompt
ref_style_prompt = gr.Textbox(
label='Reference style prompt',
info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
)
# Checkbox for using controller depth-map
depth_map = gr.Checkbox(label='Depth-map',)
# Image display for the generated reference style image
ref_style_image = gr.Image(visible=False, label='Reference style image')
with gr.Column(variant='panel'):
# Image upload option for uploading a reference image for controlnet
ref_image = gr.Image(label="Upload the reference image",
type='filepath' )
# Textbox for ControlNet prompt
img_generation_prompt = gr.Textbox(
label='ControlNet Prompt',
info="Enter a Prompt to generate images using ControlNet and Style-aligned",
)
# Button to trigger image generation
btn = gr.Button("Generate", size='sm')
# Gallery to display generated images
gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images",
elem_id="gallery",
columns=5,
rows=1,
object_fit="contain",
height="auto",
)
btn.click(fn=style_aligned_controlnet,
inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt],
outputs=[gallery, ref_style_image],
api_name="style_aligned_controlnet")
# Example inputs for the Gradio interface
gr.Examples(
examples=[
['A poster in a papercut art style.', False, 'example_image/A.png', 'Letter A in a papercut art style.'],
['A couple sitting a wooden bench, in colorful clay animation, claymation style.', True, 'example_image/train.jpg', 'A train in colorful clay animation, claymation style.'],
['A couple sitting a wooden bench, in clay animation, claymation style.', True, 'example_image/sun.png', 'Sun in clay animation, claymation style.'],
['A bull in a low-poly, colorful origami style.', True, 'example_image/whale.png', 'A whale in a low-poly, colorful origami style.'],
['A house in a painterly, digital illustration style.', True, 'example_image/camel.jpg', 'A camel in a painterly, digital illustration style.'],
['An image in ancient egyptian art style, hieroglyphics style.', True, 'example_image/whale.png', 'A whale in ancient egyptian art style, hieroglyphics style.'],
],
inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt],
outputs=[gallery, ref_style_image],
fn=style_aligned_controlnet,
)
# Launch the Gradio demo
demo.launch()
|