pcuenq HF staff commited on
Commit
4b5364c
1 Parent(s): d4aa90a

Experimental: keep prior and model in float32

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -120,14 +120,14 @@ vqmodel.load_state_dict(torch.load(vqgan_path, map_location=device))
120
  vqmodel.eval().requires_grad_(False)
121
 
122
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
123
- prior = PriorModel().to(device).half()
124
  prior.load_state_dict(torch.load(prior_path, map_location=device))
125
  prior.eval().requires_grad_(False)
126
 
127
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
128
  model = Paella(byt5_embd=2560)
129
  model.load_state_dict(torch.load(model_path, map_location=device))
130
- model.eval().requires_grad_().half()
131
  replace_attention_layers(model)
132
  model.to(device)
133
 
 
120
  vqmodel.eval().requires_grad_(False)
121
 
122
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
123
+ prior = PriorModel().to(device)#.half()
124
  prior.load_state_dict(torch.load(prior_path, map_location=device))
125
  prior.eval().requires_grad_(False)
126
 
127
  model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
128
  model = Paella(byt5_embd=2560)
129
  model.load_state_dict(torch.load(model_path, map_location=device))
130
+ model.eval().requires_grad_()#.half()
131
  replace_attention_layers(model)
132
  model.to(device)
133