silentchen commited on
Commit
725545d
1 Parent(s): 3e5a852

update space

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -5,7 +5,7 @@ from functools import partial
5
  from typing import Optional
6
  from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
7
  from shap_e.diffusion.sample import sample_latents
8
- from shap_e.models.download import load_model, load_config
9
  from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
10
  import trimesh
11
  import torch.nn as nn
@@ -275,10 +275,25 @@ def main():
275
  """
276
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
  print("device:", device)
278
- latent_model = load_model('text300M', device=device)
 
 
 
 
 
 
279
  print("loaded latent model")
280
- xm = load_model('transmitter', device=device)
 
 
 
 
 
 
 
281
  print("loaded transmitter")
 
 
282
  diffusion = diffusion_from_config(load_config('diffusion'))
283
  freeze_params(xm.parameters())
284
  models = dict()
 
5
  from typing import Optional
6
  from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
7
  from shap_e.diffusion.sample import sample_latents
8
+ from shap_e.models.download import load_model, load_config, load_checkpoint
9
  from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
10
  import trimesh
11
  import torch.nn as nn
 
275
  """
276
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
277
  print("device:", device)
278
+ # latent_model = load_model('text300M', device=device)
279
+
280
+ latent_model = model_from_config(load_config('text300M'), device=device)
281
+ # print(model_name, kwargs)
282
+ # print(model)
283
+ latent_model.load_state_dict(load_checkpoint('text300M', device='cpu'))
284
+ latent_model.eval()
285
  print("loaded latent model")
286
+ latent_model.to(device)
287
+ # xm = load_model('transmitter', device=device)
288
+
289
+ xm = model_from_config(load_config('transmitter'), device=device)
290
+ # print(model_name, kwargs)
291
+ # print(model)
292
+ xm.load_state_dict(load_checkpoint('transmitter', device='cpu'))
293
+ xm.eval()
294
  print("loaded transmitter")
295
+ xm.to(device)
296
+
297
  diffusion = diffusion_from_config(load_config('diffusion'))
298
  freeze_params(xm.parameters())
299
  models = dict()