File size: 2,006 Bytes
656b5ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import argparse
import numpy as np
from skimage import io
from ormbg.models.ormbg import ORMBG
import torch.nn.functional as F


def parse_args():
    parser = argparse.ArgumentParser(
        description="Remove background from images using ORMBG model."
    )
    parser.add_argument(
        "--prediction",
        type=list,
        default=[
            os.path.join("examples", "loss", "loss01.png"),
            os.path.join("examples", "loss", "loss02.png"),
            os.path.join("examples", "loss", "loss03.png"),
            os.path.join("examples", "loss", "loss04.png"),
            os.path.join("examples", "loss", "loss05.png"),
        ],
        help="Path to the input image file.",
    )
    parser.add_argument(
        "--gt",
        type=str,
        default=os.path.join("examples", "loss", "gt.png"),
        help="Ground truth mask",
    )
    return parser.parse_args()


def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
    im_tensor = F.interpolate(
        torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
    ).type(torch.uint8)
    image = torch.divide(im_tensor, 255.0)
    return image


def inference(args):
    prediction_paths = args.prediction
    gt_path = args.gt

    net = ORMBG()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for pred_path in prediction_paths:

        model_input_size = [1024, 1024]
        loss = io.imread(pred_path)
        prediction = preprocess_image(loss, model_input_size).to(device)

        model_input_size = [1024, 1024]
        gt = io.imread(gt_path)
        ground_truth = preprocess_image(gt, model_input_size).to(device)

        _, loss = net.compute_loss([prediction], ground_truth)

        print(f"Loss: {pred_path} {loss}")


if __name__ == "__main__":
    inference(parse_args())