SAR / infer.py
sinkers's picture
Update infer.py
473de36 verified
# -*- coding: utf-8 -*-
import os
import torch
import sys
import cv2
import numpy as np
# sys.path.append('../../')
from model.detector import Detector
from model.Model_RGB import EstimateRGB
from configs.sar_convnext_config import rgb_opt
from configs.yolo_config import yolo_opt
from configs.test_config import opt
from thirdparty.utils.preprocessing import convert_bbox
class Tester:
def __init__(self, config):
self.config = config
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.estimate_method = self.config.estimate_method
self.sar_model= self._build_model()
yolo_opt.weights = os.path.join(self.config.ckpt_root, yolo_opt.weights)
self.detect = Detector(yolo_opt)
def _build_model(self):
rgb_opt.checkpoint = os.path.join(self.config.ckpt_root, rgb_opt.checkpoint)
sar_model = EstimateRGB(rgb_opt).to(self.device)
return sar_model
def _load_model(self):
if self.estimate_method == 'single_hand':
self.logger.info('Loading the {} model from {}...'.format(self.estimate_method, self.config.load_model))
checkpoint = torch.load(self.config.load_model, map_location=torch.device('cpu'))
if 'net' in checkpoint:
self.model.load_state_dict(checkpoint['net'])
else:
self.model.load_state_dict(checkpoint['network'], strict=False)
@torch.no_grad()
def preprocessing(self, data, dets):
if len(dets) < 1:
self.logger.warning(f'No YOLO detections found for image')
return
else:
output_data_list = []
if len(dets) == 1:# single img
dets = dets[0]
for det_idx, (det_type, bbox) in enumerate(dets):
data_dict = {'rgb': data['rgb'],
'depth': data['depth'],
'rgb_bbox': bbox,
'depth_bbox': None,
'hand_type': det_type}
if data['depth'] is not None:
depth_bbox = convert_bbox('depth', data['rgb'], data['depth'], bbox)
data_dict['depth_bbox'] = depth_bbox
output_data_list.append(data_dict)
return output_data_list
def infer(input_img_path, tester, local_repo_path):
if tester is None:
from huggingface_hub import hf_hub_download, snapshot_download
repo_id = "sinkers/SAR"
print("downloading~")
local_repo_path = snapshot_download(repo_id=repo_id, allow_patterns=["ckpt/*"])
print("Done!")
os.environ["LOCAL_REPO_PATH"] = local_repo_path
opt.ckpt_root = local_repo_path
tester = Tester(opt)
yolo_detector = tester.detect
sar_estimater = tester.sar_model
# rgb = cv2.imread(input_img_path)
rgb = np.array(input_img_path)
data = {
'rgb': rgb,
'depth': None
}
pred, dets = yolo_detector.detect(rgb)
# crop_img = yolo_detector.plot_bbox(rgb,pred)
# return crop_img
if dets != [[]]:
output_data_list = tester.preprocessing(data, dets)
meta_info, output = sar_estimater.run(output_data_list)
pose_img_rgb = meta_info['pose_img_rgb']
print(type(pose_img_rgb))
return pose_img_rgb, tester, local_repo_path
# cv2.imwrite('test/out.png',pose_img_rgb)
# return 'test/out.png'
if __name__ == "__main__":
pose_img_rgb, tester = infer(cv2.imread("test/wilor_rgb.png"), tester=None)
cv2.imwrite('test/out.png',pose_img_rgb)