BLIP-Diffusion / app_stylization.py
hysts's picture
hysts HF staff
gradio==4.1.1
15b0708
raw
history blame
No virus
4.29 kB
#!/usr/bin/env python
import gradio as gr
import PIL.Image
import spaces
import torch
from controlnet_aux import CannyDetector
from diffusers.pipelines import BlipDiffusionControlNetPipeline
from settings import CACHE_EXAMPLES, DEFAULT_NEGATIVE_PROMPT, MAX_INFERENCE_STEPS
from utils import MAX_SEED, randomize_seed_fn
canny_detector = CannyDetector()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
pipe = BlipDiffusionControlNetPipeline.from_pretrained(
"Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
).to(device)
else:
pipe = None
@spaces.GPU
def run(
condition_image: PIL.Image.Image,
style_image: PIL.Image.Image,
condition_subject: str,
style_subject: str,
prompt: str,
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
seed: int = 0,
guidance_scale: float = 7.5,
num_inference_steps: int = 25,
) -> PIL.Image.Image:
if num_inference_steps > MAX_INFERENCE_STEPS:
raise gr.Error(f"Number of inference steps must be less than {MAX_INFERENCE_STEPS}")
condition_image = canny_detector(condition_image, 30, 70, output_type="pil")
return pipe(
prompt,
style_image,
condition_image,
style_subject,
condition_subject,
generator=torch.Generator(device=device).manual_seed(seed),
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
neg_prompt=negative_prompt,
height=512,
width=512,
).images[0]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
condition_image = gr.Image(label="Condition Image")
style_image = gr.Image(label="Style Image")
condition_subject = gr.Textbox(label="Condition Subject")
style_subject = gr.Textbox(label="Style Subject")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button()
with gr.Accordion(label="Advanced options", open=False):
negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0,
maximum=10,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=MAX_INFERENCE_STEPS,
step=1,
value=25,
)
with gr.Column():
result = gr.Image(label="Result")
gr.Examples(
examples=[
[
"images/kettle.jpg",
"images/flower.jpg",
"teapot",
"flower",
"on a marble table",
],
],
inputs=[
condition_image,
style_image,
condition_subject,
style_subject,
prompt,
],
outputs=result,
fn=run,
cache_examples=CACHE_EXAMPLES,
)
inputs = [
condition_image,
style_image,
condition_subject,
style_subject,
prompt,
negative_prompt,
seed,
guidance_scale,
num_inference_steps,
]
gr.on(
triggers=[
condition_subject.submit,
style_subject.submit,
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
api_name=False,
concurrency_limit=None,
).then(
fn=run,
inputs=inputs,
outputs=result,
api_name="run-stylization",
concurrency_id="gpu",
concurrency_limit=1,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()