File size: 830 Bytes
03da825 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import numpy as np
from additional_modules.modnet.modnet import MODNet
PRETRAINED_PATH = 'pretrained_models/modnet_photographic_portrait_matting.ckpt'
class ForegroundExtractor:
def __init__(self, device):
self.fg_extractor = MODNet().eval().to(device)
state_dict = torch.load(PRETRAINED_PATH, map_location='cpu')
state_dict = {x.replace('module.', '', 1): y for x, y in state_dict.items()}
self.fg_extractor.load_state_dict(state_dict)
self.device = device
@torch.no_grad()
def __call__(self, img):
img = np.transpose(np.array(img), (2, 0, 1))[None, ...] / 255.
img = torch.from_numpy(img).to(self.device).float()
matte = self.fg_extractor(img)
matte = np.transpose(matte.cpu().numpy()[0], (1, 2, 0))
return matte
|