Spaces:
Runtime error
Runtime error
File size: 5,784 Bytes
7ce4544 2342f58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import numpy as np
import os.path as osp
import cv2
import argparse
import torch
#from torch.utils.data import DataLoader
import torchvision
from RCFPyTorch0.dataset import BSDS_Dataset
from RCFPyTorch0.models import RCF
import gradio as gr
from PIL import Image
import sys
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from MODNet.src.models.modnet import MODNet
# 网页制作
import cv2
def single_scale_test(image):
ref_size = 512
# define image to tensor transform
im_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load('MODNet/pretrained/modnet_photographic.ckpt'))
modnet.eval()
# 注:程序中的数字仅表示某张输入图片尺寸,如1080x1440,此处只为记住其转换过程。
# inference images
# im_names = os.listdir(args.input_path)
# for im_name in im_names:
# print('Process image: {0}'.format(im_name))
# read image
# unify image channels to 3
image = np.asarray(image)
if len(image.shape) == 2:
image = image[:, :, None]
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] == 4:
image = image[:, :, 0:3]
im_org = image # 保存numpy原始数组 (1080,1440,3)
# convert image to PyTorch tensor
image = Image.fromarray(image)
image = im_transform(image)
# add mini-batch dim
image = image[None, :, :, :]
# resize image for input
im_b, im_c, im_h, im_w = image.shape
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
else:
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
image = F.interpolate(image, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(image.cuda(), True) # 从模型获得的 matte ([1,1,512, 672])
# resize and save matte,foreground picture
matte = F.interpolate(matte, size=(im_h, im_w), mode='area') #内插,扩展到([1,1,1080,1440]) 范围[0,1]
matte = matte[0][0].data.cpu().numpy() # torch 张量转换成numpy (1080, 1440)
# matte_name = im_name.split('.')[0] + '_matte.png'
# Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
matte_org = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) # 扩展到 (1080, 1440, 3) 以便和im_org计算
foreground = im_org * matte_org + np.full(im_org.shape, 255) * (1 - matte_org) # 计算前景,获得抠像
# fg_name = im_name.split('.')[0] + '_fg.png'
Image.fromarray(((foreground).astype('uint8')), mode='RGB').save(os.path.join('MODNet/output-img', 'fg_name.png'))
output = Image.open(os.path.join('MODNet/output-img', 'fg_name.png'))
image = np.array(output)
model = RCF().cuda()
checkpoint = torch.load("RCFPyTorch0/bsds500_pascal_model.pth")
model.load_state_dict(checkpoint)
model.eval()
# if not osp.isdir(save_dir):
# os.makedirs(save_dir)
# for idx, image in enumerate(test_loader):
image = torch.from_numpy(image).float().permute(2,0,1).unsqueeze(0)
image = image.cuda()
_, _, H, W = image.shape
results = model(image)
all_res = torch.zeros((len(results), 1, H, W))
for i in range(len(results)):
all_res[i, 0, :, :] = results[i]
#filename = osp.splitext(test_list[idx])[0]
torchvision.utils.save_image(1 - all_res, osp.join('RCFPyTorch0/results/RCF', 'result.jpg'))
fuse_res = torch.squeeze(results[1].detach()).cpu().numpy()
fuse_res = ((1 - fuse_res) * 255).astype(np.uint8)
cv2.imwrite(osp.join("RCFPyTorch0/results/RCF", 'result_ss.png'), fuse_res)
#print('\rRunning single-scale test [%d/%d]' % (idx + 1, len(test_loader)), end='')
#print('Running single-scale test done')
output = Image.open(os.path.join('RCFPyTorch0/results/RCF', 'result_ss.png'))
return output
parser = argparse.ArgumentParser(description='PyTorch Testing')
parser.add_argument('--gpu', default='0', type=str, help='GPU ID')
#parser.add_argument('--checkpoint', default=None, type=str, help='path to latest checkpoint')
#parser.add_argument('--save-dir', help='output folder', default='results/RCF')
#parser.add_argument('--dataset', help='root folder of dataset', default='data/HED-BSDS')
args = parser.parse_args()
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#if not osp.isdir(args.save_dir):
# os.makedirs(args.save_dir)
#test_dataset = BSDS_Dataset(root=args.dataset, split='test')
#test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, drop_last=False, shuffle=False)
#test_list = [osp.split(i.rstrip())[1] for i in test_dataset.file_list]
#assert len(test_list) == len(test_loader)
#if osp.isfile(args.checkpoint):
# print("=> loading checkpoint from '{}'".format(args.checkpoint))
# checkpoint = torch.load(args.checkpoint)
# model.load_state_dict(checkpoint)
# print("=> checkpoint loaded")
#else:
# print("=> no checkpoint found at '{}'".format(args.checkpoint))
#print('Performing the testing...')
interface = gr.Interface(fn=single_scale_test, inputs="image", outputs="image")
interface.launch(share=True)
|