LemonPit commited on
Commit
e82a2b7
1 Parent(s): 251b56f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -39
app.py CHANGED
@@ -1,46 +1,50 @@
1
- from shiny import App, ui, render, reactive
2
-
3
- import os
 
4
  import numpy as np
5
  import torch
6
- from PIL import Image
7
  from transformers import SamModel, SamProcessor
8
 
9
- # Load the processor and the finetuned model
10
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
- model_path = "mito_model_checkpoint.pth"
12
  model = SamModel.from_pretrained("facebook/sam-vit-base")
 
13
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
  model.eval()
17
 
 
 
 
 
 
 
 
 
 
18
  def process_image(image_path):
19
- # Open and prepare the image
20
- image = Image.open(image_path).convert("RGB") # Ensure RGB format for consistency
21
  image_np = np.array(image)
22
-
23
- # Prepare the image for the model using the processor
24
  inputs = processor(images=image_np, return_tensors="pt")
25
  inputs = {k: v.to(device) for k, v in inputs.items()}
26
-
27
- # Perform inference
28
  with torch.no_grad():
29
  outputs = model(**inputs, multimask_output=False)
30
 
31
- # Process the prediction to create a binary mask
32
  pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
33
- segmented_image = (pred_masks[0] > .99).astype(np.uint8) * 255
34
- print(segmented_image)
35
- # Save the segmented image
36
- root, ext = os.path.splitext(image_path)
37
- output_path = f"{root}_segmented.png"
38
- segmented_image_pil = Image.fromarray(segmented_image.squeeze(), mode="L")
39
- segmented_image_pil.save(output_path)
40
-
41
- return output_path
42
 
43
- # Define the Shiny app UI layout
44
  app_ui = ui.page_fluid(
45
  ui.layout_sidebar(
46
  ui.panel_sidebar(
@@ -48,7 +52,7 @@ app_ui = ui.page_fluid(
48
  ),
49
  ui.panel_main(
50
  ui.output_image("uploaded_image", "Uploaded Image"),
51
- ui.output_image("segmented_image", "Segmented Image")
52
  )
53
  )
54
  )
@@ -59,30 +63,24 @@ def server(input, output, session):
59
  def uploaded_image():
60
  file_info = input.image_upload()
61
  if file_info:
62
- if isinstance(file_info, list):
63
- file_path = file_info[0].get('datapath')
64
- if file_path:
65
- return {'src': file_path}
66
- else:
67
- file_path = file_info.get('datapath')
68
- if file_path:
69
- return {'src': file_path}
70
- return None
71
 
72
  @output
73
- @render.image
74
  def segmented_image():
75
  file_info = input.image_upload()
76
  if file_info:
77
  try:
78
- file_path = file_info[0].get('datapath') if isinstance(file_info, list) else file_info.get('datapath')
79
  if file_path:
80
- segmented_path = process_image(file_path)
81
- return {'src': segmented_path}
 
82
  except Exception as e:
83
  print(f"Error processing image: {e}")
84
- return None
85
 
86
  # Create and run the Shiny app
87
  app = App(app_ui, server)
88
- app.run(port=7860)
 
1
+ from shiny import App, ui, render
2
+ import base64
3
+ from io import BytesIO
4
+ from PIL import Image, ImageOps
5
  import numpy as np
6
  import torch
 
7
  from transformers import SamModel, SamProcessor
8
 
9
+ # Load the processor and model
10
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
 
11
  model = SamModel.from_pretrained("facebook/sam-vit-base")
12
+ model_path = "SAM/mito_model_checkpoint.pth"
13
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
  model.eval()
17
 
18
+ def preprocess_image(image, target_size=(256, 256)):
19
+ """ Resize the image to a standard dimension """
20
+ image = ImageOps.contain(image, target_size)
21
+ return image
22
+
23
+ def postprocess_mask(mask, threshold=0.95):
24
+ """ Apply threshold to clean up mask """
25
+ return (mask > threshold).astype(np.uint8) * 255
26
+
27
  def process_image(image_path):
28
+ image = Image.open(image_path).convert("RGB")
29
+ image = preprocess_image(image) # Resize image before processing
30
  image_np = np.array(image)
31
+
 
32
  inputs = processor(images=image_np, return_tensors="pt")
33
  inputs = {k: v.to(device) for k, v in inputs.items()}
34
+
 
35
  with torch.no_grad():
36
  outputs = model(**inputs, multimask_output=False)
37
 
 
38
  pred_masks = torch.sigmoid(outputs.pred_masks).cpu().numpy()
39
+ # Ensure we only use the first mask and squeeze out any singleton dimensions
40
+ segmented_image = postprocess_mask(pred_masks.squeeze(), threshold=0.95) # Apply postprocessing
41
+
42
+ pil_img = Image.fromarray(segmented_image, mode="L")
43
+ buffered = BytesIO()
44
+ pil_img.save(buffered, format="PNG")
45
+ img_str = base64.b64encode(buffered.getvalue()).decode()
46
+ return f"data:image/png;base64,{img_str}"
 
47
 
 
48
  app_ui = ui.page_fluid(
49
  ui.layout_sidebar(
50
  ui.panel_sidebar(
 
52
  ),
53
  ui.panel_main(
54
  ui.output_image("uploaded_image", "Uploaded Image"),
55
+ ui.output_ui("segmented_image", "Segmented Image") # Use output_ui for HTML content
56
  )
57
  )
58
  )
 
63
  def uploaded_image():
64
  file_info = input.image_upload()
65
  if file_info:
66
+ file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
67
+ return {'src': file_path}
 
 
 
 
 
 
 
68
 
69
  @output
70
+ @render.ui # Use render.ui for direct HTML output
71
  def segmented_image():
72
  file_info = input.image_upload()
73
  if file_info:
74
  try:
75
+ file_path = file_info[0]['datapath'] if isinstance(file_info, list) else file_info['datapath']
76
  if file_path:
77
+ base64_img = process_image(file_path)
78
+ # Return an HTML image tag with the base64 data URI
79
+ return ui.tags.img(src=base64_img, style="max-width: 100%; height: auto;")
80
  except Exception as e:
81
  print(f"Error processing image: {e}")
82
+ return "No image processed."
83
 
84
  # Create and run the Shiny app
85
  app = App(app_ui, server)
86
+ app.run()