ZHIJI_cv_web_ui / NTED /NTED_module.py
zejunyang
update
b79bc32
import cv2
import numpy as np
import torch
import random
# import mediapipe as mp
from lite_openpose.body_bbox_detector import BodyPoseEstimator
from NTED.extraction_distribution_model import Generator
from NTED.demo_dataset import DemoDataset
from NTED.base_function import accumulate
from NTED.config import Config
def set_random_seed(seed):
r"""Set random seeds for everything.
Args:
seed (int): Random seed.
by_rank (bool):
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class NTED():
def __init__(self):
super(NTED, self).__init__()
self.openpose_module = BodyPoseEstimator('cpu')
set_random_seed(0)
self.opt = Config('NTED/fashion_512.yaml', is_train=False)
net_G = Generator(**self.opt.gen.param).to('cpu')
net_G_ema = Generator(**self.opt.gen.param).to('cpu')
net_G_ema.eval()
accumulate(net_G_ema, net_G, 0)
checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage)
net_G_ema.load_state_dict(checkpoint['net_G_ema'])
self.net_G = net_G_ema.eval()
self.data_loader = DemoDataset()
# mp_hands = mp.solutions.hands
# self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1)
self.ref_img = cv2.imread('example/ref_img.png')
self.ref_img = cv2.resize(self.ref_img, (352, 512))
def hand_pose_est(self, img):
results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB))
image_height, image_width, _ = img.shape
pose_data = []
if results.multi_hand_landmarks is not None:
for hand_landmarks in results.multi_hand_landmarks:
for joint_idx in range(21):
pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height])
if len(results.multi_hand_landmarks) == 2:
if results.multi_handedness[0].classification[0].label == 'Right':
# 交换一下,先左手再右手
tmp = pose_data[:21].copy()
pose_data[:21] = pose_data[21:]
pose_data[21:] = tmp
elif len(results.multi_hand_landmarks) == 1:
miss_hand = [[-1, -1] for _ in range(21)]
if results.multi_handedness[0].classification[0].label == 'Left':
pose_data += miss_hand
else:
pose_data = miss_hand + pose_data
else:
for _ in range(42):
pose_data.append([-1, -1])
pose_data = np.array(pose_data, dtype=np.int32)
return pose_data
def inference(self, img):
img = cv2.resize(img, (352, 512))
body_pose, bbox = self.openpose_module.detect_body_pose(img.copy())
# hand_pose = self.hand_pose_est(img.copy())
data = self.data_loader.load_item(self.ref_img, body_pose[0], None)
output = self.net_G(
data['reference_image'],
data['target_skeleton'],
)
fake_image = output['fake_image'][0]
fake_image = self.data_loader.tensor2im(fake_image)
fake_image = cv2.resize(fake_image, (288, 480))
return data['skeleton_img'], fake_image