Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from torch.utils import data | |
from torch import nn, autograd | |
import os | |
import matplotlib.pyplot as plt | |
google_drive_paths = { | |
"GNR_checkpoint.pt": "https://drive.google.com/uc?id=1IMIVke4WDaGayUa7vk_xVw1uqIHikGtC", | |
} | |
def ensure_checkpoint_exists(model_weights_filename): | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename in google_drive_paths | |
): | |
gdrive_url = google_drive_paths[model_weights_filename] | |
try: | |
from gdown import download as drive_download | |
drive_download(gdrive_url, model_weights_filename, quiet=False) | |
except ModuleNotFoundError: | |
print( | |
"gdown module not found.", | |
"pip3 install gdown or, manually download the checkpoint file:", | |
gdrive_url | |
) | |
if not os.path.isfile(model_weights_filename) and ( | |
model_weights_filename not in google_drive_paths | |
): | |
print( | |
model_weights_filename, | |
" not found, you may need to manually download the model weights." | |
) | |
def shuffle_batch(x): | |
return x[torch.randperm(x.size(0))] | |
def data_sampler(dataset, shuffle, distributed): | |
if distributed: | |
return data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
if shuffle: | |
return data.RandomSampler(dataset) | |
else: | |
return data.SequentialSampler(dataset) | |
def accumulate(model1, model2, decay=0.999): | |
par1 = dict(model1.named_parameters()) | |
par2 = dict(model2.named_parameters()) | |
for k in par1.keys(): | |
par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) | |
def sample_data(loader): | |
while True: | |
for batch in loader: | |
yield batch | |
def d_logistic_loss(real_pred, fake_pred): | |
loss = 0 | |
for real, fake in zip(real_pred, fake_pred): | |
real_loss = F.softplus(-real) | |
fake_loss = F.softplus(fake) | |
loss += real_loss.mean() + fake_loss.mean() | |
return loss | |
def d_r1_loss(real_pred, real_img): | |
grad_penalty = 0 | |
for real in real_pred: | |
grad_real, = autograd.grad( | |
outputs=real.mean(), inputs=real_img, create_graph=True, only_inputs=True | |
) | |
grad_penalty += grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() | |
return grad_penalty | |
def g_nonsaturating_loss(fake_pred, weights): | |
loss = 0 | |
for fake, weight in zip(fake_pred, weights): | |
loss += weight*F.softplus(-fake).mean() | |
return loss / len(fake_pred) | |
def display_image(image, size=None, mode='nearest', unnorm=False, title=''): | |
# image is [3,h,w] or [1,3,h,w] tensor [0,1] | |
if image.is_cuda: | |
image = image.cpu() | |
if size is not None and image.size(-1) != size: | |
image = F.interpolate(image, size=(size,size), mode=mode) | |
if image.dim() == 4: | |
image = image[0] | |
image = image.permute(1, 2, 0).detach().numpy() | |
plt.figure() | |
plt.title(title) | |
plt.axis('off') | |
plt.imshow(image) | |
def normalize(x): | |
return ((x+1)/2).clamp(0,1) | |
def get_boundingbox(face, width, height, scale=1.3, minsize=None): | |
""" | |
Expects a dlib face to generate a quadratic bounding box. | |
:param face: dlib face class | |
:param width: frame width | |
:param height: frame height | |
:param scale: bounding box size multiplier to get a bigger face region | |
:param minsize: set minimum bounding box size | |
:return: x, y, bounding_box_size in opencv form | |
""" | |
x1 = face.left() | |
y1 = face.top() | |
x2 = face.right() | |
y2 = face.bottom() | |
size_bb = int(max(x2 - x1, y2 - y1) * scale) | |
if minsize: | |
if size_bb < minsize: | |
size_bb = minsize | |
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 | |
# Check for out of bounds, x-y top left corner | |
x1 = max(int(center_x - size_bb // 2), 0) | |
y1 = max(int(center_y - size_bb // 2), 0) | |
# Check for too big bb size for given x, y | |
size_bb = min(width - x1, size_bb) | |
size_bb = min(height - y1, size_bb) | |
return x1, y1, size_bb | |
def preprocess_image(image, cuda=True): | |
""" | |
Preprocesses the image such that it can be fed into our network. | |
During this process we envoke PIL to cast it into a PIL image. | |
:param image: numpy image in opencv form (i.e., BGR and of shape | |
:return: pytorch tensor of shape [1, 3, image_size, image_size], not | |
necessarily casted to cuda | |
""" | |
# Revert from BGR | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Preprocess using the preprocessing function used during training and | |
# casting it to PIL image | |
preprocess = xception_default_data_transforms['test'] | |
preprocessed_image = preprocess(pil_image.fromarray(image)) | |
# Add first dimension as the network expects a batch | |
preprocessed_image = preprocessed_image.unsqueeze(0) | |
if cuda: | |
preprocessed_image = preprocessed_image.cuda() | |
return preprocessed_image | |
def truncate(x, truncation, mean_style): | |
return truncation*x + (1-truncation)*mean_style | |