File size: 4,971 Bytes
d7f12b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.networks.drn_seg import DRNSeg
from utils.tools import *
from utils.visualize import *
from utils.preprocessing import generate_local_image
def predict_and_generate_heatmap(model, image):
# tf = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# # Use generate_local_image directly
# face = image # Crop the face or use the global image
# face_tens = tf(face).unsqueeze(0).to('cpu')
# try:
# with torch.no_grad():
# flow = model(face_tens)[0].cpu().numpy()
# flow = np.transpose(flow, (1, 2, 0))
# flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
# heatmap = save_heatmap_cv(np.asarray(face), flow_magn)
# return heatmap, flow_magn.mean()
# except Exception as e:
# print(f"Error during model inference or heatmap generation: {e}")
# return None, None
# Data preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# im_w, im_h = Image.open(img_path).size
face = image
face_tens = tf(face).to('cpu')
# Warping field prediction
with torch.no_grad():
flow = model(face_tens.unsqueeze(0))[0].cpu().numpy()
flow = np.transpose(flow, (1, 2, 0))
h, w, _ = flow.shape
# Undoing the warps
modified = face.resize((w, h), Image.BICUBIC)
modified_np = np.asarray(modified)
reverse_np = warp(modified_np, flow)
reverse = Image.fromarray(reverse_np)
# Saving the results
modified.save(
os.path.join('cropped_input.jpg'),
quality=90)
reverse.save(
os.path.join('warped.jpg'),
quality=90)
flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
save_heatmap_cv(
modified_np, flow_magn,
os.path.join('heatmap.jpg'))
return 'heatmap.jpg', flow_magn.mean()*100
def load_local_detector(model_path, gpu_id=-1):
if torch.cuda.is_available() and gpu_id != -1:
device = f'cuda:{gpu_id}'
else:
device = 'cpu'
model = DRNSeg(2) # Ensure DRNSeg is defined correctly
state_dict = torch.load(model_path, map_location=device)
if 'model' not in state_dict:
raise ValueError(f"Invalid state_dict: {list(state_dict.keys())}")
model.load_state_dict(state_dict['model'])
model.to(device)
model.eval()
# Debug model after loading
print("Model successfully loaded and moved to:", device)
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path", required=True, help="the model input")
parser.add_argument(
"--dest_folder", required=True, help="folder to store the results")
parser.add_argument(
"--model_path", required=True, help="path to the drn model")
parser.add_argument(
"--gpu_id", default='0', help="the id of the gpu to run model on")
parser.add_argument(
"--no_crop",
action="store_true",
help="do not use a face detector, instead run on the full input image")
args = parser.parse_args()
img_path = args.input_path
dest_folder = args.dest_folder
model_path = args.model_path
gpu_id = args.gpu_id
# # Data preprocessing
# tf = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])
# # im_w, im_h = Image.open(img_path).size
# if args.no_crop:
# face = Image.open(img_path).convert('RGB')
# else:
# faces = face_detection(img_path, verbose=False)
# if len(faces) == 0:
# print("no face detected by dlib, exiting")
# sys.exit()
# face, box = faces[0]
# face = resize_shorter_side(face, 400)[0]
# face_tens = tf(face).to(device)
# # Warping field prediction
# with torch.no_grad():
# flow = model(face_tens.unsqueeze(0))[0].cpu().numpy()
# flow = np.transpose(flow, (1, 2, 0))
# h, w, _ = flow.shape
# # Undoing the warps
# modified = face.resize((w, h), Image.BICUBIC)
# modified_np = np.asarray(modified)
# reverse_np = warp(modified_np, flow)
# reverse = Image.fromarray(reverse_np)
# # Saving the results
# modified.save(
# os.path.join(dest_folder, 'cropped_input.jpg'),
# quality=90)
# reverse.save(
# os.path.join(dest_folder, 'warped.jpg'),
# quality=90)
# flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
# save_heatmap_cv(
# modified_np, flow_magn,
# os.path.join(dest_folder, 'heatmap.jpg'))
|