{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": { "id": "ClbDA89uqYc2", "colab": { "base_uri": "https://localhost:8080/", "height": 383 }, "outputId": "0dac9131-1d0c-416b-98fb-9dcc7e0dcf5b" }, "outputs": [ { "output_type": "error", "ename": "ModuleNotFoundError", "evalue": "No module named 'attention'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mattention\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSelfAttention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mVAE_AttentionBlock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'attention'", "", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n" ], "errorDetails": { "actions": [ { "action": "open_url", "actionText": "Open Examples", "url": "/notebooks/snippets/importing_libraries.ipynb" } ] } } ], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from attention import SelfAttention\n", "\n", "class VAE_AttentionBlock(nn.Module):\n", " def __init__(self, channels):\n", " super.__init__()\n", " self.groupnorm = nn.GroupNorm(32, channels)\n", " self.attention = Attention(1, channels)\n", " def forward(self, x):\n", " residue=x\n", " x=self.groupnorm(x)\n", " n,c,h,w =x.shape\n", " x=x.view(n,c,h*w)\n", " x=x.transpose(-1,-2)\n", " x=self.attention(x)\n", " x=x.view((n,c,h,w))\n", " x+=residue\n", " return x\n", "class VAE_ResidualBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels):\n", " super.__init__()\n", " self.groupnorm_1= nn.GroupNorm(32, in_channels)\n", " self.conv_1=nn.Conv2d(in_channels, out_channels , kernel_size=3, padding=1)\n", "\n", " self.groupnorm_2= nn.GroupNorm(32, in_channels)\n", " self.conv_2= nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)\n", "\n", " if in_channels==out_channels:\n", " self.residual_layer=nn.Identity()\n", " else:\n", " self.conv_2= nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)\n", " def forward(self, x):\n", " residue=x\n", " x=self.groupnorm_1(x)\n", " x=F.silu(x)\n", " x=self.conv_2(x)\n", " return x+self.residual_layer(residue)\n", "class VAE_Decoder(nn.Sequential):\n", " def __init__(self):\n", " super.__init__(\n", " nn.Conv2d(4,4, kernel_size=1, padding=0),\n", " nn.Conv2d(4, 512, kernel_size=3, padding=1),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_AttentionBlock(512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", "\n", " nn.Upsample(scale_factor=2),\n", " nn.Conv2d(512, 512,kernel_size=3, padding=1),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_ResidualBlock(512, 512),\n", " nn.Upsample(scale_factor=2),\n", " VAE_ResidualBlock(512, 256),\n", " VAE_ResidualBlock(256, 256),\n", " VAE_ResidualBlock(256, 256),\n", " nn.Upsample(scale_factor=2),\n", " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n", " VAE_ResidualBlock(256, 128),\n", " VAE_ResidualBlock(128, 128),\n", " VAE_ResidualBlock(128, 128),\n", " nn.GroupNorm(32, 128),\n", " nn.SiLU(),\n", " nn.Conv2d(128, 3, kernel_size=3,padding=1)\n", " )\n", " def forward(self, x):\n", " x/=0.18125\n", " for module in self:\n", " x=module(x)\n", " return x\n", "\n", "\n" ] } ] }