Open in github.dev
Open in a new github.dev tab
Permalink
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
ControlNet/gradio_seg2image.py /
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
92 lines (72 sloc)
4.26 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from share import * | |
import config | |
import cv2 | |
import einops | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
from pytorch_lightning import seed_everything | |
from annotator.util import resize_image, HWC3 | |
from annotator.uniformer import apply_uniformer | |
from cldm.model import create_model, load_state_dict | |
from ldm.models.diffusion.ddim import DDIMSampler | |
model = create_model('./models/cldm_v15.yaml').cpu() | |
model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda')) | |
model = model.cuda() | |
ddim_sampler = DDIMSampler(model) | |
def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta): | |
with torch.no_grad(): | |
input_image = HWC3(input_image) | |
detected_map = apply_uniformer(resize_image(input_image, detect_resolution)) | |
img = resize_image(input_image, image_resolution) | |
H, W, C = img.shape | |
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 | |
control = torch.stack([control for _ in range(num_samples)], dim=0) | |
control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]} | |
un_cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]} | |
shape = (4, H // 8, W // 8) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=True) | |
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
x_samples = model.decode_first_stage(samples) | |
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
results = [x_samples[i] for i in range(num_samples)] | |
return [detected_map] + results | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## Control Stable Diffusion with Segmentation Maps") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(source='upload', type="numpy") | |
prompt = gr.Textbox(label="Prompt") | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1) | |
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256) | |
detect_resolution = gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1) | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
eta = gr.Number(label="eta (DDIM)", value=0.0) | |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
n_prompt = gr.Textbox(label="Negative Prompt", | |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') | |
ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta] | |
run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) | |
block.launch(server_name='0.0.0.0') |