import torch import numpy as np import urllib from PIL import Image from torchvision import transforms def load_model(): model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True) model.eval() mean = torch.tensor([0.485, 0.456, 0.406]) std = torch.tensor([0.229, 0.224, 0.225]) preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) postprocess = transforms.Compose([ transforms.Normalize(mean=-mean/std, std=1/std), transforms.ToPILImage(), ]) if torch.cuda.is_available(): model.to('cuda') return model, preprocess def remove_background(img, model, preprocess): input_batch = preprocess(img)[None, ...] if torch.cuda.is_available(): input_batch = input_batch.to('cuda') with torch.no_grad(): output = model(input_batch)['out'][0] output_predictions = torch.nn.functional.softmax(output, dim=0) output_predictions = (output_predictions > 0.98).float() img.putalpha(255) result_np = np.array(img) result_np[..., 3] = (1-output_predictions[0].cpu().numpy())*255 return Image.fromarray(result_np.astype('uint8')) import os def main(): model, preprocess = load_model() # fpath = 'data/parrot_2.png' path_in = "/localhome/mta122/PycharmProjects/logo_ai/final_nocherry_score/one/DRAGON/G" for fpath_file in os.listdir(path_in): # fpath = 'data/parrot_2.png' fpath = os.path.join(path_in, fpath_file) # fpath_out = fpath.split('.')[0] + '_result_rembg.png' # cmd = f'rembg i {fpath} {fpath_out}' # print(cmd) # os.system(cmd) img = Image.open(fpath) if img.size[-1] > 3: img_np = np.array(img) img_rbg = img_np[:, : ,:3] img = Image.fromarray(img_rbg) result = remove_background(img, model, preprocess) result.save(fpath.split('.')[0] + '_result_deeplab.png') print('finished') main()