GANsNRoses / util.py
aliabd
copied all files from repo
bca104a
raw
history blame
5.03 kB
import torch
import torch.nn.functional as F
from torch.utils import data
from torch import nn, autograd
import os
import matplotlib.pyplot as plt
google_drive_paths = {
"GNR_checkpoint.pt": "https://drive.google.com/uc?id=1IMIVke4WDaGayUa7vk_xVw1uqIHikGtC",
}
def ensure_checkpoint_exists(model_weights_filename):
if not os.path.isfile(model_weights_filename) and (
model_weights_filename in google_drive_paths
):
gdrive_url = google_drive_paths[model_weights_filename]
try:
from gdown import download as drive_download
drive_download(gdrive_url, model_weights_filename, quiet=False)
except ModuleNotFoundError:
print(
"gdown module not found.",
"pip3 install gdown or, manually download the checkpoint file:",
gdrive_url
)
if not os.path.isfile(model_weights_filename) and (
model_weights_filename not in google_drive_paths
):
print(
model_weights_filename,
" not found, you may need to manually download the model weights."
)
def shuffle_batch(x):
return x[torch.randperm(x.size(0))]
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data)
def sample_data(loader):
while True:
for batch in loader:
yield batch
def d_logistic_loss(real_pred, fake_pred):
loss = 0
for real, fake in zip(real_pred, fake_pred):
real_loss = F.softplus(-real)
fake_loss = F.softplus(fake)
loss += real_loss.mean() + fake_loss.mean()
return loss
def d_r1_loss(real_pred, real_img):
grad_penalty = 0
for real in real_pred:
grad_real, = autograd.grad(
outputs=real.mean(), inputs=real_img, create_graph=True, only_inputs=True
)
grad_penalty += grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred, weights):
loss = 0
for fake, weight in zip(fake_pred, weights):
loss += weight*F.softplus(-fake).mean()
return loss / len(fake_pred)
def display_image(image, size=None, mode='nearest', unnorm=False, title=''):
# image is [3,h,w] or [1,3,h,w] tensor [0,1]
if image.is_cuda:
image = image.cpu()
if size is not None and image.size(-1) != size:
image = F.interpolate(image, size=(size,size), mode=mode)
if image.dim() == 4:
image = image[0]
image = image.permute(1, 2, 0).detach().numpy()
plt.figure()
plt.title(title)
plt.axis('off')
plt.imshow(image)
def normalize(x):
return ((x+1)/2).clamp(0,1)
def get_boundingbox(face, width, height, scale=1.3, minsize=None):
"""
Expects a dlib face to generate a quadratic bounding box.
:param face: dlib face class
:param width: frame width
:param height: frame height
:param scale: bounding box size multiplier to get a bigger face region
:param minsize: set minimum bounding box size
:return: x, y, bounding_box_size in opencv form
"""
x1 = face.left()
y1 = face.top()
x2 = face.right()
y2 = face.bottom()
size_bb = int(max(x2 - x1, y2 - y1) * scale)
if minsize:
if size_bb < minsize:
size_bb = minsize
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
# Check for out of bounds, x-y top left corner
x1 = max(int(center_x - size_bb // 2), 0)
y1 = max(int(center_y - size_bb // 2), 0)
# Check for too big bb size for given x, y
size_bb = min(width - x1, size_bb)
size_bb = min(height - y1, size_bb)
return x1, y1, size_bb
def preprocess_image(image, cuda=True):
"""
Preprocesses the image such that it can be fed into our network.
During this process we envoke PIL to cast it into a PIL image.
:param image: numpy image in opencv form (i.e., BGR and of shape
:return: pytorch tensor of shape [1, 3, image_size, image_size], not
necessarily casted to cuda
"""
# Revert from BGR
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Preprocess using the preprocessing function used during training and
# casting it to PIL image
preprocess = xception_default_data_transforms['test']
preprocessed_image = preprocess(pil_image.fromarray(image))
# Add first dimension as the network expects a batch
preprocessed_image = preprocessed_image.unsqueeze(0)
if cuda:
preprocessed_image = preprocessed_image.cuda()
return preprocessed_image
def truncate(x, truncation, mean_style):
return truncation*x + (1-truncation)*mean_style