ciaochaos's picture
inference logic
5416a13
raw
history blame
2.61 kB
from PIL import Image
import gradio as gr
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
import torch
controlnet = ControlNetModel.from_pretrained("ioclab/control_v1p_sd15_brightness", torch_dtype=torch.float32, use_safetensors=True)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
def infer(prompt, negative_prompt, num_inference_steps, conditioning_image):
# conditioning_image = Image.open(conditioning_image)
conditioning_image = Image.fromarray(conditioning_image)
generator = torch.Generator(device="cpu").manual_seed(1500)
output_image = pipe(
prompt,
conditioning_image,
height=512,
width=512,
num_inference_steps=num_inference_steps,
generator=generator,
negative_prompt=negative_prompt,
controlnet_conditioning_scale=1.0,
).images[0]
return output_image
with gr.Blocks() as demo:
gr.Markdown(
"""
# ControlNet on Brightness
This is a demo on ControlNet based on brightness.
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
)
num_inference_steps = gr.Slider(
10, 40, 20,
step=1,
label="Steps",
)
conditioning_image = gr.Image(
label="Conditioning Image",
)
submit_btn = gr.Button(
value="Submit",
variant="primary"
)
with gr.Column(min_width=300):
output = gr.Image(
label="Result",
)
submit_btn.click(
fn=infer,
inputs=[
prompt, negative_prompt, num_inference_steps, conditioning_image
],
outputs=output
)
gr.Examples(
examples=[
["a painting of a village in the mountains", "monochrome", "./conditioning_images/conditioning_image_1.jpg"],
["three people walking in an alleyway with hats and pants", "monochrome", "./conditioning_images/conditioning_image_2.jpg"],
],
inputs=[
prompt, negative_prompt, conditioning_image
],
)
demo.launch()