from shiny import App, ui, render, reactive import os import numpy as np import torch from PIL import Image from transformers import SamModel, SamProcessor # Load the processor and the finetuned model processor = SamProcessor.from_pretrained("facebook/sam-vit-base") model_path = "SAM/mito_model_checkpoint.pth" model = SamModel.from_pretrained("facebook/sam-vit-base") model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() def process_image(image_path): # Open and prepare the image image = Image.open(image_path).convert("RGB") # Ensure RGB format for consistency image_np = np.array(image) # Prepare the image for the model using the processor inputs = processor(images=image_np, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Perform inference with torch.no_grad(): outputs = model(**inputs, multimask_output=False) # Process the prediction to create a binary mask pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy() segmented_image = (pred_masks[0] > .99).astype(np.uint8) * 255 print(segmented_image) # Save the segmented image root, ext = os.path.splitext(image_path) output_path = f"{root}_segmented.png" segmented_image_pil = Image.fromarray(segmented_image.squeeze(), mode="L") segmented_image_pil.save(output_path) return output_path # Define the Shiny app UI layout app_ui = ui.page_fluid( ui.layout_sidebar( ui.panel_sidebar( ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif") ), ui.panel_main( ui.output_image("uploaded_image", "Uploaded Image"), ui.output_image("segmented_image", "Segmented Image") ) ) ) def server(input, output, session): @output @render.image def uploaded_image(): file_info = input.image_upload() if file_info: if isinstance(file_info, list): file_path = file_info[0].get('datapath') if file_path: return {'src': file_path} else: file_path = file_info.get('datapath') if file_path: return {'src': file_path} return None @output @render.image def segmented_image(): file_info = input.image_upload() if file_info: try: file_path = file_info[0].get('datapath') if isinstance(file_info, list) else file_info.get('datapath') if file_path: segmented_path = process_image(file_path) return {'src': segmented_path} except Exception as e: print(f"Error processing image: {e}") return None # Create and run the Shiny app app = App(app_ui, server) app.run(port=8000)