{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ultralytics import YOLO\n",
    "\n",
    "#To initialize a model from scratch:\n",
    "#model = YOLO('yolov8n-segreg.yaml')\n",
    "\n",
    "# To load a pretrained model:  (Available pretrained are nano, medium and xlarge)\n",
    "model = YOLO('./trained_models/nano.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# To train the model on a dataset: (GPU recommended)\n",
    "\n",
    "model.train(data = './datasets/mmc_random/data.yaml', \n",
    "            mosaic = 1.0,\n",
    "            hsv_h= 0.5,  # (float) image HSV-Hue augmentation (fraction)\n",
    "            hsv_s= 0.5,  # (float) image HSV-Saturation augmentation (fraction)\n",
    "            hsv_v= 0.5,  # (float) image HSV-Value augmentation (fraction)\n",
    "            degrees= 25.0,  # (float) image rotation (+/- deg)\n",
    "            translate= 0.2,  # (float) image translation (+/- fraction)\n",
    "            scale= 0.75,  # (float) image scale (+/- gain)\n",
    "            shear= 10.0,  # (float) image shear (+/- deg)\n",
    "            perspective= 0.001,  # (float) image perspective (+/- fraction), range 0-0.001\n",
    "            flipud= 0.5,  # (float) image flip up-down (probability)\n",
    "            fliplr= 0.5,\n",
    "            epochs=1, \n",
    "            batch=6,\n",
    "            reg_gain = 1.0,  # Weight to deactivate the design variable regression loss in the total loss function.\n",
    "            amp = True,\n",
    "            warmup_epochs=0,\n",
    "            imgsz=int(640),\n",
    "            workers=12,\n",
    "            lr0=3e-4,\n",
    "            cache = \"ram\",\n",
    "            cos_lr = True,\n",
    "            single_cls=True,\n",
    "            rect=False,\n",
    "            overlap_mask=False,\n",
    "            mask_ratio=1,\n",
    "            optimizer = \"AdamW\",\n",
    "            pretrained=False,\n",
    "            patience=100,\n",
    "            weight_decay=1e-2, \n",
    "            val=True,\n",
    "            resume=False,\n",
    "            device='cpu',\n",
    "            plots=True\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 10\u001b[0m\n\u001b[1;32m      8\u001b[0m image \u001b[38;5;241m=\u001b[39m preprocess_image(image_path, threshold_value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.9\u001b[39m, upscale\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m      9\u001b[0m \u001b[38;5;66;03m# Run the model\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mrun_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miou\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimgsz\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m320\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# Process the results\u001b[39;00m\n\u001b[1;32m     13\u001b[0m input_image_array_tensor, seg_result, pred_Phi, sum_pred_H, final_H, dice_loss, tversky_loss \u001b[38;5;241m=\u001b[39m process_results(results, image)\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/./utils/yolo_utils.py:146\u001b[0m, in \u001b[0;36mrun_model\u001b[0;34m(model, image, conf, iou, imgsz)\u001b[0m\n\u001b[1;32m    145\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun_model\u001b[39m(model, image, conf\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.05\u001b[39m, iou\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.5\u001b[39m, imgsz\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m640\u001b[39m):\n\u001b[0;32m--> 146\u001b[0m     results \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miou\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43miou\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimgsz\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mimgsz\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    147\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m results\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/engine/model.py:96\u001b[0m, in \u001b[0;36mModel.__call__\u001b[0;34m(self, source, stream, **kwargs)\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, source\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, stream\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m     95\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Calls the 'predict' function with given arguments to perform object detection.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 96\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/yolov8to/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/engine/model.py:238\u001b[0m, in \u001b[0;36mModel.predict\u001b[0;34m(self, source, stream, predictor, **kwargs)\u001b[0m\n\u001b[1;32m    236\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prompts \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredictor, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mset_prompts\u001b[39m\u001b[38;5;124m'\u001b[39m):  \u001b[38;5;66;03m# for SAM-type models\u001b[39;00m\n\u001b[1;32m    237\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredictor\u001b[38;5;241m.\u001b[39mset_prompts(prompts)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredictor\u001b[38;5;241m.\u001b[39mpredict_cli(source\u001b[38;5;241m=\u001b[39msource) \u001b[38;5;28;01mif\u001b[39;00m is_cli \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredictor\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/engine/predictor.py:194\u001b[0m, in \u001b[0;36mBasePredictor.__call__\u001b[0;34m(self, source, model, stream, *args, **kwargs)\u001b[0m\n\u001b[1;32m    192\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstream_inference(source, model, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    193\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 194\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstream_inference\u001b[49m\u001b[43m(\u001b[49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/opt/anaconda3/envs/yolov8to/lib/python3.11/site-packages/torch/utils/_contextlib.py:35\u001b[0m, in \u001b[0;36m_wrap_generator.<locals>.generator_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     33\u001b[0m     \u001b[38;5;66;03m# Issuing `None` to a generator fires it up\u001b[39;00m\n\u001b[1;32m     34\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m---> 35\u001b[0m         response \u001b[38;5;241m=\u001b[39m \u001b[43mgen\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     37\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m     38\u001b[0m         \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     39\u001b[0m             \u001b[38;5;66;03m# Forward the response to our caller and get its next request\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/engine/predictor.py:257\u001b[0m, in \u001b[0;36mBasePredictor.stream_inference\u001b[0;34m(self, source, model, *args, **kwargs)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;66;03m# Postprocess\u001b[39;00m\n\u001b[1;32m    256\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m profilers[\u001b[38;5;241m2\u001b[39m]:\n\u001b[0;32m--> 257\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresults \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpostprocess\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mim0s\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    258\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_callbacks(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mon_predict_postprocess_end\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    260\u001b[0m \u001b[38;5;66;03m# Visualize, save, write results\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/models/yolo/segment/predict.py:48\u001b[0m, in \u001b[0;36mSegmentationPredictor.postprocess\u001b[0;34m(self, preds, img, orig_imgs)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpostprocess\u001b[39m(\u001b[38;5;28mself\u001b[39m, preds, img, orig_imgs):\n\u001b[1;32m     46\u001b[0m     \u001b[38;5;66;03m#print(preds[0].shape)\u001b[39;00m\n\u001b[1;32m     47\u001b[0m     regression_preds \u001b[38;5;241m=\u001b[39m preds[\u001b[38;5;241m1\u001b[39m][\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]\n\u001b[0;32m---> 48\u001b[0m     p, final_reg \u001b[38;5;241m=\u001b[39m \u001b[43mops\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnon_max_suppression\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprediction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreds\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     49\u001b[0m \u001b[43m                                           \u001b[49m\u001b[43mmask_coef\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpreds\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     50\u001b[0m \u001b[43m                                           \u001b[49m\u001b[43mproto\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpreds\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     51\u001b[0m \u001b[43m                                           \u001b[49m\u001b[43mimg_shape\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mimg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     52\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43mconf_thres\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     53\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43miou_thres\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miou\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     54\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43magnostic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43magnostic_nms\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     55\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43mmax_det\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_det\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     56\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43mnc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnames\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     57\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43mregression_var\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mregression_preds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     58\u001b[0m \u001b[43m                                   \u001b[49m\u001b[43mclasses\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclasses\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     59\u001b[0m     \u001b[38;5;66;03m#print(p[0].shape)\u001b[39;00m\n\u001b[1;32m     60\u001b[0m     results \u001b[38;5;241m=\u001b[39m []\n",
      "File \u001b[0;32m~/Documents/GitHub/YOLOv8-TO/yolov8-to/ultralytics/utils/ops.py:329\u001b[0m, in \u001b[0;36mnon_max_suppression\u001b[0;34m(prediction, mask_coef, proto, img_shape, regression_var, conf_thres, iou_thres, classes, agnostic, multi_label, labels, max_det, nc, max_time_img, max_nms, max_wh)\u001b[0m\n\u001b[1;32m    327\u001b[0m \u001b[38;5;66;03m# Compute intersection and union for the current batch\u001b[39;00m\n\u001b[1;32m    328\u001b[0m intersection_batch \u001b[38;5;241m=\u001b[39m (masks_bool_batch\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m&\u001b[39m masks_bool\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m))\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m))\n\u001b[0;32m--> 329\u001b[0m union_batch \u001b[38;5;241m=\u001b[39m \u001b[43m(\u001b[49m\u001b[43mmasks_bool_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m|\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mmasks_bool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m))\n\u001b[1;32m    331\u001b[0m \u001b[38;5;66;03m# Calculate areas for IoMin\u001b[39;00m\n\u001b[1;32m    332\u001b[0m area_batch \u001b[38;5;241m=\u001b[39m masks_bool_batch\u001b[38;5;241m.\u001b[39mfloat()\u001b[38;5;241m.\u001b[39msum(dim\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)  \u001b[38;5;66;03m# Areas of masks in the current batch\u001b[39;00m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('./utils')\n",
    "\n",
    "from yolo_utils import preprocess_image, run_model, process_results, plot_results\n",
    "\n",
    "# Load and preprocess the image\n",
    "image_path = './test.png'\n",
    "image = preprocess_image(image_path, threshold_value=0.9, upscale=False)\n",
    "# Run the model\n",
    "results = run_model(model, image, conf=0.1, iou=0.3, imgsz=640)\n",
    "# Process the results\n",
    "\n",
    "input_image_array_tensor, seg_result, pred_Phi, sum_pred_H, final_H, dice_loss, tversky_loss = process_results(results, image)\n",
    "# Plot the results\n",
    "plot_results(input_image_array_tensor, seg_result, pred_Phi, sum_pred_H, final_H, dice_loss, tversky_loss, filename='output_plot.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "customyolo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}