Spaces:
Runtime error
Runtime error
File size: 3,347 Bytes
e875957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from networks.drn_seg import DRNSeg, DRNSub
from utils.tools import *
from utils.visualize import *
def load_classifier(model_path, gpu_id):
if torch.cuda.is_available() and gpu_id != -1:
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
model = DRNSub(1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.to(device)
model.device = device
model.eval()
return model
local_model_path = 'weights/local.pth'
global_model_path = 'weights/global.pth'
gpu_id = 0
# Loading the model
if torch.cuda.is_available():
device = 'cuda:{}'.format(gpu_id)
else:
device = 'cpu'
local_model = DRNSeg(2)
state_dict = torch.load(local_model_path, map_location=device)
local_model.load_state_dict(state_dict['model'])
local_model.to(device)
local_model.eval()
global_model = load_classifier(global_model_path, gpu_id)
# prob = classify_fake(model, args.input_path, args.no_crop)
# Data preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def classify_fake(img, no_crop=False, global_model=global_model,
model_file='utils/dlib_face_detector/mmod_human_face_detector.dat'):
# Data preprocessing
im_w, im_h = img.size
if no_crop:
face = img
else:
faces = face_detection(img, verbose=False, model_file=model_file)
if len(faces) == 0:
print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(global_model.device)
# Prediction
with torch.no_grad():
prob = global_model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
return prob
def heatmap_analysis(img, no_crop=False):
im_w, im_h = img.size
if no_crop:
face = imgs
else:
faces = face_detection(img, verbose=False)
if len(faces) == 0:
print("no face detected by dlib, exiting")
sys.exit()
face, box = faces[0]
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(device)
# Warping field prediction
with torch.no_grad():
flow = local_model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
cv_out = get_heatmap_cv(modified_np, flow_magn, 7)
heatmap = Image.fromarray(cv_out)
return modified, reverse, heatmap
# Saving the results
# modified.save(
# os.path.join(dest_folder, 'cropped_input.jpg'),
# quality=90)
# reverse.save(
# os.path.join(dest_folder, 'warped.jpg'),
# quality=90)
# flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
# save_heatmap_cv(
# modified_np, flow_magn,
# os.path.join(dest_folder, 'heatmap.jpg'))
|