Spaces:
Sleeping
Sleeping
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
import torch | |
import streamlit as st | |
from Pages.streamlit_img_label import st_img_label | |
from Pages.streamlit_img_label.manage import ImageManager | |
import os | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
def get_masks(rect,img_path): | |
CHECKPOINT_PATH = os.path.join("weights", "sam_vit_h_4b8939.pth") | |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
MODEL_TYPE = "vit_h" | |
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) | |
mask_predictor = SamPredictor(sam) | |
rect = np.array([ | |
rect['left'], | |
rect['top'], | |
rect['left'] + rect['width'], | |
rect['top'] + rect['height'] | |
]) | |
image_bgr = cv2.imread(img_path) | |
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
mask_predictor.set_image(image_rgb) | |
masks, scores, logits = mask_predictor.predict( | |
box=rect, | |
multimask_output=False | |
) | |
return masks | |
def run(img_path): | |
st.set_option("deprecation.showfileUploaderEncoding", False) | |
im = ImageManager(img_path) | |
resized_img = im.resizing_img() | |
resized_rects = im.get_resized_rects() | |
if "rects" not in st.session_state: | |
st.session_state.rects = resized_rects | |
# Only display st_img_label if Save button hasn't been clicked | |
if not st.session_state.get("saved"): | |
rects = st_img_label(resized_img, box_color="red", rects=st.session_state.rects) | |
st.session_state.rects = rects | |
else: | |
st.image(resized_img, caption="Annotated Image",width=300, use_column_width=True) | |
for rect in st.session_state.rects: | |
# st.write(f"Rectangle: {rect}") | |
with st.spinner('Please wait while the product image is being extracted...'): | |
masks = get_masks(rect,img_path) | |
save_dir = "saved_masks" | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir) | |
for i, mask in enumerate(masks): | |
inverted_mask = 255 - (mask * 255).astype(np.uint8) | |
file_path = os.path.join(save_dir, f"inverted_mask_{i}.png") | |
cv2.imwrite(file_path, inverted_mask) | |
print(f"Inverted masks saved to directory: {save_dir}") | |
image_files = [f for f in os.listdir("bgImages") if os.path.isfile(os.path.join("bgImages", f))] | |
st.header("Template Selection") | |
# Create a dropdown with the list of image files | |
selected_image = st.selectbox("Select an image file", image_files) | |
if selected_image: | |
# Display the selected image | |
image_pathBG = os.path.join("bgImages", selected_image) | |
image = Image.open(image_pathBG) | |
st.image(image, width=300, caption=f"Selected image: {selected_image}") | |
if st.button("Create Image"): | |
# Read the base image and background image | |
image_bgr = cv2.imread(img_path) | |
background_bgr = cv2.imread(image_pathBG) | |
# Resize the background image to match the size of image_bgr | |
background_bgr = cv2.resize(background_bgr, (image_bgr.shape[1], image_bgr.shape[0])) | |
# Convert the base image to RGB format for mask prediction if it's not already in RGB | |
if image_bgr.shape[2] == 3: # No alpha channel, standard BGR image | |
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
else: | |
image_rgb = image_bgr[:, :, :3] # Drop alpha channel if it exists | |
# Assuming masks is a binary mask, convert it to uint8 format | |
mask = (masks[0] > 0).astype(np.uint8) * 255 | |
# Apply a Gaussian blur to the mask to smooth the edges | |
mask = cv2.GaussianBlur(mask, (3,3), 0) | |
# Ensure the image has an alpha channel | |
if image_bgr.shape[2] == 3: # If no alpha channel, add one | |
b, g, r = cv2.split(image_bgr) | |
alpha_channel = mask # Use the blurred mask as the alpha channel | |
image_bgra = cv2.merge((b, g, r, alpha_channel)) | |
else: | |
image_bgra = image_bgr | |
# Get the dimensions of the images | |
masked_height, masked_width = image_bgra.shape[:2] | |
background_height, background_width = background_bgr.shape[:2] | |
# Calculate the coordinates to place the masked image in the center of the background image | |
x_offset = (background_width - masked_width) // 2 | |
y_offset = (background_height - masked_height) // 2 | |
# Resize the masked image if it is larger than the background area | |
if masked_width > background_width or masked_height > background_height: | |
scaling_factor = min(background_width / masked_width, background_height / masked_height) | |
new_size = (int(masked_width * scaling_factor), int(masked_height * scaling_factor)) | |
image_bgra = cv2.resize(image_bgra, new_size, interpolation=cv2.INTER_AREA) | |
masked_height, masked_width = image_bgra.shape[:2] | |
x_offset = (background_width - masked_width) // 2 | |
y_offset = (background_height - masked_height) // 2 | |
# Create a copy of the background image and convert it to BGRA | |
background_bgra = cv2.cvtColor(background_bgr, cv2.COLOR_BGR2BGRA) | |
# Overlay the masked image onto the center of the background image | |
overlay_image = background_bgra.copy() | |
# Only update the region where the segmented image will be placed | |
overlay = np.zeros_like(background_bgra) | |
overlay[y_offset:y_offset+masked_height, x_offset:x_offset+masked_width] = image_bgra | |
# Create the alpha mask for blending | |
alpha_mask = overlay[:, :, 3] / 255.0 | |
alpha_inv = 1.0 - alpha_mask | |
# Modify alpha channel for smoother blending | |
alpha_mask = alpha_mask ** 0.5 # Applying square root for smoother blending | |
# Blend the images | |
for c in range(0, 3): | |
overlay_image[:, :, c] = (alpha_mask * overlay[:, :, c] + alpha_inv * overlay_image[:, :, c]) | |
# Set the alpha channel | |
overlay_image[:, :, 3] = np.clip(overlay[:, :, 3] + background_bgra[:, :, 3], 0, 255) | |
# Save the result | |
output_path = 'output/overlay_image.png' | |
cv2.imwrite(output_path, overlay_image) | |
# Display the overlay image | |
st.image(output_path, caption="Overlay Image", use_column_width=True, width=300) | |
def annotate(): | |
st.session_state.saved = True | |
if st.session_state.rects: | |
st.button(label="Save", on_click=annotate) | |