{ "cells": [ { "cell_type": "markdown", "id": "9bca0f41", "metadata": {}, "source": [ "# Simple inference example with CroCo-Stereo or CroCo-Flow" ] }, { "cell_type": "code", "execution_count": null, "id": "80653ef7", "metadata": {}, "outputs": [], "source": [ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." ] }, { "cell_type": "markdown", "id": "4f033862", "metadata": {}, "source": [ "First download the model(s) of your choice by running\n", "```\n", "bash stereoflow/download_model.sh crocostereo.pth\n", "bash stereoflow/download_model.sh crocoflow.pth\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "1fb2e392", "metadata": {}, "outputs": [], "source": [ "import torch\n", "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", "import matplotlib.pylab as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "e0e25d77", "metadata": {}, "outputs": [], "source": [ "from stereoflow.test import _load_model_and_criterion\n", "from stereoflow.engine import tiled_pred\n", "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n", "from stereoflow.datasets_flow import flowToColor\n", "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower" ] }, { "cell_type": "markdown", "id": "86a921f5", "metadata": {}, "source": [ "### CroCo-Stereo example" ] }, { "cell_type": "code", "execution_count": null, "id": "64e483cb", "metadata": {}, "outputs": [], "source": [ "image1 = np.asarray(Image.open(''))\n", "image2 = np.asarray(Image.open(''))" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d04303", "metadata": {}, "outputs": [], "source": [ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "47dc14b5", "metadata": {}, "outputs": [], "source": [ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", "with torch.inference_mode():\n", " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", "pred = pred.squeeze(0).squeeze(0).cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "583b9f16", "metadata": {}, "outputs": [], "source": [ "plt.imshow(vis_disparity(pred))\n", "plt.axis('off')" ] }, { "cell_type": "markdown", "id": "d2df5d70", "metadata": {}, "source": [ "### CroCo-Flow example" ] }, { "cell_type": "code", "execution_count": null, "id": "9ee257a7", "metadata": {}, "outputs": [], "source": [ "image1 = np.asarray(Image.open(''))\n", "image2 = np.asarray(Image.open(''))" ] }, { "cell_type": "code", "execution_count": null, "id": "d5edccf0", "metadata": {}, "outputs": [], "source": [ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b19692c3", "metadata": {}, "outputs": [], "source": [ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", "with torch.inference_mode():\n", " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "id": "26f79db3", "metadata": {}, "outputs": [], "source": [ "plt.imshow(flowToColor(pred))\n", "plt.axis('off')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }