File size: 4,837 Bytes
f8b4223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import cv2
import numpy as np
import requests
from PIL import Image
from io import BytesIO
import torch
from pathlib import Path
import torch.nn.functional as F
from typing import Dict, Any, List, Union, Tuple
from torchvision.transforms.functional import normalize

INPUT_SIZE = [1200, 1800]

def keep_large_components(a: np.ndarray) -> np.ndarray:
    """Remove small connected components from a binary mask, keeping only large regions.
    
    Args:
        a: Input binary mask as numpy array of shape (H,W) or (H,W,1)
        
    Returns:
        Processed mask with only large connected components remaining, shape (H,W,1)
    """
    dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(9, 9))
    a_mask = (a > 25).astype(np.uint8) * 255

    # Apply the Component analysis function
    analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S)
    (totalLabels, label_ids, values, centroid) = analysis

    # Find the components to be kept
    h, w = a.shape[:2]
    area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0])
    i_to_keep = []
    for i in range(1, totalLabels):
        area = values[i, cv2.CC_STAT_AREA]
        if area > area_limit:
            i_to_keep.append(i)

    if len(i_to_keep) > 0:
        # Or masks to be kept
        final_mask = np.zeros_like(a, dtype=np.uint8)
        for i in i_to_keep:
            componentMask = (label_ids == i).astype("uint8") * 255
            final_mask = cv2.bitwise_or(final_mask, componentMask)

        # Remove other components
        # Keep edges
        final_mask = cv2.dilate(final_mask, dilate_kernel, iterations = 2)
        a = cv2.bitwise_and(a, final_mask)
        a = a.reshape((a.shape[0], a.shape[1], 1))
        
    return a

def read_img(img: Union[str, Path]) -> np.ndarray:
    """Read an image from a URL or local path.
    
    Args:
        img: URL or file path to image
        
    Returns:
        Image as numpy array in RGB format with shape (H,W,3)
    """
    if img[0: 4] == 'http':
        response = requests.get(img)
        im = np.asarray(Image.open(BytesIO(response.content)))
        
    else:
        im = cv2.imread(str(img))
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    return im

def preprocess_input(im: np.ndarray) -> torch.Tensor:
    """Preprocess image for model input.
    
    Args:
        im: Input image as numpy array of shape (H,W,C)
        
    Returns:
        Preprocessed image as normalized torch tensor of shape (1,3,H,W)
    """
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
        
    if im.shape[2] == 4:  # if image has alpha channel, remove it
        im = im[:,:,:3]

    im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
    im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), INPUT_SIZE, mode="bilinear").type(torch.uint8)
    image = torch.divide(im_tensor,255.0)
    image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])

    if torch.cuda.is_available():
        image=image.cuda()
    
    return image

def postprocess_output(result: np.ndarray, orig_im_shape: Tuple[int, int]) -> np.ndarray:
    """Postprocess ONNX model output.
    
    Args:
        result: Model output as numpy array of shape (1,1,H,W)
        orig_im_shape: Original image dimensions (height, width)
        
    Returns:
        Processed binary mask as numpy array of shape (H,W,1)
    """
    result = torch.squeeze(F.upsample(
        torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode='bilinear'), 0)
    ma = torch.max(result)
    mi = torch.min(result)
    result = (result-mi)/(ma-mi)

    # a is alpha channel. 255 means foreground, 0 means background.
    a = (result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8)
    
    # postprocessing
    a = keep_large_components(a)

    return a

def process_image(src: Union[str, Path], ort_session: Any, model_path: Union[str, Path], outname: str) -> None:
    """Process an image through ONNX model to generate alpha mask and save result.
    
    Args:
        src: Source image URL or path
        ort_session: ONNX runtime inference session
        model_path: Path to ONNX model file 
        outname: Output filename for saving result
        
    Returns:
        None
    """
    # Load and preprocess image
    image_orig = read_img(src)
    image = preprocess_input(image_orig)
    
    # Prepare ONNX input
    inputs: Dict[str, Any] = {ort_session.get_inputs()[0].name: image.numpy()}
    
    # Get ONNX output and post-process
    result = ort_session.run(None, inputs)[0][0]
    alpha = postprocess_output(result, (image_orig.shape[0], image_orig.shape[1]))
    
    # Combine RGB image with alpha mask and save
    img_w_alpha = np.dstack((cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB), alpha))
    cv2.imwrite(outname, img_w_alpha)
    print(f"Saved: {outname}")