Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
from model_args import segtracker_args,sam_args,aot_args | |
from PIL import Image | |
from aot_tracker import _palette | |
import numpy as np | |
import torch | |
import gc | |
import imageio | |
from scipy.ndimage import binary_dilation | |
def save_prediction(pred_mask,output_dir,file_name): | |
save_mask = Image.fromarray(pred_mask.astype(np.uint8)) | |
save_mask = save_mask.convert(mode='P') | |
save_mask.putpalette(_palette) | |
save_mask.save(os.path.join(output_dir,file_name)) | |
def colorize_mask(pred_mask): | |
save_mask = Image.fromarray(pred_mask.astype(np.uint8)) | |
save_mask = save_mask.convert(mode='P') | |
save_mask.putpalette(_palette) | |
save_mask = save_mask.convert(mode='RGB') | |
return np.array(save_mask) | |
def draw_mask(img, mask, alpha=0.5, id_countour=False): | |
img_mask = np.zeros_like(img) | |
img_mask = img | |
if id_countour: | |
# very slow ~ 1s per image | |
obj_ids = np.unique(mask) | |
obj_ids = obj_ids[obj_ids!=0] | |
for id in obj_ids: | |
# Overlay color on binary mask | |
if id <= 255: | |
color = _palette[id*3:id*3+3] | |
else: | |
color = [0,0,0] | |
foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color) | |
binary_mask = (mask == id) | |
# Compose image | |
img_mask[binary_mask] = foreground[binary_mask] | |
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask | |
img_mask[countours, :] = 0 | |
else: | |
binary_mask = (mask!=0) | |
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask | |
foreground = img*(1-alpha)+colorize_mask(mask)*alpha | |
img_mask[binary_mask] = foreground[binary_mask] | |
img_mask[countours,:] = 0 | |
return img_mask.astype(img.dtype) | |
def create_dir(dir_path): | |
if os.path.isdir(dir_path): | |
os.system(f"rm -r {dir_path}") | |
os.makedirs(dir_path) | |
aot_model2ckpt = { | |
"deaotb": "./ckpt/DeAOTB_PRE_YTB_DAV.pth", | |
"deaotl": "./ckpt/DeAOTL_PRE_YTB_DAV", | |
"r50_deaotl": "./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth", | |
} | |
# Tracking Object In Video | |
def tracking_objects_in_video(SegTracker, input_video, input_img_seq, fps): | |
if input_video is not None: | |
video_name = os.path.basename(input_video).split('.')[0] | |
elif input_img_seq is not None: | |
file_name = input_img_seq.name.split('/')[-1].split('.')[0] | |
file_path = f'./assets/{file_name}' | |
imgs_path = sorted([os.path.join(file_path, img_name) for img_name in os.listdir(file_path)]) | |
video_name = file_name | |
else: | |
return None, None | |
# create dir to save result | |
tracking_result_dir = f'{os.path.join(os.path.dirname(__file__), "tracking_results", f"{video_name}")}' | |
create_dir(tracking_result_dir) | |
io_args = { | |
'tracking_result_dir': tracking_result_dir, | |
'output_mask_dir': f'{tracking_result_dir}/{video_name}_masks', | |
'output_masked_frame_dir': f'{tracking_result_dir}/{video_name}_masked_frames', | |
'output_video': f'{tracking_result_dir}/{video_name}_seg.mp4', # keep same format as input video | |
'output_gif': f'{tracking_result_dir}/{video_name}_seg.gif', | |
} | |
if input_video is not None: | |
return video_type_input_tracking(SegTracker, input_video, io_args, video_name) | |
elif input_img_seq is not None: | |
return img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps) | |
def video_type_input_tracking(SegTracker, input_video, io_args, video_name): | |
# source video to segment | |
cap = cv2.VideoCapture(input_video) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
# create dir to save predicted mask and masked frame | |
output_mask_dir = io_args['output_mask_dir'] | |
create_dir(io_args['output_mask_dir']) | |
create_dir(io_args['output_masked_frame_dir']) | |
pred_list = [] | |
masked_pred_list = [] | |
torch.cuda.empty_cache() | |
gc.collect() | |
sam_gap = SegTracker.sam_gap | |
frame_idx = 0 | |
with torch.cuda.amp.autocast(): | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
if frame_idx == 0: | |
pred_mask = SegTracker.first_frame_mask | |
torch.cuda.empty_cache() | |
gc.collect() | |
elif (frame_idx % sam_gap) == 0: | |
seg_mask = SegTracker.seg(frame) | |
torch.cuda.empty_cache() | |
gc.collect() | |
track_mask = SegTracker.track(frame) | |
# find new objects, and update tracker with new objects | |
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) | |
save_prediction(new_obj_mask, output_mask_dir, str(frame_idx).zfill(5) + '_new.png') | |
pred_mask = track_mask + new_obj_mask | |
# segtracker.restart_tracker() | |
SegTracker.add_reference(frame, pred_mask) | |
else: | |
pred_mask = SegTracker.track(frame,update_memory=True) | |
torch.cuda.empty_cache() | |
gc.collect() | |
save_prediction(pred_mask, output_mask_dir, str(frame_idx).zfill(5) + '.png') | |
pred_list.append(pred_mask) | |
print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') | |
frame_idx += 1 | |
cap.release() | |
print('\nfinished') | |
################## | |
# Visualization | |
################## | |
# draw pred mask on frame and save as a video | |
cap = cv2.VideoCapture(input_video) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
# if input_video[-3:]=='mp4': | |
# fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
# elif input_video[-3:] == 'avi': | |
# fourcc = cv2.VideoWriter_fourcc(*"MJPG") | |
# # fourcc = cv2.VideoWriter_fourcc(*"XVID") | |
# else: | |
# fourcc = int(cap.get(cv2.CAP_PROP_FOURCC)) | |
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) | |
frame_idx = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
pred_mask = pred_list[frame_idx] | |
masked_frame = draw_mask(frame, pred_mask) | |
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{str(frame_idx).zfill(5)}.png", masked_frame[:, :, ::-1]) | |
masked_pred_list.append(masked_frame) | |
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) | |
out.write(masked_frame) | |
print('frame {} writed'.format(frame_idx),end='\r') | |
frame_idx += 1 | |
out.release() | |
cap.release() | |
print("\n{} saved".format(io_args['output_video'])) | |
print('\nfinished') | |
# save colorized masks as a gif | |
imageio.mimsave(io_args['output_gif'], masked_pred_list, duration=fps) | |
print("{} saved".format(io_args['output_gif'])) | |
# zip predicted mask | |
os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") | |
# manually release memory (after cuda out of memory) | |
del SegTracker | |
torch.cuda.empty_cache() | |
gc.collect() | |
return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" | |
# return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_seg.mp4" | |
def img_seq_type_input_tracking(SegTracker, io_args, video_name, imgs_path, fps): | |
# create dir to save predicted mask and masked frame | |
output_mask_dir = io_args['output_mask_dir'] | |
create_dir(io_args['output_mask_dir']) | |
create_dir(io_args['output_masked_frame_dir']) | |
pred_list = [] | |
masked_pred_list = [] | |
torch.cuda.empty_cache() | |
gc.collect() | |
sam_gap = SegTracker.sam_gap | |
frame_idx = 0 | |
with torch.cuda.amp.autocast(): | |
for img_path in imgs_path: | |
frame_name = os.path.basename(img_path).split('.')[0] | |
frame = cv2.imread(img_path) | |
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
if frame_idx == 0: | |
pred_mask = SegTracker.first_frame_mask | |
torch.cuda.empty_cache() | |
gc.collect() | |
elif (frame_idx % sam_gap) == 0: | |
seg_mask = SegTracker.seg(frame) | |
torch.cuda.empty_cache() | |
gc.collect() | |
track_mask = SegTracker.track(frame) | |
# find new objects, and update tracker with new objects | |
new_obj_mask = SegTracker.find_new_objs(track_mask,seg_mask) | |
save_prediction(new_obj_mask, output_mask_dir, f'{frame_name}_new.png') | |
pred_mask = track_mask + new_obj_mask | |
# segtracker.restart_tracker() | |
SegTracker.add_reference(frame, pred_mask) | |
else: | |
pred_mask = SegTracker.track(frame,update_memory=True) | |
torch.cuda.empty_cache() | |
gc.collect() | |
save_prediction(pred_mask, output_mask_dir, f'{frame_name}.png') | |
pred_list.append(pred_mask) | |
print("processed frame {}, obj_num {}".format(frame_idx, SegTracker.get_obj_num()),end='\r') | |
frame_idx += 1 | |
print('\nfinished') | |
################## | |
# Visualization | |
################## | |
# draw pred mask on frame and save as a video | |
height, width = pred_list[0].shape | |
fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height)) | |
frame_idx = 0 | |
for img_path in imgs_path: | |
frame_name = os.path.basename(img_path).split('.')[0] | |
frame = cv2.imread(img_path) | |
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) | |
pred_mask = pred_list[frame_idx] | |
masked_frame = draw_mask(frame, pred_mask) | |
masked_pred_list.append(masked_frame) | |
cv2.imwrite(f"{io_args['output_masked_frame_dir']}/{frame_name}.png", masked_frame[:, :, ::-1]) | |
masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR) | |
out.write(masked_frame) | |
print('frame {} writed'.format(frame_name),end='\r') | |
frame_idx += 1 | |
out.release() | |
print("\n{} saved".format(io_args['output_video'])) | |
print('\nfinished') | |
# save colorized masks as a gif | |
imageio.mimsave(io_args['output_gif'], masked_pred_list, duration=fps) | |
print("{} saved".format(io_args['output_gif'])) | |
# zip predicted mask | |
os.system(f"zip -r {io_args['tracking_result_dir']}/{video_name}_pred_mask.zip {io_args['output_mask_dir']}") | |
# manually release memory (after cuda out of memory) | |
del SegTracker | |
torch.cuda.empty_cache() | |
gc.collect() | |
return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_pred_mask.zip" | |
# return io_args['output_video'], f"{io_args['tracking_result_dir']}/{video_name}_seg.mp4" |