WALT / infer.py
Your Name
update demo
a56642d
from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
from mmdet.core.mask.utils import encode_mask_results
import numpy as np
import mmcv
import torch
from imantics import Polygons, Mask
import json
import os
import cv2, glob
class detections():
def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth', threshold=0.85):
self.model = init_detector(cfg_path, model_path, device=device)
self.all_preds = []
self.all_scores = []
self.index = []
self.score_thr = threshold
self.result = []
self.record_dict = {'model': cfg_path,'results': []}
self.detect_count = []
def run_on_image(self, image):
self.result = inference_detector(self.model, image)
image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
return image_labelled
def process_output(self, count):
result = self.result
infer_result = {'url': count,
'boxes': [],
'scores': [],
'keypoints': [],
'segmentation': [],
'label_ids': [],
'track': [],
'labels': []}
if isinstance(result, tuple):
bbox_result, segm_result = result
#segm_result = encode_mask_results(segm_result)
if isinstance(segm_result, tuple):
segm_result = segm_result[0] # ms rcnn
bboxes = np.vstack(bbox_result)
labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
labels = np.concatenate(labels)
segms = None
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
if bbox[-1].item() <0.3:
continue
box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
polygons = Mask(segm).polygons()
infer_result['boxes'].append(box)
infer_result['segmentation'].append(polygons.segmentation)
infer_result['scores'].append(bbox[-1].item())
infer_result['labels'].append(self.model.CLASSES[label])
infer_result['label_ids'].append(label)
self.record_dict['results'].append(infer_result)
self.detect_count = labels
def write_json(self, filename):
with open(filename + '.json', 'w') as f:
json.dump(self.record_dict, f)
def main():
if torch.cuda.is_available() == False:
device='cpu'
else:
device='cuda:0'
detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth')
filenames = sorted(glob.glob('demo/images/*'))
count = 0
for filename in filenames:
img=cv2.imread(filename)
try:
img = detect_people.run_on_image(img)
img = detect.run_on_image(img)
except:
continue
count=count+1
try:
import os
os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
os.mkdirs(os.path.dirname(filename))
except:
print('done')
cv2.imwrite(filename.replace('demo','demo/results/'),img)
if count == 30000:
break
try:
detect.process_output(count)
except:
continue
'''
np.savez('FC', a= detect.record_dict)
with open('check.json', 'w') as f:
json.dump(detect.record_dict, f)
detect.write_json('seq3')
asas
detect.process_output(0)
'''
if __name__ == "__main__":
main()