File size: 3,956 Bytes
5c781ca
9608f17
 
 
f18b0fd
c63316c
5c781ca
9608f17
 
 
f18b0fd
9608f17
 
 
 
 
 
 
 
 
 
779eb3b
9608f17
 
 
 
 
 
 
 
 
f18b0fd
9608f17
 
 
 
 
 
 
c63316c
9608f17
 
 
 
 
 
 
 
 
f18b0fd
9608f17
 
f18b0fd
9608f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f18b0fd
9608f17
f18b0fd
9608f17
5c781ca
9608f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from diffusers import AutoencoderKL, LCMScheduler
from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from controlnet import ControlNetModel
import torch
import numpy as np
from PIL import Image
from io import BytesIO
from torchvision import transforms
import requests

# Utility functions
def resize_image_to_retain_ratio(image):
    pixel_number = 1024 * 1024
    granularity_val = 8
    ratio = image.width / image.height
    width = int((pixel_number * ratio) ** 0.5)
    width -= width % granularity_val
    height = int(pixel_number / width)
    height -= height % granularity_val
    return image.resize((width, height))

def get_masked_image(image, mask):
    image = np.array(image).astype(np.float32) / 255.0
    mask = np.array(mask.convert("L")).astype(np.float32) / 255.0
    masked_vis = image.copy()
    image[mask > 0.5] = 0.5
    masked_vis[mask > 0.5] = 0.5
    return (Image.fromarray((image * 255).astype(np.uint8)),
            Image.fromarray((masked_vis * 255).astype(np.uint8)),
            mask)

# Load model once
device = "cuda" if torch.cuda.is_available() else "cpu"
controlnet = ControlNetModel.from_pretrained(
    "briaai/BRIA-2.3-ControlNet-Generative-Fill", torch_dtype=torch.float16
).to(device)
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.3",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16
).to(device)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
pipe.fuse_lora()

# Image transforms
image_transforms = transforms.Compose([transforms.ToTensor()])

def inference(init_img, mask_img, prompt, neg_prompt,
              steps, guidance, control_scale, seed):
    # Resize and prepare
    init_img = resize_image_to_retain_ratio(init_img)
    masked_img, vis_img, mask_arr = get_masked_image(init_img, mask_img)
    
    # Encode masked image
    tensor = image_transforms(masked_img).unsqueeze(0).to(device)
    latents = vae.encode(tensor.to(vae.dtype)).latent_dist.sample() * vae.config.scaling_factor
    
    # Prepare mask tensor
    mask_t = torch.tensor(mask_arr)[None, None, ...].to(device)
    mask_resized = torch.nn.functional.interpolate(mask_t, size=(latents.shape[2], latents.shape[3]), mode='nearest')
    
    # Control image
    control = torch.cat([latents, mask_resized], dim=1)
    
    generator = torch.Generator(device=device).manual_seed(int(seed))
    output = pipe(
        prompt=prompt,
        negative_prompt=neg_prompt,
        controlnet_conditioning_scale=control_scale,
        num_inference_steps=steps,
        guidance_scale=guidance,
        image=control,
        init_image=init_img,
        mask_image=mask_t[:, 0],
        generator=generator,
        height=init_img.height,
        width=init_img.width,
    ).images[0]
    return output

# Build Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("## BRIA-2.3 ControlNet Inpainting Demo")
    with gr.Row():
        inp = gr.Image(source="upload", type="pil", label="Input Image")
        msk = gr.Image(source="upload", type="pil", label="Mask Image")
    prompt = gr.Textbox(label="Prompt", placeholder="Describe the desired content")
    neg = gr.Textbox(label="Negative Prompt", value="blurry")
    steps = gr.Slider(1, 50, value=12, step=1, label="Inference Steps")
    guidance = gr.Slider(0.0, 10.0, value=1.2, step=0.1, label="Guidance Scale")
    scale = gr.Slider(0.0, 5.0, value=1.0, step=0.1, label="ControlNet Scale")
    seed = gr.Number(label="Seed", value=123456)
    btn = gr.Button("Generate")
    out = gr.Image(type="pil", label="Output")
    btn.click(
        fn=inference,
        inputs=[inp, msk, prompt, neg, steps, guidance, scale, seed],
        outputs=out,
    )
demo.launch(server_name="0.0.0.0", server_port=7860)