AP123's picture
Update app.py
c8963b4 verified
raw
history blame
No virus
2.68 kB
import spaces
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, EulerAncestralDiscreteScheduler
import torch
import gradio as gr
from PIL import Image
import numpy as np
# Load the models
controlnet = ControlNetModel.from_pretrained(
"briaai/BRIA-2.2-ControlNet-Recoloring",
torch_dtype=torch.float16
).to('cuda')
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"briaai/BRIA-2.2",
controlnet=controlnet,
torch_dtype=torch.float16,
device_map='auto',
low_cpu_mem_usage=True,
offload_state_dict=True,
).to('cuda')
pipe.scheduler = EulerAncestralDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
steps_offset=1
)
pipe.force_zeros_for_empty_prompt = False
def resize_image(image):
image = image.convert('RGB')
current_size = image.size
transform = gr.Image(height=1024, width=1024, keep_aspect_ratio=True, source="upload", tool="editor")
resized_image = transform.postprocess(image)
return resized_image
@spaces.GPU(enable_queue=True)
def generate_image(input_image, prompt, controlnet_conditioning_scale):
# Always use a random seed for diversity in outputs
seed = np.random.randint(2147483647)
generator = torch.Generator("cuda").manual_seed(seed)
# Resize and prepare the image
input_image = resize_image(input_image)
grayscale_image = input_image.convert('L').convert('RGB')
# Generate the image with fixed 30 steps
images = pipe(
prompt=prompt,
image=grayscale_image,
num_inference_steps=30,
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
generator=generator,
).images
return images[0]
# Gradio Interface
description = "Anything to Anything. Transform anything to anything. Allow an adjuster for controlnet scale."
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Image Transformation with Bria Recolor ControlNet</center></h1>")
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(label='Upload your image', type="pil")
prompt = gr.Textbox(label='Enter your prompt')
controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
submit_button = gr.Button('Transform Image')
with gr.Column():
output_image = gr.Image(label='Transformed Image')
submit_button.click(fn=generate_image, inputs=[input_image, prompt, controlnet_conditioning_scale], outputs=output_image)
demo.queue().launch()