File size: 4,971 Bytes
d7f12b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import os
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from models.networks.drn_seg import DRNSeg
from utils.tools import *
from utils.visualize import *
from utils.preprocessing import generate_local_image

def predict_and_generate_heatmap(model, image):
    # tf = transforms.Compose([
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # ])
    # # Use generate_local_image directly
    # face = image  # Crop the face or use the global image
    # face_tens = tf(face).unsqueeze(0).to('cpu')

    # try:
    #     with torch.no_grad():
    #         flow = model(face_tens)[0].cpu().numpy()
    #         flow = np.transpose(flow, (1, 2, 0))
    #     flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    #     heatmap = save_heatmap_cv(np.asarray(face), flow_magn)
    #     return heatmap, flow_magn.mean()
    # except Exception as e:
    #     print(f"Error during model inference or heatmap generation: {e}")
    #     return None, None
    # Data preprocessing
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # im_w, im_h = Image.open(img_path).size
    
    face = image
    face_tens = tf(face).to('cpu')

    # Warping field prediction
    with torch.no_grad():
        flow = model(face_tens.unsqueeze(0))[0].cpu().numpy()
        flow = np.transpose(flow, (1, 2, 0))
        h, w, _ = flow.shape

    # Undoing the warps
    modified = face.resize((w, h), Image.BICUBIC)
    modified_np = np.asarray(modified)
    reverse_np = warp(modified_np, flow)
    reverse = Image.fromarray(reverse_np)

    # Saving the results
    modified.save(
        os.path.join('cropped_input.jpg'),
        quality=90)
    reverse.save(
        os.path.join('warped.jpg'),
        quality=90)
    flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    save_heatmap_cv(
        modified_np, flow_magn,
        os.path.join('heatmap.jpg'))
    return 'heatmap.jpg', flow_magn.mean()*100
    



def load_local_detector(model_path, gpu_id=-1):
    if torch.cuda.is_available() and gpu_id != -1:
        device = f'cuda:{gpu_id}'
    else:
        device = 'cpu'

    model = DRNSeg(2)  # Ensure DRNSeg is defined correctly
    state_dict = torch.load(model_path, map_location=device)
    if 'model' not in state_dict:
        raise ValueError(f"Invalid state_dict: {list(state_dict.keys())}")
    model.load_state_dict(state_dict['model'])
    model.to(device)
    model.eval()

    # Debug model after loading
    print("Model successfully loaded and moved to:", device)
    return model


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path", required=True, help="the model input")
    parser.add_argument(
        "--dest_folder", required=True, help="folder to store the results")
    parser.add_argument(
        "--model_path", required=True, help="path to the drn model")
    parser.add_argument(
        "--gpu_id", default='0', help="the id of the gpu to run model on")
    parser.add_argument(
        "--no_crop",
        action="store_true",
        help="do not use a face detector, instead run on the full input image")
    args = parser.parse_args()

    img_path = args.input_path
    dest_folder = args.dest_folder
    model_path = args.model_path
    gpu_id = args.gpu_id



    # # Data preprocessing
    # tf = transforms.Compose([
    #     transforms.ToTensor(),
    #     transforms.Normalize(
    #         mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # ])

    # # im_w, im_h = Image.open(img_path).size
    # if args.no_crop:
    #     face = Image.open(img_path).convert('RGB')
    # else:
    #     faces = face_detection(img_path, verbose=False)
    #     if len(faces) == 0:
    #         print("no face detected by dlib, exiting")
    #         sys.exit()
    #     face, box = faces[0]
    # face = resize_shorter_side(face, 400)[0]
    # face_tens = tf(face).to(device)

    # # Warping field prediction
    # with torch.no_grad():
    #     flow = model(face_tens.unsqueeze(0))[0].cpu().numpy()
    #     flow = np.transpose(flow, (1, 2, 0))
    #     h, w, _ = flow.shape

    # # Undoing the warps
    # modified = face.resize((w, h), Image.BICUBIC)
    # modified_np = np.asarray(modified)
    # reverse_np = warp(modified_np, flow)
    # reverse = Image.fromarray(reverse_np)

    # # Saving the results
    # modified.save(
    #     os.path.join(dest_folder, 'cropped_input.jpg'),
    #     quality=90)
    # reverse.save(
    #     os.path.join(dest_folder, 'warped.jpg'),
    #     quality=90)
    # flow_magn = np.sqrt(flow[:, :, 0]**2 + flow[:, :, 1]**2)
    # save_heatmap_cv(
    #     modified_np, flow_magn,
    #     os.path.join(dest_folder, 'heatmap.jpg'))