File size: 1,920 Bytes
1fd7780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import argparse
from Project.models.psp import pSp
import torch.onnx
import onnx
import onnxruntime

def setup_model(checkpoint_path, device='cpu'):
    
    
    model = onnx.load("e4e.onnx")
    onnx.checker.check_model(model)
    session = onnxruntime.InferenceSession("e4e.onnx")
    ckpt = torch.load(checkpoint_path)
    opts = ckpt['opts']
    opts['checkpoint_path'] = checkpoint_path
    opts['device'] = device
    latent_avg=ckpt['latent_avg']
    opts = argparse.Namespace(**opts)
    return session, opts,latent_avg
    '''



    ckpt = torch.load(checkpoint_path)

    opts = ckpt['opts']



    opts['checkpoint_path'] = checkpoint_path

    opts['device'] = device

    opts = argparse.Namespace(**opts)



    net = pSp(opts).encoder

    net.eval()

    #net = net.to(device)

    x=torch.randn(1,3,256,256,requires_grad=False)

    torch.onnx.export(

        net,

        x,

        "e4e.onnx",

        export_params = True,

        opset_version=11,  # the ONNX version to export the model to

        do_constant_folding=True,  # whether to execute constant folding for optimization

        input_names=['input'],  # the model's input names

        output_names=['output']

    )

    return net, opts

    '''
    '''

def load_e4e_standalone(checkpoint_path, device='cuda'):

    ckpt = torch.load(checkpoint_path, map_location='cpu')

    opts = argparse.Namespace(**ckpt['opts'])

    e4e = Encoder4Editing(50, 'ir_se', opts)

    e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}

    e4e.load_state_dict(e4e_dict)

    e4e.eval()

    e4e = e4e.to(device)

    latent_avg = ckpt['latent_avg'].to(device)



    def add_latent_avg(model, inputs, outputs):

        return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)



    e4e.register_forward_hook(add_latent_avg)

    return e4e

'''