ML4 / app.py
AndresZarta's picture
Models Cache in Tmp in the DKfule
de85f6a
import torch
from transformers import SamConfig, SamProcessor, SamModel
from shiny import App, Inputs, Outputs, Session, render, ui, reactive
import numpy as np
from PIL import Image
import io
import base64
# Load model configuration
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model
my_model = SamModel(config=model_config)
my_model.load_state_dict(torch.load("models/model_checkpoint_trained_on_train.pth", map_location=torch.device('cpu')))
my_model.eval()
def predict(image):
image_array = np.array(image)
# Default bounding box to entire frame
prompt = [0, 0, image_array.shape[1], image_array.shape[0]]
inputs = processor(
images=image_array,
input_boxes=[[prompt]] if prompt else None,
return_tensors="pt"
)
with torch.no_grad():
outputs = my_model(**inputs, multimask_output=False)
pred_masks = outputs.pred_masks.squeeze()
seg_prob = torch.sigmoid(pred_masks).cpu().numpy()
# Convert soft to hard mask with a threshold
hard_mask = (seg_prob > 0.5).astype(np.uint8)
# Prepare probability map in color
seg_prob_scaled = np.stack((seg_prob * 255,) * 3, axis=-1).astype(np.uint8)
prob_image = Image.fromarray(seg_prob_scaled)
return None, hard_mask, prob_image
def display_images_as_data_url(img):
buffer = io.BytesIO()
img.save(buffer, format="PNG")
buffer.seek(0)
img_data = buffer.read()
return f"data:image/png;base64,{base64.b64encode(img_data).decode()}"
def server(input: Inputs, output: Outputs, Session: Session):
@reactive.Calc
def segmented_data():
uploaded_file = input.upload_image()
if uploaded_file:
uploaded_file_path = uploaded_file[0]['datapath']
image = Image.open(uploaded_file_path).convert('RGB')
gt_mask, pred_mask, prob_image = predict(image)
return image, gt_mask, pred_mask, prob_image
else:
return None, None, None, None
@output
@render.ui
def segmented_result():
result = segmented_data()
if result == (None, None, None, None):
return "Upload an image to start segmentation."
image, gt_mask, pred_mask, prob_image = result
resized_image = image.resize((256, 256))
original_image_data = display_images_as_data_url(resized_image)
predicted_mask_data = display_images_as_data_url(Image.fromarray(pred_mask * 255))
prob_map_data = display_images_as_data_url(prob_image)
ui_result = ui.div(
ui.div(
ui.img(src=original_image_data, width="100%"),
ui.h5("Original Image")
),
ui.div(
ui.img(src=predicted_mask_data, width="100%"),
ui.h5("Predicted Mask")
),
ui.div(
ui.img(src=prob_map_data, width="100%"),
ui.h5("Probability Map")
),
style="display: flex;"
)
return ui_result
# Colors in HEX format
#Purple: #800080
#White: #FFFFFF
#Orange: #FFA500
app_ui = ui.page_fillable(
ui.layout_sidebar(
ui.sidebar(
ui.panel_title("Image Segmentation"),
ui.input_file("upload_image", "Upload an Image"),
style="""
background: linear-gradient(135deg, #FFA500 0%, #800080 100%);
color: white;
padding: 20px;
border-radius: 0;
min-height: 100vh;
box-shadow: none;
"""
),
ui.output_ui("segmented_result"),
style="""
background: #f8f9fa;
padding: 20px;
height: 100vh;
"""
),
style="""
background: linear-gradient(135deg, #800080 0%, #FFFFFF 100%);
height: 100vh;
"""
)
app = App(app_ui, server)
if __name__ == "__main__":
app.run()