Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,701 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor
from typing import List
from PIL import Image
import torch
import random
import numpy as np
import copy
import torchvision.transforms.functional as tvtf
from src.models.vae import uint82fp
def center_crop_arr(pil_image, width, height):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = max(width / pil_image.size[0], height / pil_image.size[1])
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = random.randint(0, (arr.shape[0] - height))
crop_x = random.randint(0, (arr.shape[1] - width))
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
def process_fn(width, height, data, hflip=0.5):
image, label = data
if random.uniform(0, 1) > hflip: # hflip
image = tvtf.hflip(image)
image = center_crop_arr(image, width, height) # crop
image = np.array(image).transpose(2, 0, 1)
return image, label
class VARCandidate:
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
self.aspect_ratio = aspect_ratio
self.width = int(width)
self.height = int(height)
self.buffer = buffer
self.max_buffer_size = max_buffer_size
def add_sample(self, data):
self.buffer.append(data)
self.buffer = self.buffer[-self.max_buffer_size:]
def ready(self, batch_size):
return len(self.buffer) >= batch_size
def get_batch(self, batch_size):
batch = self.buffer[:batch_size]
self.buffer = self.buffer[batch_size:]
batch = [copy.deepcopy(b.result()) for b in batch]
x, y = zip(*batch)
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
x = list(map(uint82fp, x))
return x, y
class VARTransformEngine:
def __init__(self,
base_image_size,
num_aspect_ratios,
min_aspect_ratio,
max_aspect_ratio,
num_workers = 8,
):
self.base_image_size = base_image_size
self.num_aspect_ratios = num_aspect_ratios
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
self.aspect_ratios = self.aspect_ratios.tolist()
self.candidates_pool = []
for i in range(self.num_aspect_ratios):
candidate = VARCandidate(
aspect_ratio=self.aspect_ratios[i],
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
buffer=[],
max_buffer_size=1024
)
self.candidates_pool.append(candidate)
self.default_candidate = VARCandidate(
aspect_ratio=1.0,
width=self.base_image_size,
height=self.base_image_size,
buffer=[],
max_buffer_size=1024,
)
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
self._prefill_count = 100
def find_candidate(self, data):
image = data[0]
aspect_ratio = image.size[0] / image.size[1]
min_distance = 1000000
min_candidate = None
for candidate in self.candidates_pool:
dis = abs(aspect_ratio - candidate.aspect_ratio)
if dis < min_distance:
min_distance = dis
min_candidate = candidate
return min_candidate
def __call__(self, batch_data):
self._prefill_count -= 1
if isinstance(batch_data[0], torch.Tensor):
batch_data[0] = batch_data[0].unbind(0)
batch_data = list(zip(*batch_data))
for data in batch_data:
candidate = self.find_candidate(data)
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
candidate.add_sample(future)
if self._prefill_count >= 0:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
batch_size = len(batch_data)
random.shuffle(self.candidates_pool)
for candidate in self.candidates_pool:
if candidate.ready(batch_size=batch_size):
return candidate.get_batch(batch_size=batch_size)
# fallback to default 256
for data in batch_data:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
return self.default_candidate.get_batch(batch_size=batch_size) |