Spaces:
Running
Running
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from omegaconf import OmegaConf | |
import safetensors | |
import os | |
import einops | |
import cv2 | |
from PIL import Image, ImageFilter, ImageOps | |
from utils.io_utils import resize_pad2divisior | |
import os | |
from utils.io_utils import submit_request, img2b64 | |
import json | |
# Debug by Francis | |
# from ldm.util import instantiate_from_config | |
# from ldm.models.diffusion.ddpm import LatentDiffusion | |
# from ldm.models.diffusion.ddim import DDIMSampler | |
# from ldm.modules.diffusionmodules.util import noise_like | |
import io | |
import base64 | |
from requests.auth import HTTPBasicAuth | |
# Debug by Francis | |
# def create_model(config_path): | |
# config = OmegaConf.load(config_path) | |
# model = instantiate_from_config(config.model).cpu() | |
# return model | |
# | |
# def get_state_dict(d): | |
# return d.get('state_dict', d) | |
# | |
# def load_state_dict(ckpt_path, location='cpu'): | |
# _, extension = os.path.splitext(ckpt_path) | |
# if extension.lower() == ".safetensors": | |
# import safetensors.torch | |
# state_dict = safetensors.torch.load_file(ckpt_path, device=location) | |
# else: | |
# state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) | |
# state_dict = get_state_dict(state_dict) | |
# return state_dict | |
# | |
# | |
# def load_ldm_sd(model, path) : | |
# if path.endswith('.safetensor') : | |
# sd = safetensors.torch.load_file(path) | |
# else : | |
# sd = load_state_dict(path) | |
# model.load_state_dict(sd, strict = False) | |
# | |
# def fill_mask_input(image, mask): | |
# """fills masked regions with colors from image using blur. Not extremely effective.""" | |
# | |
# image_mod = Image.new('RGBA', (image.width, image.height)) | |
# | |
# image_masked = Image.new('RGBa', (image.width, image.height)) | |
# image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) | |
# | |
# image_masked = image_masked.convert('RGBa') | |
# | |
# for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: | |
# blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') | |
# for _ in range(repeats): | |
# image_mod.alpha_composite(blurred) | |
# | |
# return image_mod.convert("RGB") | |
# | |
# | |
# def get_inpainting_image_condition(model, image, mask) : | |
# conditioning_mask = np.array(mask.convert("L")) | |
# conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 | |
# conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) | |
# conditioning_mask = torch.round(conditioning_mask) | |
# conditioning_mask = conditioning_mask.to(device=image.device, dtype=image.dtype) | |
# conditioning_image = torch.lerp( | |
# image, | |
# image * (1.0 - conditioning_mask), | |
# 1 | |
# ) | |
# conditioning_image = model.get_first_stage_encoding(model.encode_first_stage(conditioning_image)) | |
# conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=conditioning_image.shape[-2:]) | |
# conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) | |
# image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) | |
# return image_conditioning | |
# | |
# | |
# class GuidedLDM(LatentDiffusion): | |
# def __init__(self, *args, **kwargs): | |
# super().__init__(*args, **kwargs) | |
# | |
# @torch.no_grad() | |
# def img2img_inpaint( | |
# self, | |
# image: Image.Image, | |
# c_text: str, | |
# uc_text: str, | |
# mask: Image.Image, | |
# ddim_steps = 50, | |
# mask_blur: int = 0, | |
# use_cuda: bool = True, | |
# **kwargs) -> Image.Image : | |
# ddim_sampler = GuidedDDIMSample(self) | |
# if use_cuda : | |
# self.cond_stage_model.cuda() | |
# self.first_stage_model.cuda() | |
# c_text = self.get_learned_conditioning([c_text]) | |
# uc_text = self.get_learned_conditioning([uc_text]) | |
# cond = {"c_crossattn": [c_text]} | |
# uc_cond = {"c_crossattn": [uc_text]} | |
# | |
# if use_cuda : | |
# device = torch.device('cuda:0') | |
# else : | |
# device = torch.device('cpu') | |
# | |
# image_mask = mask | |
# image_mask = image_mask.convert('L') | |
# image_mask = image_mask.filter(ImageFilter.GaussianBlur(mask_blur)) | |
# latent_mask = image_mask | |
# # image = fill_mask_input(image, latent_mask) | |
# # image.save('image_fill.png') | |
# image = np.array(image).astype(np.float32) / 127.5 - 1.0 | |
# image = np.moveaxis(image, 2, 0) | |
# image = torch.from_numpy(image).to(device)[None] | |
# init_latent = self.get_first_stage_encoding(self.encode_first_stage(image)) | |
# init_mask = latent_mask | |
# latmask = init_mask.convert('RGB').resize((init_latent.shape[3], init_latent.shape[2])) | |
# latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 | |
# latmask = latmask[0] | |
# latmask = np.around(latmask) | |
# latmask = np.tile(latmask[None], (4, 1, 1)) | |
# nmask = torch.asarray(latmask).to(init_latent.device).float() | |
# init_latent = (1 - nmask) * init_latent + nmask * torch.randn_like(init_latent) | |
# | |
# denoising_strength = 1 | |
# if self.model.conditioning_key == 'hybrid' : | |
# image_cdt = get_inpainting_image_condition(self, image, image_mask) | |
# cond["c_concat"] = [image_cdt] | |
# uc_cond["c_concat"] = [image_cdt] | |
# | |
# steps = ddim_steps | |
# t_enc = int(min(denoising_strength, 0.999) * steps) | |
# eta = 0 | |
# | |
# noise = torch.randn_like(init_latent) | |
# ddim_sampler.make_schedule(ddim_num_steps=steps, ddim_eta=eta, ddim_discretize="uniform", verbose=False) | |
# x1 = ddim_sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * int(init_latent.shape[0])).to(device), noise=noise) | |
# | |
# if use_cuda : | |
# self.cond_stage_model.cpu() | |
# self.first_stage_model.cpu() | |
# | |
# if use_cuda : | |
# self.model.cuda() | |
# decoded = ddim_sampler.decode(x1, cond,t_enc,init_latent=init_latent,nmask=nmask,unconditional_guidance_scale=7,unconditional_conditioning=uc_cond) | |
# if use_cuda : | |
# self.model.cpu() | |
# | |
# if mask is not None : | |
# decoded = init_latent * (1 - nmask) + decoded * nmask | |
# | |
# if use_cuda : | |
# self.first_stage_model.cuda() | |
# with torch.cuda.amp.autocast(enabled=False): | |
# x_samples = self.decode_first_stage(decoded.to(torch.float32)) | |
# if use_cuda : | |
# self.first_stage_model.cpu() | |
# return torch.clip(x_samples, -1, 1) | |
# | |
# | |
# | |
# class GuidedDDIMSample(DDIMSampler) : | |
# def __init__(self, *args, **kwargs): | |
# super().__init__(*args, **kwargs) | |
# | |
# @torch.no_grad() | |
# def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | |
# temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | |
# unconditional_guidance_scale=1., unconditional_conditioning=None, | |
# dynamic_threshold=None): | |
# b, *_, device = *x.shape, x.device | |
# | |
# if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | |
# model_output = self.model.apply_model(x, t, c) | |
# else: | |
# x_in = torch.cat([x] * 2) | |
# t_in = torch.cat([t] * 2) | |
# if isinstance(c, dict): | |
# assert isinstance(unconditional_conditioning, dict) | |
# c_in = dict() | |
# for k in c: | |
# if isinstance(c[k], list): | |
# c_in[k] = [torch.cat([ | |
# unconditional_conditioning[k][i], | |
# c[k][i]]) for i in range(len(c[k]))] | |
# else: | |
# c_in[k] = torch.cat([ | |
# unconditional_conditioning[k], | |
# c[k]]) | |
# elif isinstance(c, list): | |
# c_in = list() | |
# assert isinstance(unconditional_conditioning, list) | |
# for i in range(len(c)): | |
# c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) | |
# else: | |
# c_in = torch.cat([unconditional_conditioning, c]) | |
# model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | |
# model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) | |
# | |
# e_t = model_output | |
# | |
# alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas | |
# alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev | |
# sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas | |
# sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas | |
# # select parameters corresponding to the currently considered timestep | |
# a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | |
# a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | |
# sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | |
# sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | |
# | |
# # current prediction for x_0 | |
# pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | |
# | |
# # direction pointing to x_t | |
# dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | |
# noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | |
# if noise_dropout > 0.: | |
# noise = torch.nn.functional.dropout(noise, p=noise_dropout) | |
# x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | |
# return x_prev, pred_x0 | |
# | |
# @torch.no_grad() | |
# def decode(self, x_latent, cond, t_start, init_latent=None, nmask=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, | |
# use_original_steps=False, callback=None): | |
# | |
# timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps | |
# total_steps = len(timesteps) | |
# timesteps = timesteps[:t_start] | |
# | |
# time_range = np.flip(timesteps) | |
# total_steps = timesteps.shape[0] | |
# print(f"Running Guided DDIM Sampling with {len(timesteps)} timesteps, t_start={t_start}") | |
# iterator = tqdm(time_range, desc='Decoding image', total=total_steps) | |
# x_dec = x_latent | |
# for i, step in enumerate(iterator): | |
# p = (i + (total_steps - t_start) + 1) / (total_steps) | |
# index = total_steps - i - 1 | |
# ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) | |
# if nmask is not None : | |
# noised_input = self.model.q_sample(init_latent.to(x_latent.device), ts.to(x_latent.device)) | |
# x_dec = (1 - nmask) * noised_input + nmask * x_dec | |
# x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, | |
# unconditional_guidance_scale=unconditional_guidance_scale, | |
# unconditional_conditioning=unconditional_conditioning) | |
# if callback: callback(i) | |
# return x_dec | |
# | |
# | |
# def ldm_inpaint(model, img, mask, inpaint_size=720, pos_prompt='', neg_prompt = '', use_cuda=True): | |
# img_original = np.copy(img) | |
# im_h, im_w = img.shape[:2] | |
# img_resized, (pad_h, pad_w) = resize_pad2divisior(img, inpaint_size) | |
# | |
# mask_original = np.copy(mask) | |
# mask_original[mask_original < 127] = 0 | |
# mask_original[mask_original >= 127] = 1 | |
# mask_original = mask_original[:, :, None] | |
# mask, _ = resize_pad2divisior(mask, inpaint_size) | |
# | |
# # cv2.imwrite('img_resized.png', img_resized) | |
# # cv2.imwrite('mask_resized.png', mask) | |
# | |
# | |
# if use_cuda : | |
# with torch.autocast(enabled = True, device_type = 'cuda') : | |
# img = model.img2img_inpaint( | |
# image = Image.fromarray(img_resized), | |
# c_text = pos_prompt, | |
# uc_text = neg_prompt, | |
# mask = Image.fromarray(mask), | |
# use_cuda = True | |
# ) | |
# else : | |
# img = model.img2img_inpaint( | |
# image = Image.fromarray(img_resized), | |
# c_text = pos_prompt, | |
# uc_text = neg_prompt, | |
# mask = Image.fromarray(mask), | |
# use_cuda = False | |
# ) | |
# | |
# img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8) | |
# if pad_h != 0: | |
# img_inpainted = img_inpainted[:-pad_h] | |
# if pad_w != 0: | |
# img_inpainted = img_inpainted[:, :-pad_w] | |
# | |
# | |
# if img_inpainted.shape[0] != im_h or img_inpainted.shape[1] != im_w: | |
# img_inpainted = cv2.resize(img_inpainted, (im_w, im_h), interpolation = cv2.INTER_LINEAR) | |
# ans = img_inpainted * mask_original + img_original * (1 - mask_original) | |
# ans = img_inpainted | |
# return ans | |
import requests | |
from PIL import Image | |
def ldm_inpaint_webui( | |
img, mask, resolution: int, url: str, prompt: str = '', neg_prompt: str = '', | |
**inpaint_ldm_options): | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
im_h, im_w = img.height, img.width | |
if img.height > img.width: | |
W = resolution | |
H = (img.height / img.width * resolution) // 32 * 32 | |
H = int(H) | |
else: | |
H = resolution | |
W = (img.width / img.height * resolution) // 32 * 32 | |
W = int(W) | |
auth = None | |
if 'username' in inpaint_ldm_options: | |
username = inpaint_ldm_options.pop('username') | |
password = inpaint_ldm_options.pop('password') | |
auth = HTTPBasicAuth(username, password) | |
img_b64 = img2b64(img) | |
mask_b64 = img2b64(mask) | |
data = { | |
"init_images": [img_b64], | |
"mask": mask_b64, | |
"prompt": prompt, | |
"negative_prompt": neg_prompt, | |
"width": W, | |
"height": H, | |
**inpaint_ldm_options, | |
} | |
data = json.dumps(data) | |
response = submit_request(url, data, auth=auth) | |
inpainted_b64 = response.json()['images'][0] | |
inpainted = Image.open(io.BytesIO(base64.b64decode(inpainted_b64))) | |
if inpainted.height != im_h or inpainted.width != im_w: | |
inpainted = inpainted.resize((im_w, im_h), resample=Image.Resampling.LANCZOS) | |
inpainted = np.array(inpainted) | |
return inpainted | |