zejunyang
commited on
Commit
•
fab87df
1
Parent(s):
d947e9b
update
Browse files- src/create_modules.py +3 -2
src/create_modules.py
CHANGED
@@ -35,6 +35,7 @@ class Processer():
|
|
35 |
def __init__(self):
|
36 |
self.create_models()
|
37 |
|
|
|
38 |
def create_models(self):
|
39 |
|
40 |
self.lmk_extractor = LMKExtractor()
|
@@ -50,8 +51,8 @@ class Processer():
|
|
50 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
51 |
# prepare model
|
52 |
self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
|
53 |
-
self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
|
54 |
-
self.a2m_model.cuda
|
55 |
|
56 |
self.vae = AutoencoderKL.from_pretrained(
|
57 |
config.pretrained_vae_path,
|
|
|
35 |
def __init__(self):
|
36 |
self.create_models()
|
37 |
|
38 |
+
@spaces.GPU
|
39 |
def create_models(self):
|
40 |
|
41 |
self.lmk_extractor = LMKExtractor()
|
|
|
51 |
audio_infer_config = OmegaConf.load(config.audio_inference_config)
|
52 |
# prepare model
|
53 |
self.a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
|
54 |
+
self.a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt'], map_location="cpu"), strict=False)
|
55 |
+
self.a2m_model.to("cuda").eval()
|
56 |
|
57 |
self.vae = AutoencoderKL.from_pretrained(
|
58 |
config.pretrained_vae_path,
|