Spaces:
Sleeping
Sleeping
# CS5330 Lab 4 | |
# This file includes functions for image segmentation | |
import torch | |
import numpy as np | |
from torchvision import models, transforms | |
from PIL import Image | |
# Load a pretrained DeepLabV3 model | |
model = models.segmentation.deeplabv3_resnet101(pretrained=True) | |
model.eval() | |
# Define the preprocessing transform | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Extract the person out of the image | |
def extract_person(image): | |
input_tensor = preprocess(image).unsqueeze(0) | |
# Perform segmentation | |
with torch.no_grad(): | |
output = model(input_tensor)['out'][0] | |
output_predictions = output.argmax(0) | |
# Create a mask for the person | |
# Person class is 15 in COCO dataset | |
person_mask = (output_predictions == 15).cpu().numpy() | |
# Handle the case where there's no person detected | |
if not person_mask.any(): | |
return None | |
# Convert mask to binary (0 for background, 1 for person) | |
binary_mask = np.where(person_mask, 1, 0).astype(np.uint8) | |
# Extract the person from the original image | |
extracted_person = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
for y in range(image.size[1]): | |
for x in range(image.size[0]): | |
# If pixel is part of the person, draw that the pixel | |
if binary_mask[y, x] == 1: | |
extracted_person.putpixel((x, y), image.getpixel((x, y))) | |
# Else, leave the background transparent | |
# Resize for consistency | |
# Keep the aspect ratio, and adjust width to 800 pixels | |
target_width = 800 | |
aspect_ratio = image.size[1] / image.size[0] | |
target_height = int(target_width * aspect_ratio) | |
extracted_person = extracted_person.resize((target_width, target_height)) | |
# Return the extracted person as a transparent PNG | |
return extracted_person | |
# ==========Test case========== | |
# input = Image.open('false_img/ManyPeople.jpg') | |
# image = extract_person(input) | |
# if image is None: | |
# print('Non-person') | |
# else: | |
# image.show() |