rodrigomasini commited on
Commit
255d880
1 Parent(s): 34615b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -97
app.py CHANGED
@@ -88,103 +88,6 @@ class AttnProcessor(nn.Module):
88
  return hidden_states
89
 
90
 
91
- class AttnProcessor2_0(torch.nn.Module):
92
- r"""
93
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
94
- """
95
- def __init__(
96
- self,
97
- hidden_size=None,
98
- cross_attention_dim=None,
99
- ):
100
- super().__init__()
101
- if not hasattr(F, "scaled_dot_product_attention"):
102
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
103
-
104
- def __call__(
105
- self,
106
- attn,
107
- hidden_states,
108
- encoder_hidden_states=None,
109
- attention_mask=None,
110
- temb=None,
111
- ):
112
- residual = hidden_states
113
-
114
- if attn.spatial_norm is not None:
115
- hidden_states = attn.spatial_norm(hidden_states, temb)
116
-
117
- input_ndim = hidden_states.ndim
118
-
119
- if input_ndim == 4:
120
- batch_size, channel, height, width = hidden_states.shape
121
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
-
123
- batch_size, sequence_length, _ = (
124
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
- )
126
-
127
- if attention_mask is not None:
128
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
- # scaled_dot_product_attention expects attention_mask shape to be
130
- # (batch, heads, source_length, target_length)
131
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
132
-
133
- if attn.group_norm is not None:
134
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
135
-
136
- query = attn.to_q(hidden_states)
137
-
138
- if encoder_hidden_states is None:
139
- encoder_hidden_states = hidden_states
140
- elif attn.norm_cross:
141
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
142
-
143
- key = attn.to_k(encoder_hidden_states)
144
- value = attn.to_v(encoder_hidden_states)
145
-
146
- inner_dim = key.shape[-1]
147
- head_dim = inner_dim // attn.heads
148
-
149
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
150
-
151
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
152
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
153
-
154
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
155
- # TODO: add support for attn.scale when we move to Torch 2.1
156
- hidden_states = F.scaled_dot_product_attention(
157
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
158
- )
159
-
160
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
161
- hidden_states = hidden_states.to(query.dtype)
162
-
163
- # linear proj
164
- hidden_states = attn.to_out[0](hidden_states)
165
- # dropout
166
- hidden_states = attn.to_out[1](hidden_states)
167
-
168
- if input_ndim == 4:
169
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
170
-
171
- if attn.residual_connection:
172
- hidden_states = hidden_states + residual
173
-
174
- hidden_states = hidden_states / attn.rescale_output_factor
175
-
176
- return hidden_states
177
-
178
-
179
- def is_torch2_available():
180
- return hasattr(F, "scaled_dot_product_attention")
181
-
182
- if is_torch2_available():
183
- from utils.gradio_utils import \
184
- AttnProcessor2_0 as AttnProcessor
185
- # from utils.gradio_utils import SpatialAttnProcessor2_0
186
- else:
187
- from utils.gradio_utils import AttnProcessor
188
 
189
  import diffusers
190
  from diffusers import StableDiffusionXLPipeline
 
88
  return hidden_states
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  import diffusers
93
  from diffusers import StableDiffusionXLPipeline