import torch | |
import numpy as np | |
from additional_modules.modnet.modnet import MODNet | |
PRETRAINED_PATH = 'pretrained_models/modnet_photographic_portrait_matting.ckpt' | |
class ForegroundExtractor: | |
def __init__(self, device): | |
self.fg_extractor = MODNet().eval().to(device) | |
state_dict = torch.load(PRETRAINED_PATH, map_location='cpu') | |
state_dict = {x.replace('module.', '', 1): y for x, y in state_dict.items()} | |
self.fg_extractor.load_state_dict(state_dict) | |
self.device = device | |
def __call__(self, img): | |
img = np.transpose(np.array(img), (2, 0, 1))[None, ...] / 255. | |
img = torch.from_numpy(img).to(self.device).float() | |
matte = self.fg_extractor(img) | |
matte = np.transpose(matte.cpu().numpy()[0], (1, 2, 0)) | |
return matte | |