{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "07b2bef9-bbaf-41b8-b960-7ac373ff3e8d", "metadata": {}, "outputs": [], "source": [ "!pip install diffusers==0.14.0 transformers==4.26.1 accelerate==0.16.0 safetensors==0.3.1 matplotlib" ] }, { "cell_type": "code", "execution_count": null, "id": "6ebecb44-f796-4c76-8385-888a2f46fd6a", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"NEURON_FUSE_SOFTMAX\"] = \"1\"\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch_neuronx\n", "import numpy as np\n", "\n", "from matplotlib import pyplot as plt\n", "from matplotlib import image as mpimg\n", "import time\n", "import copy\n", "\n", "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n", "from diffusers.models.unet_2d_condition import UNet2DConditionOutput\n", "from diffusers.models.cross_attention import CrossAttention\n", "\n", "# Define datatype\n", "DTYPE = torch.float32" ] }, { "cell_type": "code", "execution_count": null, "id": "9950025f-877a-4c11-b30e-9c32f0825e94", "metadata": {}, "outputs": [], "source": [ "class UNetWrap(nn.Module):\n", " def __init__(self, unet):\n", " super().__init__()\n", " self.unet = unet\n", "\n", " def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n", " out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)\n", " return out_tuple\n", "\n", "class NeuronUNet(nn.Module):\n", " def __init__(self, unetwrap):\n", " super().__init__()\n", " self.unetwrap = unetwrap\n", " self.config = unetwrap.unet.config\n", " self.in_channels = unetwrap.unet.in_channels\n", " self.device = unetwrap.unet.device\n", "\n", " def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):\n", " sample = self.unetwrap(sample, timestep.to(dtype=DTYPE).expand((sample.shape[0],)), encoder_hidden_states)[0]\n", " return UNet2DConditionOutput(sample=sample)\n", "\n", "class NeuronTextEncoder(nn.Module):\n", " def __init__(self, text_encoder):\n", " super().__init__()\n", " self.neuron_text_encoder = text_encoder\n", " self.config = text_encoder.config\n", " self.dtype = text_encoder.dtype\n", " self.device = text_encoder.device\n", "\n", " def forward(self, emb, attention_mask = None):\n", " return [self.neuron_text_encoder(emb)['last_hidden_state']]\n", "# Optimized attention\n", "def get_attention_scores(self, query, key, attn_mask): \n", " dtype = query.dtype\n", "\n", " if self.upcast_attention:\n", " query = query.float()\n", " key = key.float()\n", "\n", " # Check for square matmuls\n", " if(query.size() == key.size()):\n", " attention_scores = custom_badbmm(\n", " key,\n", " query.transpose(-1, -2)\n", " )\n", "\n", " if self.upcast_softmax:\n", " attention_scores = attention_scores.float()\n", "\n", " attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)\n", " attention_probs = attention_probs.to(dtype)\n", "\n", " else:\n", " attention_scores = custom_badbmm(\n", " query,\n", " key.transpose(-1, -2)\n", " )\n", "\n", " if self.upcast_softmax:\n", " attention_scores = attention_scores.float()\n", "\n", " attention_probs = attention_scores.softmax(dim=-1)\n", " attention_probs = attention_probs.to(dtype)\n", " \n", " return attention_probs\n", "\n", "# In the original badbmm the bias is all zeros, so only apply scale\n", "def custom_badbmm(a, b):\n", " bmm = torch.bmm(a, b)\n", " scaled = bmm * 0.125\n", " return scaled" ] }, { "cell_type": "code", "execution_count": null, "id": "ffc64d14-f48c-488c-b60a-36e3ebfdab83", "metadata": {}, "outputs": [], "source": [ "model_id = \"stabilityai/stable-diffusion-2-1\"\n", "text_encoder_filename = 'text_encoder.pt'\n", "decoder_filename = 'vae_decoder.pt'\n", "unet_filename = 'unet.pt'\n", "post_quant_conv_filename = 'vae_post_quant_conv.pt'\n", "\n", "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)\n", "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n", "\n", "# Load the compiled UNet onto two neuron cores.\n", "pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n", "device_ids = [0,1]\n", "pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)\n", "\n", "# Load other compiled models onto a single neuron core.\n", "pipe.text_encoder = NeuronTextEncoder(pipe.text_encoder)\n", "pipe.text_encoder.neuron_text_encoder = torch.jit.load(text_encoder_filename)\n", "pipe.vae.decoder = torch.jit.load(decoder_filename)\n", "pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)" ] } ], "metadata": { "kernelspec": { "display_name": "Python (torch-neuronx)", "language": "python", "name": "aws_neuron_venv_pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }