milestone4 / app.py
LemonPit's picture
Update app.py
90baaac verified
raw
history blame
No virus
2.96 kB
from shiny import App, ui, render, reactive
import os
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
# Load the processor and the finetuned model
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model_path = "SAM/mito_model_checkpoint.pth"
model = SamModel.from_pretrained("facebook/sam-vit-base")
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
def process_image(image_path):
# Open and prepare the image
image = Image.open(image_path).convert("RGB") # Ensure RGB format for consistency
image_np = np.array(image)
# Prepare the image for the model using the processor
inputs = processor(images=image_np, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = model(**inputs, multimask_output=False)
# Process the prediction to create a binary mask
pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
segmented_image = (pred_masks[0] > .99).astype(np.uint8) * 255
print(segmented_image)
# Save the segmented image
root, ext = os.path.splitext(image_path)
output_path = f"{root}_segmented.png"
segmented_image_pil = Image.fromarray(segmented_image.squeeze(), mode="L")
segmented_image_pil.save(output_path)
return output_path
# Define the Shiny app UI layout
app_ui = ui.page_fluid(
ui.layout_sidebar(
ui.panel_sidebar(
ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif")
),
ui.panel_main(
ui.output_image("uploaded_image", "Uploaded Image"),
ui.output_image("segmented_image", "Segmented Image")
)
)
)
def server(input, output, session):
@output
@render.image
def uploaded_image():
file_info = input.image_upload()
if file_info:
if isinstance(file_info, list):
file_path = file_info[0].get('datapath')
if file_path:
return {'src': file_path}
else:
file_path = file_info.get('datapath')
if file_path:
return {'src': file_path}
return None
@output
@render.image
def segmented_image():
file_info = input.image_upload()
if file_info:
try:
file_path = file_info[0].get('datapath') if isinstance(file_info, list) else file_info.get('datapath')
if file_path:
segmented_path = process_image(file_path)
return {'src': segmented_path}
except Exception as e:
print(f"Error processing image: {e}")
return None
# Create and run the Shiny app
app = App(app_ui, server)
app.run(port=8000)