Spaces:
Sleeping
Sleeping
import torch | |
from transformers import SamConfig, SamProcessor, SamModel | |
from shiny import App, Inputs, Outputs, Session, render, ui, reactive | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
# Load model configuration | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Create an instance of the model | |
my_model = SamModel(config=model_config) | |
my_model.load_state_dict(torch.load("models/model_checkpoint_trained_on_train.pth", map_location=torch.device('cpu'))) | |
my_model.eval() | |
def predict(image): | |
image_array = np.array(image) | |
# Default bounding box to entire frame | |
prompt = [0, 0, image_array.shape[1], image_array.shape[0]] | |
inputs = processor( | |
images=image_array, | |
input_boxes=[[prompt]] if prompt else None, | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
outputs = my_model(**inputs, multimask_output=False) | |
pred_masks = outputs.pred_masks.squeeze() | |
seg_prob = torch.sigmoid(pred_masks).cpu().numpy() | |
# Convert soft to hard mask with a threshold | |
hard_mask = (seg_prob > 0.5).astype(np.uint8) | |
# Prepare probability map in color | |
seg_prob_scaled = np.stack((seg_prob * 255,) * 3, axis=-1).astype(np.uint8) | |
prob_image = Image.fromarray(seg_prob_scaled) | |
return None, hard_mask, prob_image | |
def display_images_as_data_url(img): | |
buffer = io.BytesIO() | |
img.save(buffer, format="PNG") | |
buffer.seek(0) | |
img_data = buffer.read() | |
return f"data:image/png;base64,{base64.b64encode(img_data).decode()}" | |
def server(input: Inputs, output: Outputs, Session: Session): | |
def segmented_data(): | |
uploaded_file = input.upload_image() | |
if uploaded_file: | |
uploaded_file_path = uploaded_file[0]['datapath'] | |
image = Image.open(uploaded_file_path).convert('RGB') | |
gt_mask, pred_mask, prob_image = predict(image) | |
return image, gt_mask, pred_mask, prob_image | |
else: | |
return None, None, None, None | |
def segmented_result(): | |
result = segmented_data() | |
if result == (None, None, None, None): | |
return "Upload an image to start segmentation." | |
image, gt_mask, pred_mask, prob_image = result | |
resized_image = image.resize((256, 256)) | |
original_image_data = display_images_as_data_url(resized_image) | |
predicted_mask_data = display_images_as_data_url(Image.fromarray(pred_mask * 255)) | |
prob_map_data = display_images_as_data_url(prob_image) | |
ui_result = ui.div( | |
ui.div( | |
ui.img(src=original_image_data, width="100%"), | |
ui.h5("Original Image") | |
), | |
ui.div( | |
ui.img(src=predicted_mask_data, width="100%"), | |
ui.h5("Predicted Mask") | |
), | |
ui.div( | |
ui.img(src=prob_map_data, width="100%"), | |
ui.h5("Probability Map") | |
), | |
style="display: flex;" | |
) | |
return ui_result | |
# Colors in HEX format | |
#Purple: #800080 | |
#White: #FFFFFF | |
#Orange: #FFA500 | |
app_ui = ui.page_fillable( | |
ui.layout_sidebar( | |
ui.sidebar( | |
ui.panel_title("Image Segmentation"), | |
ui.input_file("upload_image", "Upload an Image"), | |
style=""" | |
background: linear-gradient(135deg, #FFA500 0%, #800080 100%); | |
color: white; | |
padding: 20px; | |
border-radius: 0; | |
min-height: 100vh; | |
box-shadow: none; | |
""" | |
), | |
ui.output_ui("segmented_result"), | |
style=""" | |
background: #f8f9fa; | |
padding: 20px; | |
height: 100vh; | |
""" | |
), | |
style=""" | |
background: linear-gradient(135deg, #800080 0%, #FFFFFF 100%); | |
height: 100vh; | |
""" | |
) | |
app = App(app_ui, server) | |
if __name__ == "__main__": | |
app.run() | |