import cv2 import numpy as np import torch import random # import mediapipe as mp from lite_openpose.body_bbox_detector import BodyPoseEstimator from NTED.extraction_distribution_model import Generator from NTED.demo_dataset import DemoDataset from NTED.base_function import accumulate from NTED.config import Config def set_random_seed(seed): r"""Set random seeds for everything. Args: seed (int): Random seed. by_rank (bool): """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) class NTED(): def __init__(self): super(NTED, self).__init__() self.openpose_module = BodyPoseEstimator('cpu') set_random_seed(0) self.opt = Config('NTED/fashion_512.yaml', is_train=False) net_G = Generator(**self.opt.gen.param).to('cpu') net_G_ema = Generator(**self.opt.gen.param).to('cpu') net_G_ema.eval() accumulate(net_G_ema, net_G, 0) checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage) net_G_ema.load_state_dict(checkpoint['net_G_ema']) self.net_G = net_G_ema.eval() self.data_loader = DemoDataset() # mp_hands = mp.solutions.hands # self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1) self.ref_img = cv2.imread('example/ref_img.png') self.ref_img = cv2.resize(self.ref_img, (352, 512)) def hand_pose_est(self, img): results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB)) image_height, image_width, _ = img.shape pose_data = [] if results.multi_hand_landmarks is not None: for hand_landmarks in results.multi_hand_landmarks: for joint_idx in range(21): pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height]) if len(results.multi_hand_landmarks) == 2: if results.multi_handedness[0].classification[0].label == 'Right': # 交换一下,先左手再右手 tmp = pose_data[:21].copy() pose_data[:21] = pose_data[21:] pose_data[21:] = tmp elif len(results.multi_hand_landmarks) == 1: miss_hand = [[-1, -1] for _ in range(21)] if results.multi_handedness[0].classification[0].label == 'Left': pose_data += miss_hand else: pose_data = miss_hand + pose_data else: for _ in range(42): pose_data.append([-1, -1]) pose_data = np.array(pose_data, dtype=np.int32) return pose_data def inference(self, img): img = cv2.resize(img, (352, 512)) body_pose, bbox = self.openpose_module.detect_body_pose(img.copy()) # hand_pose = self.hand_pose_est(img.copy()) data = self.data_loader.load_item(self.ref_img, body_pose[0], None) output = self.net_G( data['reference_image'], data['target_skeleton'], ) fake_image = output['fake_image'][0] fake_image = self.data_loader.tensor2im(fake_image) fake_image = cv2.resize(fake_image, (288, 480)) return data['skeleton_img'], fake_image