File size: 4,574 Bytes
bd199cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e18ee
bd199cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83e18ee
 
 
 
 
 
 
 
 
e0fc0eb
 
83e18ee
 
 
 
 
 
 
 
 
 
 
bd199cf
 
83e18ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
from torchvision import transforms
from SDXL.diff_pipe import StableDiffusionXLDiffImg2ImgPipeline
from diffusers import DPMSolverMultistepScheduler

NUM_INFERENCE_STEPS = 50
device = "cuda"

base = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to(device)

refiner = StableDiffusionXLDiffImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to(device)

base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config)


def preprocess_image(image):
    image = image.convert("RGB")
    image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))(image)
    image = transforms.ToTensor()(image)
    image = image * 2 - 1
    image = image.unsqueeze(0).to(device)
    return image


def preprocess_map(map):
    map = map.convert("L")
    map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map)
    # convert to tensor
    map = transforms.ToTensor()(map)
    map = map.to(device)
    return map


def inference(image, map, gs, prompt, negative_prompt):
    validate_inputs(image, map)
    image = preprocess_image(image)
    map = preprocess_map(map)
    edited_images = base(prompt=prompt, original_image=image, image=image, strength=1, guidance_scale=gs,
                         num_images_per_prompt=1,
                         negative_prompt=negative_prompt,
                         map=map,
                         num_inference_steps=NUM_INFERENCE_STEPS, denoising_end=0.8, output_type="latent").images

    edited_images = refiner(prompt=prompt, original_image=image, image=edited_images, strength=1, guidance_scale=7.5,
                            num_images_per_prompt=1,
                            negative_prompt=negative_prompt,
                            map=map,
                            num_inference_steps=NUM_INFERENCE_STEPS, denoising_start=0.8).images[0]
    return edited_images


def validate_inputs(image, map):
    if image is None:
        raise gr.Error("Missing image")
    if map is None:
        raise gr.Error("Missing map")


example1 = ["assets/input2.jpg", "assets/map2.jpg", 17.5,
            "Tree of life under the sea, ethereal, glittering, lens flares, cinematic lighting, artwork by Anna Dittmann & Carne Griffiths, 8k, unreal engine 5, hightly detailed, intricate detailed",
            "bad anatomy, poorly drawn face, out of frame, gibberish, lowres, duplicate, morbid, darkness, maniacal, creepy, fused, blurry background, crosseyed, extra limbs, mutilated, dehydrated, surprised, poor quality, uneven, off-centered, bird illustration, painting, cartoons"]
example2 = ["assets/input3.jpg", "assets/map4.png", 21,
            "overgrown atrium, nature, ancient black marble columns and terracotta tile floors, waterfall, ultra-high quality, octane render, corona render, UHD, 64k",
            "Two bodies, Two heads, doll, extra nipples, bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated, out of frame, cloned face, watermark, text, lowres, disfigured, ostentatious, ugly, oversaturated, grain, low resolution, blurry, bad anatomy, poorly drawn face, mutant, mutated,  blurred, out of focus, long neck, long body, ugly, disgusting, bad drawing, childish"]
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            with gr.Row():
                input_image = gr.Image(label="input image", type="pil")
                change_map = gr.Image(label="change map", type="pil")
            gs = gr.Slider(0, 28, value=7.5, label="Guidance Scale")
            prompt = gr.Textbox(label="Prompt")
            neg_prompt = gr.Textbox(label="Negative Prompt")
            with gr.Row():
                clr_btn=gr.ClearButton(components=[input_image, change_map, gs, prompt, neg_prompt])
                run_btn = gr.Button("Run",variant="primary")

        output = gr.Image(label="output image")
    gr.Examples(examples=[example1, example2],inputs=[input_image, change_map, gs, prompt, neg_prompt])
    run_btn.click(inference, inputs=[input_image, change_map, gs, prompt, neg_prompt], outputs=output)
    clr_btn.add(output)
if __name__ == "__main__":
    demo.launch()