{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import requests\n", "\n", "! wget https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download -O weights.zip\n", "! unzip -d weights -j weights.zip\n", "from models.clipseg import CLIPDensePredT\n", "from PIL import Image\n", "from torchvision import transforms\n", "from matplotlib import pyplot as plt\n", "\n", "# load model\n", "model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)\n", "model.eval();\n", "\n", "# non-strict, because we only stored decoder weights (not CLIP weights)\n", "model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load and normalize `example_image.jpg`. You can also load through an URL." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load and normalize image\n", "input_image = Image.open('example_image.jpg')\n", "\n", "# or load from URL...\n", "# image_url = 'https://farm5.staticflickr.com/4141/4856248695_03475782dc_z.jpg'\n", "# input_image = Image.open(requests.get(image_url, stream=True).raw)\n", "\n", "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " transforms.Resize((352, 352)),\n", "])\n", "img = transform(input_image).unsqueeze(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predict and visualize (this might take a few seconds if running without GPU support)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompts = ['a glass', 'something to fill', 'wood', 'a jar']\n", "\n", "# predict\n", "with torch.no_grad():\n", " preds = model(img.repeat(4,1,1,1), prompts)[0]\n", "\n", "# visualize prediction\n", "_, ax = plt.subplots(1, 5, figsize=(15, 4))\n", "[a.axis('off') for a in ax.flatten()]\n", "ax[0].imshow(input_image)\n", "[ax[i+1].imshow(torch.sigmoid(preds[i][0])) for i in range(4)];\n", "[ax[i+1].text(0, -15, prompts[i]) for i in range(4)];" ] } ], "metadata": { "interpreter": { "hash": "800ed241f7db2bd3aa6942aa3be6809cdb30ee6b0a9e773dfecfa9fef1f4c586" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }