vumichien commited on
Commit
d27d313
1 Parent(s): deb639a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -66,7 +66,10 @@ is_cuda = torch.cuda.is_available()
66
  device = torch.device("cuda" if is_cuda else "cpu")
67
  print(device)
68
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False, download_root='./') # Must set jit=False for training
69
- clip.model.convert_weights(clip_model) # Actually this line is unnecessary since clip by default already on float16
 
 
 
70
  clip_model.eval()
71
  for p in clip_model.parameters():
72
  p.requires_grad = False
@@ -101,15 +104,16 @@ print('loading transformer checkpoint from {}'.format(args.resume_trans))
101
  ckpt = torch.load(args.resume_trans, map_location='cpu')
102
  trans_encoder.load_state_dict(ckpt['trans'], strict=True)
103
  trans_encoder.eval()
 
104
  mean = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/mean.npy'))
105
  std = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/std.npy'))
 
106
  if is_cuda:
107
  net.cuda()
108
  trans_encoder.cuda()
109
  mean = mean.cuda()
110
  std = std.cuda()
111
 
112
-
113
  def render(motions, device_id=0, name='test_vis'):
114
  frames, njoints, nfeats = motions.shape
115
  MINS = motions.min(axis=0).min(axis=0)
 
66
  device = torch.device("cuda" if is_cuda else "cpu")
67
  print(device)
68
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False, download_root='./') # Must set jit=False for training
69
+
70
+ if is_cuda:
71
+ clip.model.convert_weights(clip_model)
72
+
73
  clip_model.eval()
74
  for p in clip_model.parameters():
75
  p.requires_grad = False
 
104
  ckpt = torch.load(args.resume_trans, map_location='cpu')
105
  trans_encoder.load_state_dict(ckpt['trans'], strict=True)
106
  trans_encoder.eval()
107
+
108
  mean = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/mean.npy'))
109
  std = torch.from_numpy(np.load('./checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta/std.npy'))
110
+
111
  if is_cuda:
112
  net.cuda()
113
  trans_encoder.cuda()
114
  mean = mean.cuda()
115
  std = std.cuda()
116
 
 
117
  def render(motions, device_id=0, name='test_vis'):
118
  frames, njoints, nfeats = motions.shape
119
  MINS = motions.min(axis=0).min(axis=0)