|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import sys |
|
|
import cv2 |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from model.detector import Detector |
|
|
from model.Model_RGB import EstimateRGB |
|
|
from configs.sar_convnext_config import rgb_opt |
|
|
from configs.yolo_config import yolo_opt |
|
|
from configs.test_config import opt |
|
|
from thirdparty.utils.preprocessing import convert_bbox |
|
|
|
|
|
class Tester: |
|
|
def __init__(self, config): |
|
|
self.config = config |
|
|
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
self.estimate_method = self.config.estimate_method |
|
|
self.sar_model= self._build_model() |
|
|
yolo_opt.weights = os.path.join(self.config.ckpt_root, yolo_opt.weights) |
|
|
self.detect = Detector(yolo_opt) |
|
|
def _build_model(self): |
|
|
rgb_opt.checkpoint = os.path.join(self.config.ckpt_root, rgb_opt.checkpoint) |
|
|
sar_model = EstimateRGB(rgb_opt).to(self.device) |
|
|
return sar_model |
|
|
|
|
|
def _load_model(self): |
|
|
if self.estimate_method == 'single_hand': |
|
|
self.logger.info('Loading the {} model from {}...'.format(self.estimate_method, self.config.load_model)) |
|
|
checkpoint = torch.load(self.config.load_model, map_location=torch.device('cpu')) |
|
|
if 'net' in checkpoint: |
|
|
self.model.load_state_dict(checkpoint['net']) |
|
|
else: |
|
|
self.model.load_state_dict(checkpoint['network'], strict=False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def preprocessing(self, data, dets): |
|
|
if len(dets) < 1: |
|
|
self.logger.warning(f'No YOLO detections found for image') |
|
|
return |
|
|
else: |
|
|
output_data_list = [] |
|
|
if len(dets) == 1: |
|
|
dets = dets[0] |
|
|
for det_idx, (det_type, bbox) in enumerate(dets): |
|
|
data_dict = {'rgb': data['rgb'], |
|
|
'depth': data['depth'], |
|
|
'rgb_bbox': bbox, |
|
|
'depth_bbox': None, |
|
|
'hand_type': det_type} |
|
|
if data['depth'] is not None: |
|
|
depth_bbox = convert_bbox('depth', data['rgb'], data['depth'], bbox) |
|
|
data_dict['depth_bbox'] = depth_bbox |
|
|
output_data_list.append(data_dict) |
|
|
return output_data_list |
|
|
|
|
|
|
|
|
|
|
|
def infer(input_img_path, tester, local_repo_path): |
|
|
if tester is None: |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
repo_id = "sinkers/SAR" |
|
|
print("downloading~") |
|
|
local_repo_path = snapshot_download(repo_id=repo_id, allow_patterns=["ckpt/*"]) |
|
|
print("Done!") |
|
|
os.environ["LOCAL_REPO_PATH"] = local_repo_path |
|
|
|
|
|
opt.ckpt_root = local_repo_path |
|
|
tester = Tester(opt) |
|
|
|
|
|
yolo_detector = tester.detect |
|
|
sar_estimater = tester.sar_model |
|
|
|
|
|
rgb = np.array(input_img_path) |
|
|
data = { |
|
|
'rgb': rgb, |
|
|
'depth': None |
|
|
} |
|
|
pred, dets = yolo_detector.detect(rgb) |
|
|
|
|
|
|
|
|
|
|
|
if dets != [[]]: |
|
|
output_data_list = tester.preprocessing(data, dets) |
|
|
meta_info, output = sar_estimater.run(output_data_list) |
|
|
pose_img_rgb = meta_info['pose_img_rgb'] |
|
|
print(type(pose_img_rgb)) |
|
|
return pose_img_rgb, tester, local_repo_path |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pose_img_rgb, tester = infer(cv2.imread("test/wilor_rgb.png"), tester=None) |
|
|
cv2.imwrite('test/out.png',pose_img_rgb) |