{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| default_exp utils\n", "\n", "import os\n", "from PIL import Image\n", "\n", "def make_square(image, min_size=512, fill_color=(255, 255, 255, 0)):\n", " '''\n", " Make a square image with signature in the center and black (transparent)\n", " strips on top and bottom. Cycle GAN is trained with images of this format. \n", " '''\n", " x, y = image.size\n", " size = max(min_size, x, y)\n", " new_im = Image.new('RGBA', (size, size), fill_color)\n", " new_im.paste(image, (int((size - x) / 2), int((size - y) / 2)))\n", " new_im = new_im.resize((512, 512))\n", " return new_im\n", "\n", "def resize_images(path):\n", " '''\n", " Resize all the images present in path that matches the ips used in cyclegan\n", " training\n", " '''\n", " dirs = os.listdir(path)\n", " for item in dirs:\n", " if os.path.isfile(path+item):\n", " image = Image.open(path+item)\n", " image = make_square(image)\n", " image = image.convert('RGB')\n", " image.save(path+item)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'torch'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[15], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| export\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m \n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnn\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnn\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mIdentity\u001b[39;00m(nn\u001b[38;5;241m.\u001b[39mModule):\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'" ] } ], "source": [ "#| export\n", "import torch \n", "import torch.nn as nn\n", "\n", "class Identity(nn.Module):\n", " def __init__(self):\n", " super(Identity, self).__init__()\n", " \n", " def forward(self, x):\n", " return x\n", " \n", "\n", "class SiameseNetwork(nn.Module):\n", " def __init__(self, model, embedding_size=2000):\n", " super(SiameseNetwork, self).__init__()\n", "\n", " self.backbone = model\n", " self.rgb_grayscale = nn.Conv2d(1,3,kernel_size=3,stride=1,padding=1)\n", " self.a = nn.Sigmoid()\n", " self.fc1 = nn.Linear(in_features=1280, out_features=embedding_size) \n", " self.fc2 = nn.Linear(in_features=embedding_size,out_features=1) \n", " self.relu = nn.ReLU()\n", " self.backbone.classifier = Identity()\n", "\n", " def forward_once(self, x):\n", " \n", " x = self.rgb_grayscale(x)\n", " resnet_output = self.backbone(x)\n", " \n", " return resnet_output\n", " \n", " def forward(self, x1,x2):\n", " \n", " x1 = self.forward_once(x1)\n", " x2 = self.forward_once(x2)\n", " out1, out2 = self.relu(self.fc1(x1)), self.relu(self.fc1(x2))\n", " dis = torch.abs((out1 - out2))\n", " pred = self.a(self.fc2(dis)).squeeze()\n", " return pred" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "255\n", "0\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import cv2\n", "import numpy as np\n", "p = \"/Users/markus/code/explore/signify/results/media/documents/ajy01c00.png\"\n", "pnew = \"/Users/markus/code/explore/signify/results/media/documents/3.png\"\n", "\n", "img = 255 - cv2.imread(p) \n", "print(img.max())\n", "print(img.min())\n", "cv2.imwrite(pnew, img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#| hide\n", "import nbdev; nbdev.nbdev_export()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.8 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "a2b499d73ef2075ded5123ca15bce54816687eb05fa98ad819942f2bd35191e9" } } }, "nbformat": 4, "nbformat_minor": 2 }