File size: 2,081 Bytes
29fa6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# CS5330 Lab 4
# This file includes functions for image segmentation

import torch
import numpy as np
from torchvision import models, transforms
from PIL import Image

# Load a pretrained DeepLabV3 model
model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()

# Define the preprocessing transform
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Extract the person out of the image
def extract_person(image):
    input_tensor = preprocess(image).unsqueeze(0)

    # Perform segmentation
    with torch.no_grad():
        output = model(input_tensor)['out'][0]
    output_predictions = output.argmax(0)

    # Create a mask for the person
    # Person class is 15 in COCO dataset
    person_mask = (output_predictions == 15).cpu().numpy()
    # Handle the case where there's no person detected
    if not person_mask.any():
        return None

    # Convert mask to binary (0 for background, 1 for person)
    binary_mask = np.where(person_mask, 1, 0).astype(np.uint8)

    # Extract the person from the original image
    extracted_person = Image.new("RGBA", image.size, (0, 0, 0, 0))
    for y in range(image.size[1]):
        for x in range(image.size[0]):
            # If pixel is part of the person, draw that the pixel
            if binary_mask[y, x] == 1:
                extracted_person.putpixel((x, y), image.getpixel((x, y)))
            # Else, leave the background transparent
    
    # Resize for consistency
    # Keep the aspect ratio, and adjust width to 800 pixels
    target_width = 800
    aspect_ratio = image.size[1] / image.size[0]
    target_height = int(target_width * aspect_ratio)
    extracted_person = extracted_person.resize((target_width, target_height))

    # Return the extracted person as a transparent PNG
    return extracted_person

# ==========Test case==========
# input = Image.open('false_img/ManyPeople.jpg')
# image = extract_person(input)
# if image is None:
#     print('Non-person')
# else:
#     image.show()