Spaces:
Runtime error
Runtime error
import os | |
import torch.distributed as dist | |
import torch | |
import sys | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '1253' | |
dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
def create_smplx_model(fast_smplx_path, | |
model_path, | |
model_type, | |
gender, | |
ext, | |
batch_size, | |
device): | |
sys.path.insert(0, fast_smplx_path) | |
import smplx | |
smpl_model = smplx.create(model_path=model_path, | |
model_type=model_type, | |
gender=gender, | |
ext=ext, | |
batch_size=batch_size).to(device) | |
smpl_model.eval() | |
return smpl_model | |
def load_checkpoint(model, optimizer, checkpoint_path): | |
checkpoint = torch.load(checkpoint_path) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |