In [None]:
!pip install diffusers==0.14.0 transformers==4.26.1 accelerate==0.16.0 safetensors==0.3.1 matplotlib

In [None]:
import os
os.environ["NEURON_FUSE_SOFTMAX"] = "1"

import torch
import torch.nn as nn
import torch_neuronx
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import image as mpimg
import time
import copy

from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.cross_attention import CrossAttention

# Define datatype
DTYPE = torch.float32

In [None]:
class UNetWrap(nn.Module):
 def __init__(self, unet):
 super().__init__()
 self.unet = unet

 def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):
 out_tuple = self.unet(sample, timestep, encoder_hidden_states, return_dict=False)
 return out_tuple

class NeuronUNet(nn.Module):
 def __init__(self, unetwrap):
 super().__init__()
 self.unetwrap = unetwrap
 self.config = unetwrap.unet.config
 self.in_channels = unetwrap.unet.in_channels
 self.device = unetwrap.unet.device

 def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None):
 sample = self.unetwrap(sample, timestep.to(dtype=DTYPE).expand((sample.shape[0],)), encoder_hidden_states)[0]
 return UNet2DConditionOutput(sample=sample)

class NeuronTextEncoder(nn.Module):
 def __init__(self, text_encoder):
 super().__init__()
 self.neuron_text_encoder = text_encoder
 self.config = text_encoder.config
 self.dtype = text_encoder.dtype
 self.device = text_encoder.device

 def forward(self, emb, attention_mask = None):
 return [self.neuron_text_encoder(emb)['last_hidden_state']]
# Optimized attention
def get_attention_scores(self, query, key, attn_mask): 
 dtype = query.dtype

 if self.upcast_attention:
 query = query.float()
 key = key.float()

 # Check for square matmuls
 if(query.size() == key.size()):
 attention_scores = custom_badbmm(
 key,
 query.transpose(-1, -2)
 )

 if self.upcast_softmax:
 attention_scores = attention_scores.float()

 attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)
 attention_probs = attention_probs.to(dtype)

 else:
 attention_scores = custom_badbmm(
 query,
 key.transpose(-1, -2)
 )

 if self.upcast_softmax:
 attention_scores = attention_scores.float()

 attention_probs = attention_scores.softmax(dim=-1)
 attention_probs = attention_probs.to(dtype)
 
 return attention_probs

# In the original badbmm the bias is all zeros, so only apply scale
def custom_badbmm(a, b):
 bmm = torch.bmm(a, b)
 scaled = bmm * 0.125
 return scaled

In [None]:
model_id = "stabilityai/stable-diffusion-2-1"
text_encoder_filename = 'text_encoder.pt'
decoder_filename = 'vae_decoder.pt'
unet_filename = 'unet.pt'
post_quant_conv_filename = 'vae_post_quant_conv.pt'

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=DTYPE)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Load the compiled UNet onto two neuron cores.
pipe.unet = NeuronUNet(UNetWrap(pipe.unet))
device_ids = [0,1]
pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)

# Load other compiled models onto a single neuron core.
pipe.text_encoder = NeuronTextEncoder(pipe.text_encoder)
pipe.text_encoder.neuron_text_encoder = torch.jit.load(text_encoder_filename)
pipe.vae.decoder = torch.jit.load(decoder_filename)
pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)