Spaces:
Sleeping
Sleeping
File size: 2,571 Bytes
5464cad 289faff 5464cad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from model import U2Net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = preprocess(img).unsqueeze(0).to(device)
return img
def run_inference(model, image_path, threshold=0.5):
input_img = preprocess_image(image_path)
with torch.no_grad():
d1, *_ = model(input_img)
pred = torch.sigmoid(d1)
pred = pred[0, :, :].cpu().numpy()
pred = (pred - pred.min()) / (pred.max() - pred.min())
if threshold is not None:
pred = (pred > threshold).astype(np.uint8) * 255
else:
pred = (pred * 255).astype(np.uint8)
return pred
def overlay_segmentation(original_image, binary_mask, alpha=0.5):
original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
original_image_np = np.array(original_image)
overlay = np.zeros_like(original_image_np)
overlay[:, :, 0] = binary_mask
overlay_image = (1 - alpha) * original_image_np + alpha * overlay
overlay_image = overlay_image.astype(np.uint8)
return overlay_image
if __name__ == '__main__':
# ---
model_path = 'results/inter-u2net-duts.pt'
image_path = 'images/ladies.jpg'
# ---
model = U2Net().to(device)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=False))
model.eval()
mask = run_inference(model, image_path, threshold=None)
mask_with_threshold = run_inference(model, image_path, threshold=0.7)
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
images = [
Image.open(image_path).resize((512, 512)),
mask,
overlay_segmentation(image_path, mask_with_threshold),
mask_with_threshold
]
for i, img in enumerate(images):
ax = fig.add_subplot(gs[i // 2, i % 2])
ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
ax.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig('inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
|