radames's picture
radames HF staff
update reqs
5c511d7
raw history blame
No virus
4.77 kB
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
import time
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from lib.options import BaseOptions
from lib.mesh_util import *
from lib.sample_util import *
from lib.train_util import *
from lib.model import *
from PIL import Image
import torchvision.transforms as transforms
import trimesh
from datetime import datetime
# get options
opt = BaseOptions().parse()
class Evaluator:
def __init__(self, opt, projection_mode='orthogonal'):
self.opt = opt
self.load_size = self.opt.loadSize
self.to_tensor = transforms.Compose([
transforms.Resize(self.load_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# set cuda
cuda = torch.device('cuda:%d' % opt.gpu_id) if torch.cuda.is_available() else torch.device('cpu')
print("CUDDAAAAA ???", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO ONLY CPU")
# create net
netG = HGPIFuNet(opt, projection_mode).to(device=cuda)
print('Using Network: ', netG.name)
if opt.load_netG_checkpoint_path:
netG.load_state_dict(torch.load(opt.load_netG_checkpoint_path, map_location=cuda))
if opt.load_netC_checkpoint_path is not None:
print('loading for net C ...', opt.load_netC_checkpoint_path)
netC = ResBlkPIFuNet(opt).to(device=cuda)
netC.load_state_dict(torch.load(opt.load_netC_checkpoint_path, map_location=cuda))
else:
netC = None
os.makedirs(opt.results_path, exist_ok=True)
os.makedirs('%s/%s' % (opt.results_path, opt.name), exist_ok=True)
opt_log = os.path.join(opt.results_path, opt.name, 'opt.txt')
with open(opt_log, 'w') as outfile:
outfile.write(json.dumps(vars(opt), indent=2))
self.cuda = cuda
self.netG = netG
self.netC = netC
def load_image(self, image_path, mask_path):
# Name
img_name = os.path.splitext(os.path.basename(image_path))[0]
# Calib
B_MIN = np.array([-1, -1, -1])
B_MAX = np.array([1, 1, 1])
projection_matrix = np.identity(4)
projection_matrix[1, 1] = -1
calib = torch.Tensor(projection_matrix).float()
# Mask
mask = Image.open(mask_path).convert('L')
mask = transforms.Resize(self.load_size)(mask)
mask = transforms.ToTensor()(mask).float()
# image
image = Image.open(image_path).convert('RGB')
image = self.to_tensor(image)
image = mask.expand_as(image) * image
return {
'name': img_name,
'img': image.unsqueeze(0),
'calib': calib.unsqueeze(0),
'mask': mask.unsqueeze(0),
'b_min': B_MIN,
'b_max': B_MAX,
}
def eval(self, data, use_octree=False):
'''
Evaluate a data point
:param data: a dict containing at least ['name'], ['image'], ['calib'], ['b_min'] and ['b_max'] tensors.
:return:
'''
opt = self.opt
with torch.no_grad():
self.netG.eval()
if self.netC:
self.netC.eval()
save_path = '%s/%s/result_%s.obj' % (opt.results_path, opt.name, data['name'])
if self.netC:
gen_mesh_color(opt, self.netG, self.netC, self.cuda, data, save_path, use_octree=use_octree)
else:
gen_mesh(opt, self.netG, self.cuda, data, save_path, use_octree=use_octree)
if __name__ == '__main__':
evaluator = Evaluator(opt)
results_path = opt.results_path
name = opt.name
test_image_path = opt.img_path
test_mask_path = test_image_path[:-4] +'_mask.png'
test_img_name = os.path.splitext(os.path.basename(test_image_path))[0]
print("test_image: ", test_image_path)
print("test_mask: ", test_mask_path)
try:
time = datetime.now()
print("evaluating" , time)
data = evaluator.load_image(test_image_path, test_mask_path)
evaluator.eval(data, False)
print("done evaluating" , datetime.now() - time)
except Exception as e:
print("error:", e.args)
try:
mesh = trimesh.load(f'{results_path}/{name}/result_{test_img_name}.obj')
mesh.apply_transform([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
mesh.export(file_obj=f'{results_path}/{name}/result_{test_img_name}.glb')
except Exception as e:
print("error generating MESH", e)