sam-model / datasets /augmentations.py
Denis
lfs
2302223
raw
history blame
No virus
624 Bytes
import numpy as np
import torch
class AgeTransformer(object):
def __init__(self, target_age):
self.target_age = target_age
def __call__(self, img):
img = self.add_aging_channel(img)
return img
def add_aging_channel(self, img):
target_age = self.__get_target_age()
target_age = int(target_age) / 100 # normalize aging amount to be in range [-1,1]
img = torch.cat((img, target_age * torch.ones((1, img.shape[1], img.shape[2]))))
return img
def __get_target_age(self):
if self.target_age == "uniform_random":
return np.random.randint(low=0., high=101, size=1)[0]
else:
return self.target_age