Text2Human / Text2Human /ui_demo.py
yumingj's picture
update
bde71cb
import sys
import cv2
import numpy as np
import torch
from PIL import Image
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from models.sample_model import SampleFromPoseModel
from ui.mouse_event import GraphicsScene
from ui.ui import Ui_Form
from utils.language_utils import (generate_shape_attributes,
generate_texture_attributes)
from utils.options import dict_to_nonedict, parse
color_list = [(0, 0, 0), (255, 250, 250), (220, 220, 220), (250, 235, 215),
(255, 250, 205), (211, 211, 211), (70, 130, 180),
(127, 255, 212), (0, 100, 0), (50, 205, 50), (255, 255, 0),
(245, 222, 179), (255, 140, 0), (255, 0, 0), (16, 78, 139),
(144, 238, 144), (50, 205, 174), (50, 155, 250), (160, 140, 88),
(213, 140, 88), (90, 140, 90), (185, 210, 205), (130, 165, 180),
(225, 141, 151)]
class Ex(QWidget, Ui_Form):
def __init__(self, opt):
super(Ex, self).__init__()
self.setupUi(self)
self.show()
self.output_img = None
self.mat_img = None
self.mode = 0
self.size = 6
self.mask = None
self.mask_m = None
self.img = None
# about UI
self.mouse_clicked = False
self.scene = QGraphicsScene()
self.graphicsView.setScene(self.scene)
self.graphicsView.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.ref_scene = GraphicsScene(self.mode, self.size)
self.graphicsView_2.setScene(self.ref_scene)
self.graphicsView_2.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_2.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_2.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.result_scene = QGraphicsScene()
self.graphicsView_3.setScene(self.result_scene)
self.graphicsView_3.setAlignment(Qt.AlignTop | Qt.AlignLeft)
self.graphicsView_3.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.graphicsView_3.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
self.dlg = QColorDialog(self.graphicsView)
self.color = None
self.sample_model = SampleFromPoseModel(opt)
def open_densepose(self):
fileName, _ = QFileDialog.getOpenFileName(self, "Open File",
QDir.currentPath())
if fileName:
image = QPixmap(fileName)
mat_img = Image.open(fileName)
self.pose_img = mat_img.copy()
if image.isNull():
QMessageBox.information(self, "Image Viewer",
"Cannot load %s." % fileName)
return
image = image.scaled(self.graphicsView.size(),
Qt.IgnoreAspectRatio)
if len(self.scene.items()) > 0:
self.scene.removeItem(self.scene.items()[-1])
self.scene.addPixmap(image)
self.ref_scene.clear()
self.result_scene.clear()
# load pose to model
self.pose_img = np.array(
self.pose_img.resize(
size=(256, 512),
resample=Image.LANCZOS))[:, :, 2:].transpose(
2, 0, 1).astype(np.float32)
self.pose_img = self.pose_img / 12. - 1
self.pose_img = torch.from_numpy(self.pose_img).unsqueeze(1)
self.sample_model.feed_pose_data(self.pose_img)
def generate_parsing(self):
self.ref_scene.reset_items()
self.ref_scene.reset()
shape_texts = self.message_box_1.text()
shape_attributes = generate_shape_attributes(shape_texts)
shape_attributes = torch.LongTensor(shape_attributes).unsqueeze(0)
self.sample_model.feed_shape_attributes(shape_attributes)
self.sample_model.generate_parsing_map()
self.sample_model.generate_quantized_segm()
self.colored_segm = self.sample_model.palette_result(
self.sample_model.segm[0].cpu())
self.mask_m = cv2.cvtColor(
cv2.cvtColor(self.colored_segm, cv2.COLOR_RGB2BGR),
cv2.COLOR_BGR2RGB)
qim = QImage(self.colored_segm.data.tobytes(),
self.colored_segm.shape[1], self.colored_segm.shape[0],
QImage.Format_RGB888)
image = QPixmap.fromImage(qim)
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
if len(self.ref_scene.items()) > 0:
self.ref_scene.removeItem(self.ref_scene.items()[-1])
self.ref_scene.addPixmap(image)
self.result_scene.clear()
def generate_human(self):
for i in range(24):
self.mask_m = self.make_mask(self.mask_m,
self.ref_scene.mask_points[i],
self.ref_scene.size_points[i],
color_list[i])
seg_map = np.full(self.mask_m.shape[:-1], -1)
# convert rgb to num
for index, color in enumerate(color_list):
seg_map[np.sum(self.mask_m == color, axis=2) == 3] = index
assert (seg_map != -1).all()
self.sample_model.segm = torch.from_numpy(seg_map).unsqueeze(
0).unsqueeze(0).to(self.sample_model.device)
self.sample_model.generate_quantized_segm()
texture_texts = self.message_box_2.text()
texture_attributes = generate_texture_attributes(texture_texts)
texture_attributes = torch.LongTensor(texture_attributes)
self.sample_model.feed_texture_attributes(texture_attributes)
self.sample_model.generate_texture_map()
result = self.sample_model.sample_and_refine()
result = result.permute(0, 2, 3, 1)
result = result.detach().cpu().numpy()
result = result * 255
result = np.asarray(result[0, :, :, :], dtype=np.uint8)
self.output_img = result
qim = QImage(result.data.tobytes(), result.shape[1], result.shape[0],
QImage.Format_RGB888)
image = QPixmap.fromImage(qim)
image = image.scaled(self.graphicsView.size(), Qt.IgnoreAspectRatio)
if len(self.result_scene.items()) > 0:
self.result_scene.removeItem(self.result_scene.items()[-1])
self.result_scene.addPixmap(image)
def top_mode(self):
self.ref_scene.mode = 1
def skin_mode(self):
self.ref_scene.mode = 15
def outer_mode(self):
self.ref_scene.mode = 2
def face_mode(self):
self.ref_scene.mode = 14
def skirt_mode(self):
self.ref_scene.mode = 3
def hair_mode(self):
self.ref_scene.mode = 13
def dress_mode(self):
self.ref_scene.mode = 4
def headwear_mode(self):
self.ref_scene.mode = 7
def pants_mode(self):
self.ref_scene.mode = 5
def eyeglass_mode(self):
self.ref_scene.mode = 8
def rompers_mode(self):
self.ref_scene.mode = 21
def footwear_mode(self):
self.ref_scene.mode = 11
def leggings_mode(self):
self.ref_scene.mode = 6
def ring_mode(self):
self.ref_scene.mode = 16
def belt_mode(self):
self.ref_scene.mode = 10
def neckwear_mode(self):
self.ref_scene.mode = 9
def wrist_mode(self):
self.ref_scene.mode = 17
def socks_mode(self):
self.ref_scene.mode = 18
def tie_mode(self):
self.ref_scene.mode = 23
def earstuds_mode(self):
self.ref_scene.mode = 22
def necklace_mode(self):
self.ref_scene.mode = 20
def bag_mode(self):
self.ref_scene.mode = 12
def glove_mode(self):
self.ref_scene.mode = 19
def background_mode(self):
self.ref_scene.mode = 0
def make_mask(self, mask, pts, sizes, color):
if len(pts) > 0:
for idx, pt in enumerate(pts):
cv2.line(mask, pt['prev'], pt['curr'], color, sizes[idx])
return mask
def save_img(self):
if type(self.output_img):
fileName, _ = QFileDialog.getSaveFileName(self, "Save File",
QDir.currentPath())
cv2.imwrite(fileName + '.png', self.output_img[:, :, ::-1])
def undo(self):
self.scene.undo()
def clear(self):
self.ref_scene.reset_items()
self.ref_scene.reset()
self.ref_scene.clear()
self.result_scene.clear()
if __name__ == '__main__':
app = QApplication(sys.argv)
opt = './configs/sample_from_pose.yml'
opt = parse(opt, is_train=False)
opt = dict_to_nonedict(opt)
ex = Ex(opt)
sys.exit(app.exec_())