robin-courant commited on
Commit
c008abb
1 Parent(s): c395a63

Update utils/common_viz.py

Browse files
Files changed (1) hide show
  1. utils/common_viz.py +3 -1
utils/common_viz.py CHANGED
@@ -88,8 +88,10 @@ def get_batch(
88
  batch = collate_fn([to_device(raw_batch, device)])
89
 
90
  # Encode text
 
91
  caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device)
92
-
 
93
  if seq_feat:
94
  caption_feat = caption_seq[0]
95
  caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0]))
 
88
  batch = collate_fn([to_device(raw_batch, device)])
89
 
90
  # Encode text
91
+ clip_model.to(device)
92
  caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device)
93
+ print(caption_seq[0].device)
94
+
95
  if seq_feat:
96
  caption_feat = caption_seq[0]
97
  caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0]))