{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "from SegTracker import SegTracker\n",
    "from model_args import aot_args,sam_args,segtracker_args\n",
    "from PIL import Image\n",
    "from aot_tracker import _palette\n",
    "import numpy as np\n",
    "import torch\n",
    "import imageio\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.ndimage import binary_dilation\n",
    "import gc\n",
    "def save_prediction(pred_mask,output_dir,file_name):\n",
    "    save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n",
    "    save_mask = save_mask.convert(mode='P')\n",
    "    save_mask.putpalette(_palette)\n",
    "    save_mask.save(os.path.join(output_dir,file_name))\n",
    "def colorize_mask(pred_mask):\n",
    "    save_mask = Image.fromarray(pred_mask.astype(np.uint8))\n",
    "    save_mask = save_mask.convert(mode='P')\n",
    "    save_mask.putpalette(_palette)\n",
    "    save_mask = save_mask.convert(mode='RGB')\n",
    "    return np.array(save_mask)\n",
    "def draw_mask(img, mask, alpha=0.7, id_countour=False):\n",
    "    img_mask = np.zeros_like(img)\n",
    "    img_mask = img\n",
    "    if id_countour:\n",
    "        # very slow ~ 1s per image\n",
    "        obj_ids = np.unique(mask)\n",
    "        obj_ids = obj_ids[obj_ids!=0]\n",
    "\n",
    "        for id in obj_ids:\n",
    "            # Overlay color on  binary mask\n",
    "            if id <= 255:\n",
    "                color = _palette[id*3:id*3+3]\n",
    "            else:\n",
    "                color = [0,0,0]\n",
    "            foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)\n",
    "            binary_mask = (mask == id)\n",
    "\n",
    "            # Compose image\n",
    "            img_mask[binary_mask] = foreground[binary_mask]\n",
    "\n",
    "            countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n",
    "            img_mask[countours, :] = 0\n",
    "    else:\n",
    "        binary_mask = (mask!=0)\n",
    "        countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask\n",
    "        foreground = img*(1-alpha)+colorize_mask(mask)*alpha\n",
    "        img_mask[binary_mask] = foreground[binary_mask]\n",
    "        img_mask[countours,:] = 0\n",
    "        \n",
    "    return img_mask.astype(img.dtype)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set parameters for input and output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "video_name = 'cars'\n",
    "io_args = {\n",
    "    'input_video': f'./assets/{video_name}.mp4',\n",
    "    'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks\n",
    "    'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video\n",
    "    'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tuning Grounding-DINO and SAM on the First Frame for Good Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose good parameters in sam_args based on the first frame segmentation result\n",
    "# other arguments can be modified in model_args.py\n",
    "# note the object number limit is 255 by default, which requires < 10GB GPU memory with amp\n",
    "sam_args['generator_args'] = {\n",
    "        'points_per_side': 30,\n",
    "        'pred_iou_thresh': 0.8,\n",
    "        'stability_score_thresh': 0.9,\n",
    "        'crop_n_layers': 1,\n",
    "        'crop_n_points_downscale_factor': 2,\n",
    "        'min_mask_region_area': 200,\n",
    "    }\n",
    "\n",
    "# Set Text args\n",
    "'''\n",
    "parameter:\n",
    "    grounding_caption: Text prompt to detect objects in key-frames\n",
    "    box_threshold: threshold for box \n",
    "    text_threshold: threshold for label(text)\n",
    "    box_size_threshold: If the size ratio between the box and the frame is larger than the box_size_threshold, the box will be ignored. This is used to filter out large boxes.\n",
    "    reset_image: reset the image embeddings for SAM\n",
    "'''\n",
    "grounding_caption = \"car.suv\"\n",
    "box_threshold, text_threshold, box_size_threshold, reset_image = 0.35, 0.5, 0.5, True\n",
    "\n",
    "cap = cv2.VideoCapture(io_args['input_video'])\n",
    "frame_idx = 0\n",
    "segtracker = SegTracker(segtracker_args,sam_args,aot_args)\n",
    "segtracker.restart_tracker()\n",
    "with torch.cuda.amp.autocast():\n",
    "    while cap.isOpened():\n",
    "        ret, frame = cap.read()\n",
    "        frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
    "        pred_mask, annotated_frame = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold)\n",
    "        torch.cuda.empty_cache()\n",
    "        obj_ids = np.unique(pred_mask)\n",
    "        obj_ids = obj_ids[obj_ids!=0]\n",
    "        print(\"processed frame {}, obj_num {}\".format(frame_idx,len(obj_ids)),end='\\n')\n",
    "        break\n",
    "    cap.release()\n",
    "    init_res = draw_mask(annotated_frame, pred_mask,id_countour=False)\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.axis('off')\n",
    "    plt.imshow(init_res)\n",
    "    plt.show()\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.axis('off')\n",
    "    plt.imshow(colorize_mask(pred_mask))\n",
    "    plt.show()\n",
    "\n",
    "    del segtracker\n",
    "    torch.cuda.empty_cache()\n",
    "    gc.collect()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Results for the Whole Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For every sam_gap frames, we use SAM to find new objects and add them for tracking\n",
    "# larger sam_gap is faster but may not spot new objects in time\n",
    "segtracker_args = {\n",
    "    'sam_gap': 49, # the interval to run sam to segment new objects\n",
    "    'min_area': 200, # minimal mask area to add a new mask as a new object\n",
    "    'max_obj_num': 255, # maximal object number to track in a video\n",
    "    'min_new_obj_iou': 0.8, # the area of a new object in the background should > 80% \n",
    "}\n",
    "\n",
    "# source video to segment\n",
    "cap = cv2.VideoCapture(io_args['input_video'])\n",
    "fps = cap.get(cv2.CAP_PROP_FPS)\n",
    "# output masks\n",
    "output_dir = io_args['output_mask_dir']\n",
    "if not os.path.exists(output_dir):\n",
    "    os.makedirs(output_dir)\n",
    "pred_list = []\n",
    "masked_pred_list = []\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()\n",
    "sam_gap = segtracker_args['sam_gap']\n",
    "frame_idx = 0\n",
    "segtracker = SegTracker(segtracker_args, sam_args, aot_args)\n",
    "segtracker.restart_tracker()\n",
    "\n",
    "with torch.cuda.amp.autocast():\n",
    "    while cap.isOpened():\n",
    "        ret, frame = cap.read()\n",
    "        if not ret:\n",
    "            break\n",
    "        frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
    "        if frame_idx == 0:\n",
    "            pred_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)\n",
    "            # pred_mask = cv2.imread('./debug/first_frame_mask.png', 0)\n",
    "            torch.cuda.empty_cache()\n",
    "            gc.collect()\n",
    "            segtracker.add_reference(frame, pred_mask)\n",
    "        elif (frame_idx % sam_gap) == 0:\n",
    "            seg_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)\n",
    "            save_prediction(seg_mask, './debug/seg_result', str(frame_idx)+'.png')\n",
    "            torch.cuda.empty_cache()\n",
    "            gc.collect()\n",
    "            track_mask = segtracker.track(frame)\n",
    "            save_prediction(track_mask, './debug/aot_result', str(frame_idx)+'.png')\n",
    "            # find new objects, and update tracker with new objects\n",
    "            new_obj_mask = segtracker.find_new_objs(track_mask, seg_mask)\n",
    "            if np.sum(new_obj_mask > 0) >  frame.shape[0] * frame.shape[1] * 0.4:\n",
    "                new_obj_mask = np.zeros_like(new_obj_mask)\n",
    "            save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')\n",
    "            pred_mask = track_mask + new_obj_mask\n",
    "            # segtracker.restart_tracker()\n",
    "            segtracker.add_reference(frame, pred_mask)\n",
    "        else:\n",
    "            pred_mask = segtracker.track(frame,update_memory=True)\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "        \n",
    "        save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')\n",
    "        # masked_frame = draw_mask(frame,pred_mask)\n",
    "        # masked_pred_list.append(masked_frame)\n",
    "        # plt.imshow(masked_frame)\n",
    "        # plt.show() \n",
    "        \n",
    "        pred_list.append(pred_mask)\n",
    "        \n",
    "        \n",
    "        print(\"processed frame {}, obj_num {}\".format(frame_idx,segtracker.get_obj_num()),end='\\r')\n",
    "        frame_idx += 1\n",
    "    cap.release()\n",
    "    print('\\nfinished')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save results for visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# draw pred mask on frame and save as a video\n",
    "cap = cv2.VideoCapture(io_args['input_video'])\n",
    "fps = cap.get(cv2.CAP_PROP_FPS)\n",
    "width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
    "height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
    "num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "\n",
    "if io_args['input_video'][-3:]=='mp4':\n",
    "    fourcc =  cv2.VideoWriter_fourcc(*\"mp4v\")\n",
    "elif io_args['input_video'][-3:] == 'avi':\n",
    "    fourcc =  cv2.VideoWriter_fourcc(*\"MJPG\")\n",
    "    # fourcc = cv2.VideoWriter_fourcc(*\"XVID\")\n",
    "else:\n",
    "    fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))\n",
    "out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))\n",
    "\n",
    "frame_idx = 0\n",
    "while cap.isOpened():\n",
    "    ret, frame = cap.read()\n",
    "    if not ret:\n",
    "        break\n",
    "    frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)\n",
    "    pred_mask = pred_list[frame_idx]\n",
    "    masked_frame = draw_mask(frame,pred_mask)\n",
    "    # masked_frame = masked_pred_list[frame_idx]\n",
    "    masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)\n",
    "    out.write(masked_frame)\n",
    "    print('frame {} writed'.format(frame_idx),end='\\r')\n",
    "    frame_idx += 1\n",
    "out.release()\n",
    "cap.release()\n",
    "print(\"\\n{} saved\".format(io_args['output_video']))\n",
    "print('\\nfinished')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save colorized masks as a gif\n",
    "imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)\n",
    "print(\"{} saved\".format(io_args['output_gif']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# manually release memory (after cuda out of memory)\n",
    "del segtracker\n",
    "torch.cuda.empty_cache()\n",
    "gc.collect()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 64-bit ('ldm': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5-final"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "536611da043600e50719c9460971b5220bad26cd4a87e5994bfd4c9e9e5e7fb0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}