EvanTHU commited on
Commit
5949188
·
verified ·
1 Parent(s): 929cd61

Update models/unet.py

Browse files
Files changed (1) hide show
  1. models/unet.py +4 -0
models/unet.py CHANGED
@@ -864,6 +864,10 @@ class MotionCLR(nn.Module):
864
  self.unet = self.unet.cuda()
865
 
866
  def encode_text(self, raw_text, device):
 
 
 
 
867
  with torch.no_grad():
868
  texts = clip.tokenize(raw_text, truncate=True).to(
869
  device
 
864
  self.unet = self.unet.cuda()
865
 
866
  def encode_text(self, raw_text, device):
867
+ self.clip_model.token_embedding = self.clip_model.token_embedding.to(device)
868
+ self.clip_model.positional_embedding = self.clip_model.positional_embedding.to(device)
869
+ self.clip_model.transformer = self.clip_model.transformer.to(device)
870
+ self.clip_model.ln_final = self.clip_model.ln_final.to(device)
871
  with torch.no_grad():
872
  texts = clip.tokenize(raw_text, truncate=True).to(
873
  device