Sanshruth's picture
Upload 3 files
c3acf88 verified
import PIL
import numpy as np
import copy
import cv2
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import torch
from PIL import Image
import matplotlib
matplotlib.use('Agg')
def show_anns(anns, ax=None):
if len(anns) == 0:
return
if ax is None:
ax = plt.gca()
sorted_anns = sorted(enumerate(anns), key=(lambda x: x[1]['area']), reverse=True)
for original_idx, ann in sorted_anns:
m = ann['segmentation']
if m.shape != (512, 512): # Ensure mask is right size
m = cv2.resize(m.astype(float), (512, 512))
# Create a random color for this mask
color_mask = np.random.random(3)
# Create the colored mask
colored_mask = np.zeros((512, 512, 3))
for i in range(3):
colored_mask[:,:,i] = color_mask[i]
# Add the mask with transparency
ax.imshow(np.dstack([colored_mask, m * 0.35]))
# Find contours of the mask
contours, _ = cv2.findContours((m * 255).astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
# Add mask number if contours exist
if contours:
# Get the largest contour
cnt = max(contours, key=cv2.contourArea)
M = cv2.moments(cnt)
if M["m00"] != 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
# Add text with white color and black outline for visibility
ax.text(cx, cy, str(original_idx),
color='white',
fontsize=16,
ha='center',
va='center',
fontweight='bold',
bbox=dict(facecolor='black',
alpha=0.5,
edgecolor='none',
pad=1))
def create_image_grid(original_image, images, names, rows, columns):
names = copy.copy(names)
images = copy.copy(images)
# Filter out empty prompts and their corresponding images
filtered_images = []
filtered_names = []
for img, name in zip(images, names):
if name.strip():
filtered_images.append(img)
filtered_names.append(name)
images = filtered_images
names = filtered_names
# Add original image
images.insert(0, original_image)
names.insert(0, 'Original')
fig = plt.figure(figsize=(20, 20))
for idx, (img, name) in enumerate(zip(images, names)):
ax = fig.add_subplot(rows, columns, idx + 1)
if isinstance(img, PIL.Image.Image):
ax.imshow(img)
else:
ax.imshow(img)
ax.set_title(name, fontsize=12, pad=10)
ax.axis('off')
plt.tight_layout()
return fig