{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "OaBIb0WNma-U", "colab": { "base_uri": "https://localhost:8080/", "height": 383 }, "outputId": "33f26e87-f2b5-4de8-dfd8-dca7e8973a01" }, "outputs": [ { "output_type": "error", "ename": "ModuleNotFoundError", "evalue": "No module named 'decoder'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdecoder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVAE_AttentionBlock\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVAE_ResidualBlock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mVAE_Encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'decoder'", "", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n" ], "errorDetails": { "actions": [ { "action": "open_url", "actionText": "Open Examples", "url": "/notebooks/snippets/importing_libraries.ipynb" } ] } } ], "source": [ "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from decoder import VAE_AttentionBlock, VAE_ResidualBlock\n", "\n", "class VAE_Encoder(nn.Sequential):\n", " def __init__(self):\n", " super().__init__(\n", " nn.Conv2d(3, 128, kernel_size=3, padding=1),\n", " VAE_ResidualBlock(128, 128),\n", " VAE_ResidualBlock(128, 128),\n", " nn.Conv2d(128, 128, kernel_size=3,stride=2, padding=1),\n", " VAE_ResidualBlock(128,256),\n", " VAE_ResidualBlock(256, 256),\n", " nn.Conv2d(3, 128, kernel_size=3,stride=2, padding=1),\n", " VAE_ResidualBlock(256, 512),\n", " VAE_ResidualBlock(512, 512),\n", " VAE_AttentionBlock(512),\n", " VAE_ResidualBlock(512, 512),\n", " nn.GroupNorm(32, 512),\n", " nn.SiLU(),\n", " nn.Conv2d(512, 8, kernel_size=3, padding=1),\n", " nn.Conv2d(8,8,kernel_size=3, padding=0),\n", " )\n", " def forward(self, x , noise):\n", " for module in self:\n", " if getattr(module, \"stride\", None) ==(2,2):\n", " x= F.pad(x, (0,1,0,1))\n", " x= module(x)\n", " mean , log_variance = torch.chunk(x, 2, dim=1)\n", " log_variance= torch.clamp(log_variance, -20, 30)\n", " variance= log_variance.exp()\n", " stdev = variance.sqrt()\n", " x= stdev*noise +mean\n", " x*=0.18215\n", " return x\n" ] } ] }