File size: 4,694 Bytes
6817d94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "OaBIb0WNma-U",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 383
},
"outputId": "33f26e87-f2b5-4de8-dfd8-dca7e8973a01"
},
"outputs": [
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "No module named 'decoder'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-140de5a90328>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdecoder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVAE_AttentionBlock\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVAE_ResidualBlock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mVAE_Encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'decoder'",
"",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
],
"errorDetails": {
"actions": [
{
"action": "open_url",
"actionText": "Open Examples",
"url": "/notebooks/snippets/importing_libraries.ipynb"
}
]
}
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from decoder import VAE_AttentionBlock, VAE_ResidualBlock\n",
"\n",
"class VAE_Encoder(nn.Sequential):\n",
" def __init__(self):\n",
" super().__init__(\n",
" nn.Conv2d(3, 128, kernel_size=3, padding=1),\n",
" VAE_ResidualBlock(128, 128),\n",
" VAE_ResidualBlock(128, 128),\n",
" nn.Conv2d(128, 128, kernel_size=3,stride=2, padding=1),\n",
" VAE_ResidualBlock(128,256),\n",
" VAE_ResidualBlock(256, 256),\n",
" nn.Conv2d(3, 128, kernel_size=3,stride=2, padding=1),\n",
" VAE_ResidualBlock(256, 512),\n",
" VAE_ResidualBlock(512, 512),\n",
" VAE_AttentionBlock(512),\n",
" VAE_ResidualBlock(512, 512),\n",
" nn.GroupNorm(32, 512),\n",
" nn.SiLU(),\n",
" nn.Conv2d(512, 8, kernel_size=3, padding=1),\n",
" nn.Conv2d(8,8,kernel_size=3, padding=0),\n",
" )\n",
" def forward(self, x , noise):\n",
" for module in self:\n",
" if getattr(module, \"stride\", None) ==(2,2):\n",
" x= F.pad(x, (0,1,0,1))\n",
" x= module(x)\n",
" mean , log_variance = torch.chunk(x, 2, dim=1)\n",
" log_variance= torch.clamp(log_variance, -20, 30)\n",
" variance= log_variance.exp()\n",
" stdev = variance.sqrt()\n",
" x= stdev*noise +mean\n",
" x*=0.18215\n",
" return x\n"
]
}
]
} |