File size: 5,715 Bytes
7ce4544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1731cb5
2ff37ed
7ce4544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ff37ed
7ce4544
1731cb5
1591fde
7ce4544
 
1591fde
 
7ce4544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ff37ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import numpy as np
import os.path as osp
import cv2
import argparse
import torch
#from torch.utils.data import DataLoader
import torchvision
from RCFPyTorch0.dataset import BSDS_Dataset
from RCFPyTorch0.models import RCF
import gradio as gr
from PIL import Image
import sys
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from MODNet.src.models.modnet import MODNet
# 网页制作
import cv2


def single_scale_test(image):
    ref_size = 512
    # define image to tensor transform
    im_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet).cuda()
    modnet.load_state_dict(torch.load('MODNet/pretrained/modnet_photographic.ckpt'))
    modnet.eval()
# 注:程序中的数字仅表示某张输入图片尺寸,如1080x1440,此处只为记住其转换过程。
    # inference images
#    im_names = os.listdir(args.input_path)
#    for im_name in im_names:
#        print('Process image: {0}'.format(im_name))
        # read image

        # unify image channels to 3
    image = np.asarray(image)
    if len(image.shape) == 2:
        image = image[:, :, None]
    if image.shape[2] == 1:
        image = np.repeat(image, 3, axis=2)
    elif image.shape[2] == 4:
        image = image[:, :, 0:3]
    im_org = image                                # 保存numpy原始数组 (1080,1440,3)
    # convert image to PyTorch tensor
    image = Image.fromarray(image)
    image = im_transform(image)
    # add mini-batch dim
    image = image[None, :, :, :]
    # resize image for input
    im_b, im_c, im_h, im_w = image.shape
    if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
        if im_w >= im_h:
            im_rh = ref_size
            im_rw = int(im_w / im_h * ref_size)
        elif im_w < im_h:
            im_rw = ref_size
            im_rh = int(im_h / im_w * ref_size)
    else:
        im_rh = im_h
        im_rw = im_w
    im_rw = im_rw - im_rw % 32
    im_rh = im_rh - im_rh % 32
    image = F.interpolate(image, size=(im_rh, im_rw), mode='area')

        # inference
    _, _, matte = modnet(image.cuda(), True)    # 从模型获得的 matte ([1,1,512, 672])

        # resize and save matte,foreground picture
    matte = F.interpolate(matte, size=(im_h, im_w), mode='area')  #内插,扩展到([1,1,1080,1440])  范围[0,1]
    matte = matte[0][0].data.cpu().numpy()    # torch 张量转换成numpy (1080, 1440)
#        matte_name = im_name.split('.')[0] + '_matte.png'
#        Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
    matte_org = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2)   # 扩展到 (1080, 1440, 3) 以便和im_org计算
        
    foreground = im_org * matte_org + np.full(im_org.shape, 255) * (1 - matte_org)         # 计算前景,获得抠像
#        fg_name = im_name.split('.')[0] + '_fg.png'
    output = Image.fromarray(((foreground).astype('uint8')), mode='RGB')
    image = np.asarray(output)

    model = RCF().cuda()
    checkpoint = torch.load("RCFPyTorch0/bsds500_pascal_model.pth")
    model.load_state_dict(checkpoint)
    model.eval()
#    if not osp.isdir(save_dir):
#        os.makedirs(save_dir)
#    for idx, image in enumerate(test_loader):
    image = torch.from_numpy(image).float().permute(2,0,1).unsqueeze(0)
    image = image.cuda()
    _, _, H, W = image.shape
    results = model(image)
    all_res = torch.zeros((len(results), 1, H, W))
    for i in range(len(results)):
        all_res[i, 0, :, :] = results[i]
    #filename = osp.splitext(test_list[idx])[0]
    #torchvision.utils.save_image(1 - all_res, osp.join('RCFPyTorch0/results/RCF', 'result.jpg'))
    fuse_res = torch.squeeze(results[1].detach()).cpu().numpy()
    fuse_res = ((1 - fuse_res) * 255).astype('uint8')
    #cv2.imwrite(osp.join("RCFPyTorch0/results/RCF", 'result_ss.png'), fuse_res)
    #print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
    #print('Running single-scale test done')
    #output = Image.open(os.path.join('RCFPyTorch0/results/RCF', 'result_ss.png'))
    output = Image.fromarray((fuse_res).astype('uint8'))
    return output

parser = argparse.ArgumentParser(description='PyTorch Testing')
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
#parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
#parser.add_argument('--save-dir', help='output folder', default='results/RCF')
#parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
args = parser.parse_args()

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

#if not osp.isdir(args.save_dir):
#    os.makedirs(args.save_dir)

#test_dataset  = BSDS_Dataset(root=args.dataset, split='test')
#test_loader   = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
#test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
#assert len(test_list) == len(test_loader)



#if osp.isfile(args.checkpoint):
#    print("=> loading checkpoint from '{}'".format(args.checkpoint))
#    checkpoint = torch.load(args.checkpoint)
#    model.load_state_dict(checkpoint)
#    print("=> checkpoint loaded")
#else:
#    print("=> no checkpoint found at '{}'".format(args.checkpoint))

#print('Performing the testing...')


interface = gr.Interface(fn=single_scale_test, inputs="image", outputs="image")
interface.launch()