mta122
first
ca4133a
import torch
import numpy as np
import urllib
from PIL import Image
from torchvision import transforms
def load_model():
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
model.eval()
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
postprocess = transforms.Compose([
transforms.Normalize(mean=-mean/std, std=1/std),
transforms.ToPILImage(),
])
if torch.cuda.is_available():
model.to('cuda')
return model, preprocess
def remove_background(img, model, preprocess):
input_batch = preprocess(img)[None, ...]
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
with torch.no_grad():
output = model(input_batch)['out'][0]
output_predictions = torch.nn.functional.softmax(output, dim=0)
output_predictions = (output_predictions > 0.98).float()
img.putalpha(255)
result_np = np.array(img)
result_np[..., 3] = (1-output_predictions[0].cpu().numpy())*255
return Image.fromarray(result_np.astype('uint8'))
import os
def main():
model, preprocess = load_model()
# fpath = 'data/parrot_2.png'
path_in = "/localhome/mta122/PycharmProjects/logo_ai/final_nocherry_score/one/DRAGON/G"
for fpath_file in os.listdir(path_in):
# fpath = 'data/parrot_2.png'
fpath = os.path.join(path_in, fpath_file)
# fpath_out = fpath.split('.')[0] + '_result_rembg.png'
# cmd = f'rembg i {fpath} {fpath_out}'
# print(cmd)
# os.system(cmd)
img = Image.open(fpath)
if img.size[-1] > 3:
img_np = np.array(img)
img_rbg = img_np[:, : ,:3]
img = Image.fromarray(img_rbg)
result = remove_background(img, model, preprocess)
result.save(fpath.split('.')[0] + '_result_deeplab.png')
print('finished')
main()