Pixart-Sigma / diffusion /data /transforms.py
artificialguybr's picture
Hi
eadd7b4
raw
history blame
No virus
745 Bytes
import torchvision.transforms as T
TRANSFORMS = dict()
def register_transform(transform):
name = transform.__name__
if name in TRANSFORMS:
raise RuntimeError(f'Transform {name} has already registered.')
TRANSFORMS.update({name: transform})
def get_transform(type, resolution):
transform = TRANSFORMS[type](resolution)
transform = T.Compose(transform)
transform.image_size = resolution
return transform
@register_transform
def default_train(n_px):
transform = [
T.Lambda(lambda img: img.convert('RGB')),
T.Resize(n_px), # Image.BICUBIC
T.CenterCrop(n_px),
# T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([.5], [.5]),
]
return transform