roubaofeipi's picture
Upload 100 files
5231633 verified
raw
history blame
No virus
5.32 kB
import PIL
import torch
import requests
import torchvision
from math import ceil
from io import BytesIO
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import math
from tqdm import tqdm
def download_image(url):
return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB")
def resize_image(image, size=768):
tensor_image = F.to_tensor(image)
resized_image = F.resize(tensor_image, size, antialias=True)
return resized_image
def downscale_images(images, factor=3/4):
scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32)
scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
return scaled_image
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = ceil(height / compression_factor_b)
latent_width = ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = ceil(height / compression_factor_a)
latent_width = ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
def get_views(H, W, window_size=64, stride=16):
'''
- H, W: height and width of the latent
'''
num_blocks_height = (H - window_size) // stride + 1
num_blocks_width = (W - window_size) // stride + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
h_end = h_start + window_size
w_start = int((i % num_blocks_width) * stride)
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
return views
def show_images(images, rows=None, cols=None, **kwargs):
if images.size(1) == 1:
images = images.repeat(1, 3, 1, 1)
elif images.size(1) > 3:
images = images[:, :3]
if rows is None:
rows = 1
if cols is None:
cols = images.size(0) // rows
_, _, h, w = images.shape
imgs = []
for i, img in enumerate(images):
imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)))
return imgs
def decode_b(conditions_b, unconditions_b, models_b, bshape, extras_b, device, \
stage_a_tiled=False, num_instance=4, patch_size=256, stride=24):
sampling_b = extras_b.gdf.sample(
models_b.generator.half(), conditions_b, bshape,
unconditions_b, device=device,
**extras_b.sampling_configs,
)
models_b.generator.cuda()
for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
sampled_b = sampled_b
models_b.generator.cpu()
torch.cuda.empty_cache()
if stage_a_tiled:
with torch.cuda.amp.autocast(dtype=torch.float16):
padding = (stride*2, stride*2, stride*2, stride*2)
sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect')
count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride)
for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))):
sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float()
count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1
sampled /= count
sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2]
else:
sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled)
return sampled.float()
def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None):
if conditions is None:
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
if unconditions is None:
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
sampling_c = extras.gdf.sample(
models.generator, conditions, stage_c_latent_shape, stage_c_latent_shape_lr,
unconditions, device=device, **extras.sampling_configs,
)
for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
sampled_c = sampled_c
return sampled_c
def get_target_lr_size(ratio, std_size=24):
w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio))
return (h * 32 , w *32 )