jojo_fork / e4e /utils /model_utils.py
algomuffin's picture
fork
e6b8f5d
import torch
import argparse
from models.psp import pSp
from models.encoders.psp_encoders import Encoder4Editing
def setup_model(checkpoint_path, device='cuda'):
ckpt = torch.load(checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = checkpoint_path
opts['device'] = device
opts = argparse.Namespace(**opts)
net = pSp(opts)
net.eval()
net = net.to(device)
return net, opts
def load_e4e_standalone(checkpoint_path, device='cuda'):
ckpt = torch.load(checkpoint_path, map_location='cpu')
opts = argparse.Namespace(**ckpt['opts'])
e4e = Encoder4Editing(50, 'ir_se', opts)
e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')}
e4e.load_state_dict(e4e_dict)
e4e.eval()
e4e = e4e.to(device)
latent_avg = ckpt['latent_avg'].to(device)
def add_latent_avg(model, inputs, outputs):
return outputs + latent_avg.repeat(outputs.shape[0], 1, 1)
e4e.register_forward_hook(add_latent_avg)
return e4e