{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9c66a150-b2f7-4c34-b93a-ca70a0855169", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-Aug-18 23:10:12.0532 67649:67649 ERROR TDRV:tdrv_get_dev_info No neuron device available\n" ] } ], "source": [ "import os\n", "os.environ[\"NEURON_FUSE_SOFTMAX\"] = \"1\"\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch_neuronx\n", "import numpy as np\n", "\n", "from matplotlib import pyplot as plt\n", "from matplotlib import image as mpimg\n", "import time\n", "import copy\n", "from IPython.display import clear_output\n", "\n", "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n", "from diffusers.models.unet_2d_condition import UNet2DConditionOutput\n", "from diffusers.models.cross_attention import CrossAttention\n", "\n", "# Define datatype\n", "DTYPE = torch.float32\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "54c2839b-44b5-4d27-8e83-7cc3d69a53df", "metadata": {}, "outputs": [], "source": [ "class UNetWrap(nn.Module):\n", " def __init__(self, unet):\n", " super().__init__()\n", " self.unet = unet\n", "\n", " def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n", " out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)\n", " return out_tuple\n", "\n", "class NeuronUNet(nn.Module):\n", " def __init__(self, unetwrap):\n", " super().__init__()\n", " self.unetwrap = unetwrap\n", " self.config = unetwrap.unet.config\n", " self.in_channels = unetwrap.unet.in_channels\n", " self.device = unetwrap.unet.device\n", "\n", " def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n", " sample = self.unetwrap(sample, timestep.to(dtype=DTYPE).expand((sample.shape[0],)), encoder_hidden_states)[0]\n", " return UNet2DConditionOutput(sample=sample)\n", "\n", "class NeuronTextEncoder(nn.Module):\n", " def __init__(self, text_encoder):\n", " super().__init__()\n", " self.neuron_text_encoder = text_encoder\n", " self.config = text_encoder.config\n", " self.dtype = text_encoder.dtype\n", " self.device = text_encoder.device\n", "\n", " def forward(self, emb, attention_mask = None):\n", " return [self.neuron_text_encoder(emb)['last_hidden_state']]\n", "# Optimized attention\n", "def get_attention_scores(self, query, key, attn_mask): \n", " dtype = query.dtype\n", "\n", " if self.upcast_attention:\n", " query = query.float()\n", " key = key.float()\n", "\n", " # Check for square matmuls\n", " if(query.size() == key.size()):\n", " attention_scores = custom_badbmm(\n", " key,\n", " query.transpose(-1, -2)\n", " )\n", "\n", " if self.upcast_softmax:\n", " attention_scores = attention_scores.float()\n", "\n", " attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)\n", " attention_probs = attention_probs.to(dtype)\n", "\n", " else:\n", " attention_scores = custom_badbmm(\n", " query,\n", " key.transpose(-1, -2)\n", " )\n", "\n", " if self.upcast_softmax:\n", " attention_scores = attention_scores.float()\n", "\n", " attention_probs = attention_scores.softmax(dim=-1)\n", " attention_probs = attention_probs.to(dtype)\n", " \n", " return attention_probs\n", "\n", "# In the original badbmm the bias is all zeros, so only apply scale\n", "def custom_badbmm(a, b):\n", " bmm = torch.bmm(a, b)\n", " scaled = bmm * 0.125\n", " return scaled" ] }, { "cell_type": "code", "execution_count": 4, "id": "e1eb8d1b-7b4e-4d55-996e-482e8f18d5e0", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "89d0ef19f2d84ac8bf742de97c95617b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 13 files: 0%| | 0/13 [00:00 11\u001b[0m \u001b[43mpipe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m(width\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1920\u001b[39m, height\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1080\u001b[39m)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# Replace original cross-attention module with custom cross-attention module for better performance\u001b[39;00m\n\u001b[1;32m 14\u001b[0m CrossAttention\u001b[38;5;241m.\u001b[39mget_attention_scores \u001b[38;5;241m=\u001b[39m get_attention_scores\n", "\u001b[0;31mAttributeError\u001b[0m: 'StableDiffusionPipeline' object has no attribute 'reshape'" ] } ], "source": [ "# For saving compiler artifacts\n", "COMPILER_WORKDIR_ROOT = 'sd2_compile_dir_768'\n", "\n", "# Model ID for SD version pipeline\n", "model_id = \"stabilityai/stable-diffusion-2-1\"\n", "\n", "# --- Compile UNet and save ---\n", "\n", "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)\n", "\n", "pipe.reshape(width=1920, height=1080)\n", "\n", "# Replace original cross-attention module with custom cross-attention module for better performance\n", "CrossAttention.get_attention_scores = get_attention_scores\n", "\n", "# Apply double wrapper to deal with custom return type\n", "pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n", "\n", "# Only keep the model being compiled in RAM to minimze memory pressure\n", "unet = copy.deepcopy(pipe.unet.unetwrap)\n", "\n", "# Compile unet - FP32\n", "sample_1b = torch.randn([1, 4, 135, 240], dtype=DTYPE)\n", "timestep_1b = torch.tensor(999, dtype=DTYPE).expand((1,))\n", "encoder_hidden_states_1b = torch.randn([1, 77, 1024], dtype=DTYPE)\n", "example_inputs = sample_1b, timestep_1b, encoder_hidden_states_1b\n", "print(1)\n", "unet_neuron = torch_neuronx.trace(\n", " unet,\n", " example_inputs,\n", " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),\n", " compiler_args=[\"--model-type=unet-inference\", \"--enable-fast-loading-neuron-binaries\"]\n", ")\n", "\n", "# Enable asynchronous and lazy loading to speed up model load\n", "torch_neuronx.async_load(unet_neuron)\n", "torch_neuronx.lazy_load(unet_neuron)\n", "\n", "# save compiled unet\n", "unet_filename = 'unet.pt'\n", "torch.jit.save(unet_neuron, unet_filename)\n", "\n", "# delete unused objects\n", "del unet\n", "del unet_neuron\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e1301369-2008-496f-a52f-65309ab138ac", "metadata": {}, "outputs": [], "source": [ "text_encoder = copy.deepcopy(pipe.text_encoder)\n", "\n", "# Apply the wrapper to deal with custom return type\n", "text_encoder = NeuronTextEncoder(text_encoder)\n", "\n", "# Compile text encoder\n", "# This is used for indexing a lookup table in torch.nn.Embedding,\n", "# so using random numbers may give errors (out of range).\n", "emb = torch.tensor([[49406, 18376, 525, 7496, 49407, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0]])\n", "text_encoder_neuron = torch_neuronx.trace(\n", " text_encoder.neuron_text_encoder, \n", " emb, \n", " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'text_encoder'),\n", " compiler_args=[\"--enable-fast-loading-neuron-binaries\"]\n", " )\n", "\n", "# Enable asynchronous loading to speed up model load\n", "torch_neuronx.async_load(text_encoder_neuron)\n", "\n", "# Save the compiled text encoder\n", "text_encoder_filename = 'text_encoder.pt'\n", "torch.jit.save(text_encoder_neuron, text_encoder_filename)\n", "\n", "# delete unused objects\n", "del text_encoder\n", "del text_encoder_neuron\n", "\n", "# --- Compile VAE decoder and save ---\n", "\n", "# Only keep the model being compiled in RAM to minimze memory pressure\n", "\n", "decoder = copy.deepcopy(pipe.vae.decoder)\n", "# Compile vae decoder\n", "decoder_in = torch.randn([1, 4, 96, 96], dtype=DTYPE)\n", "decoder_neuron = torch_neuronx.trace(\n", " decoder, \n", " decoder_in, \n", " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),\n", " compiler_args=[\"--enable-fast-loading-neuron-binaries\"]\n", ")\n", "\n", "# Enable asynchronous loading to speed up model load\n", "torch_neuronx.async_load(decoder_neuron)\n", "\n", "# Save the compiled vae decoder\n", "decoder_filename = 'vae_decoder.pt'\n", "torch.jit.save(decoder_neuron, decoder_filename)\n", "\n", "# delete unused objects\n", "del decoder\n", "del decoder_neuron\n", "\n", "\n", "\n", "\n", "post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)\n", "\n", "# # Compile vae post_quant_conv\n", "post_quant_conv_in = torch.randn([1, 4, 96, 96], dtype=DTYPE)\n", "post_quant_conv_neuron = torch_neuronx.trace(\n", " post_quant_conv, \n", " post_quant_conv_in,\n", " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),\n", ")\n", "# Enable asynchronous loading to speed up model load\n", "torch_neuronx.async_load(post_quant_conv_neuron)\n", "\n", "# # Save the compiled vae post_quant_conv\n", "post_quant_conv_filename = 'vae_post_quant_conv.pt'\n", "torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)\n", "\n", "# delete unused objects\n", "del post_quant_conv\n", "del post_quant_conv_neuron" ] }, { "cell_type": "code", "execution_count": null, "id": "07524a73-3bbf-4f76-945e-358ca833c335", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (torch-neuronx)", "language": "python", "name": "aws_neuron_venv_pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }