EvanTHU commited on
Commit
7e9ae62
1 Parent(s): 721b9c2

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +21 -1
models/unet.py CHANGED
@@ -9,6 +9,25 @@ from einops.layers.torch import Rearrange
9
  from einops import rearrange
10
  import matplotlib.pyplot as plt
11
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  MONITOR_ATTN = []
@@ -808,7 +827,8 @@ class MotionCLR(nn.Module):
808
  # text encoder
809
  self.embed_text = nn.Linear(clip_dim, text_latent_dim)
810
  self.clip_version = clip_version
811
- self.clip_model = self.load_and_freeze_clip(clip_version).cuda()
 
812
  textTransEncoderLayer = nn.TransformerEncoderLayer(
813
  d_model=text_latent_dim,
814
  nhead=text_num_heads,
 
9
  from einops import rearrange
10
  import matplotlib.pyplot as plt
11
  import os
12
+ import torch.nn as nn
13
+
14
+ # Custom LayerNorm class to handle fp16
15
+ class CustomLayerNorm(nn.LayerNorm):
16
+ def forward(self, x: torch.Tensor):
17
+ if self.weight.dtype == torch.float32:
18
+ orig_type = x.dtype
19
+ ret = super().forward(x.type(torch.float32))
20
+ return ret.type(orig_type)
21
+ else:
22
+ return super().forward(x)
23
+
24
+ # Function to replace LayerNorm in CLIP model with CustomLayerNorm
25
+ def replace_layer_norm(model):
26
+ for name, module in model.named_children():
27
+ if isinstance(module, nn.LayerNorm):
28
+ setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine))
29
+ else:
30
+ replace_layer_norm(module) # Recursively apply to all submodules
31
 
32
 
33
  MONITOR_ATTN = []
 
827
  # text encoder
828
  self.embed_text = nn.Linear(clip_dim, text_latent_dim)
829
  self.clip_version = clip_version
830
+ self.clip_model = self.load_and_freeze_clip(clip_version)
831
+ replace_layer_norm(self.clip_model)
832
  textTransEncoderLayer = nn.TransformerEncoderLayer(
833
  d_model=text_latent_dim,
834
  nhead=text_num_heads,