File size: 2,051 Bytes
ca4133a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import numpy as np
import urllib
from PIL import Image
from torchvision import transforms

def load_model():

    model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
    model.eval()

    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])

    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    postprocess = transforms.Compose([
        transforms.Normalize(mean=-mean/std, std=1/std),
        transforms.ToPILImage(),
    ])

    if torch.cuda.is_available():
        model.to('cuda')
    return model, preprocess

def remove_background(img, model, preprocess):
    input_batch = preprocess(img)[None, ...]

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
        output_predictions = torch.nn.functional.softmax(output, dim=0)
        output_predictions = (output_predictions > 0.98).float()

    img.putalpha(255)
    result_np = np.array(img)
    result_np[..., 3] = (1-output_predictions[0].cpu().numpy())*255
    
    return Image.fromarray(result_np.astype('uint8'))

import os
def main():
    model, preprocess = load_model()
    # fpath = 'data/parrot_2.png'
    path_in = "/localhome/mta122/PycharmProjects/logo_ai/final_nocherry_score/one/DRAGON/G"

    for fpath_file in os.listdir(path_in):
        # fpath = 'data/parrot_2.png'
        fpath = os.path.join(path_in, fpath_file)
        # fpath_out = fpath.split('.')[0] + '_result_rembg.png'
        # cmd = f'rembg i {fpath} {fpath_out}'
        # print(cmd)
        # os.system(cmd)

        img = Image.open(fpath)
        if img.size[-1] > 3:
            img_np = np.array(img)
            img_rbg = img_np[:, : ,:3]
        img = Image.fromarray(img_rbg)
        result = remove_background(img, model, preprocess)
        result.save(fpath.split('.')[0] + '_result_deeplab.png')
        print('finished')


main()