import gradio as gr import torch from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, LMSDiscreteScheduler from my_model import unet_2d_condition import json import numpy as np from PIL import Image, ImageDraw, ImageFont from functools import partial import math from utils import compute_ca_loss from gradio import processing_utils from typing import Optional import spaces import warnings import sys sys.tracebacklimit = 0 class Blocks(gr.Blocks): def __init__( self, theme: str = "default", analytics_enabled: Optional[bool] = None, mode: str = "blocks", title: str = "Gradio", css: Optional[str] = None, **kwargs, ): self.extra_configs = { 'thumbnail': kwargs.pop('thumbnail', ''), 'url': kwargs.pop('url', 'https://gradio.app/'), 'creator': kwargs.pop('creator', '@teamGradio'), } super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs) warnings.filterwarnings("ignore") def get_config_file(self): config = super(Blocks, self).get_config_file() for k, v in self.extra_configs.items(): config[k] = v return config def draw_box(boxes=[], texts=[], img=None): if len(boxes) == 0 and img is None: return None if img is None: img = Image.new('RGB', (512, 512), (255, 255, 255)) colors = ["red", "olive", "blue", "green", "orange", "brown", "cyan", "purple"] draw = ImageDraw.Draw(img) font = ImageFont.truetype("DejaVuSansMono.ttf", size=18) print(boxes) for bid, box in enumerate(boxes): draw.rectangle([box[0], box[1], box[2], box[3]], outline=colors[bid % len(colors)], width=4) anno_text = texts[bid] draw.rectangle( [box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]], outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) draw.text([box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)], anno_text, font=font, fill=(255, 255, 255)) return img def get_concat(ims): if len(ims) == 1: n_col = 1 else: n_col = 2 n_row = math.ceil(len(ims) / 2) dst = Image.new('RGB', (ims[0].width * n_col, ims[0].height * n_row), color="white") for i, im in enumerate(ims): row_id = i // n_col col_id = i % n_col dst.paste(im, (im.width * col_id, im.height * row_id)) return dst def binarize(x): return (x != 0).astype('uint8') * 255 def sized_center_crop(img, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) return img[starty:starty + cropy, startx:startx + cropx] def sized_center_fill(img, fill, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) img[starty:starty + cropy, startx:startx + cropx] = fill return img def sized_center_mask(img, cropx, cropy): y, x = img.shape[:2] startx = x // 2 - (cropx // 2) starty = y // 2 - (cropy // 2) center_region = img[starty:starty + cropy, startx:startx + cropx].copy() img = (img * 0.2).astype('uint8') img[starty:starty + cropy, startx:startx + cropx] = center_region return img def center_crop(img, HW=None, tgt_size=(512, 512)): if HW is None: H, W = img.shape[:2] HW = min(H, W) img = sized_center_crop(img, HW, HW) img = Image.fromarray(img) img = img.resize(tgt_size) return np.array(img) def draw(input, grounding_texts, new_image_trigger, state): if type(input) == dict: # import pdb; pdb.set_trace() # image = input['composite'] mask = input['composite'] else: mask = input if mask.ndim == 3: mask = 255 - mask[..., 0] image_scale = 1.0 mask = binarize(mask) if type(mask) != np.ndarray: mask = np.array(mask) if mask.sum() == 0: state = {} image = None if 'boxes' not in state: state['boxes'] = [] if 'masks' not in state or len(state['masks']) == 0: state['masks'] = [] last_mask = np.zeros_like(mask) else: last_mask = state['masks'][-1] if type(mask) == np.ndarray and mask.size > 1: diff_mask = mask - last_mask else: diff_mask = np.zeros([]) if diff_mask.sum() > 0: x1x2 = np.where(diff_mask.max(0) != 0)[0] y1y2 = np.where(diff_mask.max(1) != 0)[0] y1, y2 = y1y2.min(), y1y2.max() x1, x2 = x1x2.min(), x1x2.max() if (x2 - x1 > 5) and (y2 - y1 > 5): state['masks'].append(mask.copy()) state['boxes'].append((x1, y1, x2, y2)) grounding_texts = [x.strip() for x in grounding_texts.split(';')] grounding_texts = [x for x in grounding_texts if len(x) > 0] if len(grounding_texts) < len(state['boxes']): grounding_texts += [f'Obj. {bid + 1}' for bid in range(len(grounding_texts), len(state['boxes']))] box_image = draw_box(state['boxes'], grounding_texts, image) return [box_image, new_image_trigger, image_scale, state] def clear(task, sketch_pad_trigger, batch_size, state, switch_task=False): if task != 'Grounded Inpainting': sketch_pad_trigger = sketch_pad_trigger + 1 blank_samples = batch_size % 2 if batch_size > 1 else 0 out_images = [gr.Image.change(value=None, visible=True) for i in range(batch_size)] # state = {} return [None, sketch_pad_trigger, None, 1.0] + out_images + [{}] def main(): css = """ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img { height: var(--height) !important; max-height: var(--height) !important; min-height: var(--height) !important; } #paper-info a { color:#008AD7; text-decoration: none; } #paper-info a:hover { cursor: pointer; text-decoration: none; } .tooltip { color: #555; position: relative; display: inline-block; cursor: pointer; } .tooltip .tooltiptext { visibility: hidden; width: 400px; background-color: #555; color: #fff; text-align: center; padding: 5px; border-radius: 5px; position: absolute; z-index: 1; /* Set z-index to 1 */ left: 10px; top: 100%; opacity: 0; transition: opacity 0.3s; } .tooltip:hover .tooltiptext { visibility: visible; opacity: 1; z-index: 9999; /* Set a high z-index value when hovering */ } """ rescale_js = """ function(x) { const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; const image_width = root.querySelector('#img2img_image').clientWidth; const target_height = parseInt(image_width * image_scale); document.body.style.setProperty('--height', `${target_height}px`); root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; return x; } """ with open('./conf/unet/config.json') as f: unet_config = json.load(f) unet = unet_2d_condition.UNet2DConditionModel(**unet_config).from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet") tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") unet.to(device) text_encoder.to(device) vae.to(device) def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad, loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter, state): if 'boxes' not in state: state['boxes'] = [] boxes = state['boxes'] grounding_texts = [x.strip() for x in grounding_texts.split(';')] # assert len(boxes) == len(grounding_texts) if len(boxes) != len(grounding_texts): if len(boxes) < len(grounding_texts): raise ValueError("""The number of boxes should be equal to the number of grounding objects. Number of boxes drawn: {}, number of grounding tokens: {}. Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts))) grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts)) boxes = (np.asarray(boxes) / 512).tolist() boxes = [[box] for box in boxes] grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)}) language_instruction_list = language_instruction.strip('.').split(' ') object_positions = [] for obj in grounding_texts: obj_position = [] for word in obj.split(' '): obj_first_index = language_instruction_list.index(word) + 1 obj_position.append(obj_first_index) object_positions.append(obj_position) device = 'cuda' if torch.cuda.is_available() else 'cpu' gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale) blank_samples = batch_size % 2 if batch_size > 1 else 0 gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \ + [gr.Image.change(fn=None, show_api=True) for _ in range(blank_samples)] \ + [gr.Image.change(fn=None, show_api=False) for _ in range(4 - batch_size - blank_samples)] return gen_images + [state] ''' inference model ''' @spaces.GPU(duration=180) def inference(device, unet, vae, tokenizer, text_encoder, prompt, bboxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_index_step, rand_seed, guidance_scale): uncond_input = tokenizer( [""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] input_ids = tokenizer( prompt, padding="max_length", truncation=True, max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids[0].unsqueeze(0).to(device) # text_embeddings = text_encoder(input_ids)[0] text_embeddings = torch.cat([uncond_embeddings, text_encoder(input_ids)[0]]) # text_embeddings[1, 1, :] = text_embeddings[1, 2, :] generator = torch.manual_seed(rand_seed) # Seed generator to create the inital latent noise latents = torch.randn( (batch_size, 4, 64, 64), generator=generator, ).to(device) noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) # generator = torch.Generator("cuda").manual_seed(1024) noise_scheduler.set_timesteps(51) latents = latents * noise_scheduler.init_noise_sigma loss = torch.tensor(10000) for index, t in enumerate(noise_scheduler.timesteps): iteration = 0 while loss.item() / loss_scale > loss_threshold and iteration < max_iter and index < max_index_step: latents = latents.requires_grad_(True) # latent_model_input = torch.cat([latents] * 2) latent_model_input = latents latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ unet(latent_model_input, t, encoder_hidden_states=text_encoder(input_ids)[0]) # update latents with guidence from gaussian blob loss = compute_ca_loss(attn_map_integrated_mid, attn_map_integrated_up, bboxes=bboxes, object_positions=object_positions) * loss_scale print(loss.item() / loss_scale) grad_cond = torch.autograd.grad(loss.requires_grad_(True), [latents])[0] latents = latents - grad_cond * noise_scheduler.sigmas[index] ** 2 iteration += 1 torch.cuda.empty_cache() torch.cuda.empty_cache() with torch.no_grad(): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) noise_pred, attn_map_integrated_up, attn_map_integrated_mid, attn_map_integrated_down = \ unet(latent_model_input, t, encoder_hidden_states=text_embeddings) noise_pred = noise_pred.sample # perform classifier-free guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample torch.cuda.empty_cache() # Decode image with torch.no_grad(): # print("decode image") latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images with Blocks( css=css, analytics_enabled=False, title="Layout-Guidance demo", ) as demo: description = """
Layout Guidance
[Project Page]
[Paper]
[GitHub]
The source codes of the demo are modified based on the GlIGen. Thanks!
""" gr.HTML(description) demo.queue(api_open=False) demo.launch(share=False, show_api=False, show_error=True) if __name__ == '__main__': main()