File size: 1,086 Bytes
9006ba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import torch.distributed as dist
import torch
import sys

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1253'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def create_smplx_model(fast_smplx_path,
                        model_path,
                        model_type,
                        gender,
                        ext,
                        batch_size,
                        device):
    
    sys.path.insert(0, fast_smplx_path)
    import smplx
    smpl_model = smplx.create(model_path=model_path, 
                              model_type=model_type,
                              gender=gender, 
                              ext=ext,
                              batch_size=batch_size).to(device)
    smpl_model.eval()
    return smpl_model


def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])