File size: 3,600 Bytes
9667e74
 
 
 
 
b79bc32
9667e74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ea347
 
9667e74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96ea347
9667e74
96ea347
9667e74
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import cv2
import numpy as np
import torch
import random

# import mediapipe as mp
from lite_openpose.body_bbox_detector import BodyPoseEstimator
from NTED.extraction_distribution_model import Generator
from NTED.demo_dataset import DemoDataset
from NTED.base_function import accumulate
from NTED.config import Config


def set_random_seed(seed):
    r"""Set random seeds for everything.

    Args:
        seed (int): Random seed.
        by_rank (bool):
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class NTED():
    def __init__(self):
        super(NTED, self).__init__()
        
        self.openpose_module = BodyPoseEstimator('cpu')
        set_random_seed(0)
        self.opt = Config('NTED/fashion_512.yaml', is_train=False)
        
        net_G = Generator(**self.opt.gen.param).to('cpu')
        net_G_ema = Generator(**self.opt.gen.param).to('cpu')
        net_G_ema.eval()
        accumulate(net_G_ema, net_G, 0)
        
        checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage)
        net_G_ema.load_state_dict(checkpoint['net_G_ema'])
        self.net_G = net_G_ema.eval()
        
        self.data_loader = DemoDataset()
        
        # mp_hands = mp.solutions.hands
        # self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1) 
        
        self.ref_img = cv2.imread('example/ref_img.png')
        self.ref_img = cv2.resize(self.ref_img, (352, 512))
    
    def hand_pose_est(self, img):
        results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB))
        image_height, image_width, _ = img.shape
        pose_data = []

        if results.multi_hand_landmarks is not None:
            for hand_landmarks in results.multi_hand_landmarks:
                for joint_idx in range(21):
                    pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height])
            if len(results.multi_hand_landmarks) == 2:
                if results.multi_handedness[0].classification[0].label == 'Right':
                    # 交换一下,先左手再右手
                    tmp = pose_data[:21].copy()
                    pose_data[:21] = pose_data[21:]
                    pose_data[21:] = tmp
            elif len(results.multi_hand_landmarks) == 1:
                miss_hand = [[-1, -1] for _ in range(21)]
                if results.multi_handedness[0].classification[0].label == 'Left':
                    pose_data += miss_hand
                else:
                    pose_data = miss_hand + pose_data
        else:
            for _ in range(42):
                pose_data.append([-1, -1])
        pose_data = np.array(pose_data, dtype=np.int32)
        
        return pose_data
    

    def inference(self, img):
        
        img = cv2.resize(img, (352, 512))
        
        body_pose, bbox = self.openpose_module.detect_body_pose(img.copy())
        
        # hand_pose = self.hand_pose_est(img.copy())
        
        data = self.data_loader.load_item(self.ref_img, body_pose[0], None)

        output = self.net_G(
            data['reference_image'], 
            data['target_skeleton'], 
        )
        fake_image = output['fake_image'][0]
        
        fake_image = self.data_loader.tensor2im(fake_image)
        
        fake_image = cv2.resize(fake_image, (288, 480))
        
        return data['skeleton_img'], fake_image