EvanTHU commited on
Commit
54e887f
·
verified ·
1 Parent(s): d5e8fc2

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +1 -1
models/unet.py CHANGED
@@ -808,7 +808,7 @@ class MotionCLR(nn.Module):
808
  # text encoder
809
  self.embed_text = nn.Linear(clip_dim, text_latent_dim)
810
  self.clip_version = clip_version
811
- self.clip_model = self.load_and_freeze_clip(clip_version)
812
  textTransEncoderLayer = nn.TransformerEncoderLayer(
813
  d_model=text_latent_dim,
814
  nhead=text_num_heads,
 
808
  # text encoder
809
  self.embed_text = nn.Linear(clip_dim, text_latent_dim)
810
  self.clip_version = clip_version
811
+ self.clip_model = self.load_and_freeze_clip(clip_version).cuda()
812
  textTransEncoderLayer = nn.TransformerEncoderLayer(
813
  d_model=text_latent_dim,
814
  nhead=text_num_heads,