{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import cv2\n", "from SegTracker import SegTracker\n", "from model_args import aot_args,sam_args,segtracker_args\n", "from PIL import Image\n", "from aot_tracker import _palette\n", "import numpy as np\n", "import torch\n", "import imageio\n", "import matplotlib.pyplot as plt\n", "from scipy.ndimage import binary_dilation\n", "import gc\n", "def save_prediction(pred_mask,output_dir,file_name):\n", " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n", " save_mask = save_mask.convert(mode='P')\n", " save_mask.putpalette(_palette)\n", " save_mask.save(os.path.join(output_dir,file_name))\n", "def colorize_mask(pred_mask):\n", " save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n", " save_mask = save_mask.convert(mode='P')\n", " save_mask.putpalette(_palette)\n", " save_mask = save_mask.convert(mode='RGB')\n", " return np.array(save_mask)\n", "def draw_mask(img, mask, alpha=0.5, id_countour=False):\n", " img_mask = np.zeros_like(img)\n", " img_mask = img\n", " if id_countour:\n", " # very slow ~ 1s per image\n", " obj_ids = np.unique(mask)\n", " obj_ids = obj_ids[obj_ids!=0]\n", "\n", " for id in obj_ids:\n", " # Overlay color on binary mask\n", " if id <= 255:\n", " color = _palette[id*3:id*3+3]\n", " else:\n", " color = [0,0,0]\n", " foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)\n", " binary_mask = (mask == id)\n", "\n", " # Compose image\n", " img_mask[binary_mask] = foreground[binary_mask]\n", "\n", " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n", " img_mask[countours, :] = 0\n", " else:\n", " binary_mask = (mask!=0)\n", " countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n", " foreground = img*(1-alpha)+colorize_mask(mask)*alpha\n", " img_mask[binary_mask] = foreground[binary_mask]\n", " img_mask[countours,:] = 0\n", " \n", " return img_mask.astype(img.dtype)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Set parameters for input and output" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "video_name = 'cell'\n", "io_args = {\n", " 'input_video': f'./assets/{video_name}.mp4',\n", " 'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks\n", " 'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video\n", " 'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization\n", "}" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Tuning SAM on the First Frame for Good Initialization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# choose good parameters in sam_args based on the first frame segmentation result\n", "# other arguments can be modified in model_args.py\n", "# note the object number limit is 255 by default, which requires < 10GB GPU memory with amp\n", "sam_args['generator_args'] = {\n", " 'points_per_side': 30,\n", " 'pred_iou_thresh': 0.8,\n", " 'stability_score_thresh': 0.9,\n", " 'crop_n_layers': 1,\n", " 'crop_n_points_downscale_factor': 2,\n", " 'min_mask_region_area': 200,\n", " }\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "frame_idx = 0\n", "segtracker = SegTracker(segtracker_args,sam_args,aot_args)\n", "segtracker.restart_tracker()\n", "with torch.cuda.amp.autocast():\n", " while cap.isOpened():\n", " ret, frame = cap.read()\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " pred_mask = segtracker.seg(frame)\n", " torch.cuda.empty_cache()\n", " obj_ids = np.unique(pred_mask)\n", " obj_ids = obj_ids[obj_ids!=0]\n", " print(\"processed frame {}, obj_num {}\".format(frame_idx,len(obj_ids)),end='\\n')\n", " break\n", " cap.release()\n", " init_res = draw_mask(frame,pred_mask,id_countour=False)\n", " plt.figure(figsize=(10,10))\n", " plt.axis('off')\n", " plt.imshow(init_res)\n", " plt.show()\n", " plt.figure(figsize=(10,10))\n", " plt.axis('off')\n", " plt.imshow(colorize_mask(pred_mask))\n", " plt.show()\n", "\n", " del segtracker\n", " torch.cuda.empty_cache()\n", " gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Generate Results for the Whole Video" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# For every sam_gap frames, we use SAM to find new objects and add them for tracking\n", "# larger sam_gap is faster but may not spot new objects in time\n", "segtracker_args = {\n", " 'sam_gap': 5, # the interval to run sam to segment new objects\n", " 'min_area': 200, # minimal mask area to add a new mask as a new object\n", " 'max_obj_num': 255, # maximal object number to track in a video\n", " 'min_new_obj_iou': 0.8, # the area of a new object in the background should > 80% \n", "}\n", "\n", "# source video to segment\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "fps = cap.get(cv2.CAP_PROP_FPS)\n", "# output masks\n", "output_dir = io_args['output_mask_dir']\n", "if not os.path.exists(output_dir):\n", " os.makedirs(output_dir)\n", "pred_list = []\n", "masked_pred_list = []\n", "\n", "torch.cuda.empty_cache()\n", "gc.collect()\n", "sam_gap = segtracker_args['sam_gap']\n", "frame_idx = 0\n", "segtracker = SegTracker(segtracker_args,sam_args,aot_args)\n", "segtracker.restart_tracker()\n", "\n", "with torch.cuda.amp.autocast():\n", " while cap.isOpened():\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " if frame_idx == 0:\n", " pred_mask = segtracker.seg(frame)\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " segtracker.add_reference(frame, pred_mask)\n", " elif (frame_idx % sam_gap) == 0:\n", " seg_mask = segtracker.seg(frame)\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " track_mask = segtracker.track(frame)\n", " # find new objects, and update tracker with new objects\n", " new_obj_mask = segtracker.find_new_objs(track_mask,seg_mask)\n", " save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')\n", " pred_mask = track_mask + new_obj_mask\n", " # segtracker.restart_tracker()\n", " segtracker.add_reference(frame, pred_mask)\n", " else:\n", " pred_mask = segtracker.track(frame,update_memory=True)\n", " torch.cuda.empty_cache()\n", " gc.collect()\n", " save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')\n", " # masked_frame = draw_mask(frame,pred_mask)\n", " # masked_pred_list.append(masked_frame)\n", " # plt.imshow(masked_frame)\n", " # plt.show() \n", " \n", " pred_list.append(pred_mask)\n", " \n", " \n", " print(\"processed frame {}, obj_num {}\".format(frame_idx,segtracker.get_obj_num()),end='\\r')\n", " frame_idx += 1\n", " cap.release()\n", " print('\\nfinished')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Save results for visualization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# draw pred mask on frame and save as a video\n", "cap = cv2.VideoCapture(io_args['input_video'])\n", "fps = cap.get(cv2.CAP_PROP_FPS)\n", "width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n", "height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n", "num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", "\n", "if io_args['input_video'][-3:]=='mp4':\n", " fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n", "elif io_args['input_video'][-3:] == 'avi':\n", " fourcc = cv2.VideoWriter_fourcc(*\"MJPG\")\n", " # fourcc = cv2.VideoWriter_fourcc(*\"XVID\")\n", "else:\n", " fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))\n", "out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))\n", "\n", "frame_idx = 0\n", "while cap.isOpened():\n", " ret, frame = cap.read()\n", " if not ret:\n", " break\n", " frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n", " pred_mask = pred_list[frame_idx]\n", " masked_frame = draw_mask(frame,pred_mask)\n", " # masked_frame = masked_pred_list[frame_idx]\n", " masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)\n", " out.write(masked_frame)\n", " print('frame {} writed'.format(frame_idx),end='\\r')\n", " frame_idx += 1\n", "out.release()\n", "cap.release()\n", "print(\"\\n{} saved\".format(io_args['output_video']))\n", "print('\\nfinished')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save colorized masks as a gif\n", "imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)\n", "print(\"{} saved\".format(io_args['output_gif']))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "301" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# manually release memory (after cuda out of memory)\n", "del segtracker\n", "torch.cuda.empty_cache()\n", "gc.collect()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.5 64-bit ('ldm': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "536611da043600e50719c9460971b5220bad26cd4a87e5994bfd4c9e9e5e7fb0" } } }, "nbformat": 4, "nbformat_minor": 2 }