Spaces:
Running
on
Zero
Running
on
Zero
Update models/unet.py
Browse files- models/unet.py +21 -1
models/unet.py
CHANGED
@@ -9,6 +9,25 @@ from einops.layers.torch import Rearrange
|
|
9 |
from einops import rearrange
|
10 |
import matplotlib.pyplot as plt
|
11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
MONITOR_ATTN = []
|
@@ -808,7 +827,8 @@ class MotionCLR(nn.Module):
|
|
808 |
# text encoder
|
809 |
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
|
810 |
self.clip_version = clip_version
|
811 |
-
self.clip_model = self.load_and_freeze_clip(clip_version)
|
|
|
812 |
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
813 |
d_model=text_latent_dim,
|
814 |
nhead=text_num_heads,
|
|
|
9 |
from einops import rearrange
|
10 |
import matplotlib.pyplot as plt
|
11 |
import os
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
# Custom LayerNorm class to handle fp16
|
15 |
+
class CustomLayerNorm(nn.LayerNorm):
|
16 |
+
def forward(self, x: torch.Tensor):
|
17 |
+
if self.weight.dtype == torch.float32:
|
18 |
+
orig_type = x.dtype
|
19 |
+
ret = super().forward(x.type(torch.float32))
|
20 |
+
return ret.type(orig_type)
|
21 |
+
else:
|
22 |
+
return super().forward(x)
|
23 |
+
|
24 |
+
# Function to replace LayerNorm in CLIP model with CustomLayerNorm
|
25 |
+
def replace_layer_norm(model):
|
26 |
+
for name, module in model.named_children():
|
27 |
+
if isinstance(module, nn.LayerNorm):
|
28 |
+
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine))
|
29 |
+
else:
|
30 |
+
replace_layer_norm(module) # Recursively apply to all submodules
|
31 |
|
32 |
|
33 |
MONITOR_ATTN = []
|
|
|
827 |
# text encoder
|
828 |
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
|
829 |
self.clip_version = clip_version
|
830 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
831 |
+
replace_layer_norm(self.clip_model)
|
832 |
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
833 |
d_model=text_latent_dim,
|
834 |
nhead=text_num_heads,
|