zejunyang commited on
Commit
fab87df
1 Parent(s): d947e9b
Files changed (1) hide show
  1. src/create_modules.py +3 -2
src/create_modules.py CHANGED
@@ -35,6 +35,7 @@ class Processer():
35
  def __init__(self):
36
  self.create_models()
37
 
 
38
  def create_models(self):
39
 
40
  self.lmk_extractor = LMKExtractor()
@@ -50,8 +51,8 @@ class Processer():
50
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
51
  # prepare model
52
  self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
53
- self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
54
- self.a2m_model.cuda().eval()
55
 
56
  self.vae = AutoencoderKL.from_pretrained(
57
  config.pretrained_vae_path,
 
35
  def __init__(self):
36
  self.create_models()
37
 
38
+ @spaces.GPU
39
  def create_models(self):
40
 
41
  self.lmk_extractor = LMKExtractor()
 
51
  audio_infer_config = OmegaConf.load(config.audio_inference_config)
52
  # prepare model
53
  self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
54
+ self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
55
+ self.a2m_model.to("cuda").eval()
56
 
57
  self.vae = AutoencoderKL.from_pretrained(
58
  config.pretrained_vae_path,