File size: 3,157 Bytes
e82a2b7
 
 
 
90baaac
 
 
 
e82a2b7
90baaac
 
2a3e290
90baaac
 
 
 
 
e82a2b7
 
 
 
 
 
 
 
 
90baaac
e82a2b7
 
90baaac
e82a2b7
90baaac
 
e82a2b7
90baaac
 
 
 
e82a2b7
 
 
 
 
 
 
 
90baaac
 
3dd227f
90baaac
 
3dd227f
90baaac
 
e82a2b7
90baaac
 
3dd227f
 
90baaac
3dd227f
90baaac
 
 
 
e82a2b7
 
3dd227f
 
e82a2b7
90baaac
 
 
 
e82a2b7
90baaac
e82a2b7
 
 
90baaac
 
e82a2b7
90baaac
 
31de6d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from shiny import App, ui, render
import base64
from io import BytesIO
from PIL import Image, ImageOps
import numpy as np
import torch
from transformers import SamModel, SamProcessor

# Load the processor and model
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
model_path = "mito_model_checkpoint.pth"
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

def preprocess_image(image, target_size=(256, 256)):
    """ Resize the image to a standard dimension """
    image = ImageOps.contain(image, target_size)
    return image

def postprocess_mask(mask, threshold=0.95):
    """ Apply threshold to clean up mask """
    return (mask > threshold).astype(np.uint8) * 255

def process_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = preprocess_image(image)  # Resize image before processing
    image_np = np.array(image)
    
    inputs = processor(images=image_np, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)
    
    pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
    # Ensure we only use the first mask and squeeze out any singleton dimensions
    segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95)  # Apply postprocessing
    
    pil_img = Image.fromarray(segmented_image, mode="L")
    buffered = BytesIO()
    pil_img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/png;base64,{img_str}"

app_ui = ui.page_fluid(
    ui.layout_sidebar(
        ui.panel_sidebar(
            ui.input_file("image_upload", "Upload Satellite Image", accept=".jpg,.jpeg,.png,.tif")
        ),
        ui.panel_main(
            ui.output_image("uploaded_image", "Uploaded Image"),
            ui.output_ui("segmented_image", "Segmented Image")  # Use output_ui for HTML content
        )
    )
)

def server(input, output, session):
    @output
    @render.image
    def uploaded_image():
        file_info = input.image_upload()
        if file_info:
            file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
            return {'src': file_path}

    @output
    @render.ui  # Use render.ui for direct HTML output
    def segmented_image():
        file_info = input.image_upload()
        if file_info:
            try:
                file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
                if file_path:
                    base64_img = process_image(file_path)
                    # Return an HTML image tag with the base64 data URI
                    return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;")
            except Exception as e:
                print(f"Error processing image: {e}")
        return "No image processed."

# Create and run the Shiny app
app = App(app_ui, server)