AndresZarta commited on
Commit
f512a51
1 Parent(s): 8db830a
Files changed (2) hide show
  1. app.py +33 -20
  2. requirements.txt +6 -0
app.py CHANGED
@@ -1,50 +1,61 @@
 
 
1
  from shiny import App, render, ui
2
- from shiny.ui import output_image, input_file, panel_main, panel_sidebar, layout_sidebar
3
- import matplotlib.pyplot as plt
4
  import numpy as np
5
- import torch # Assuming the model is a PyTorch model
6
  from PIL import Image
7
  import io
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Load your pre-trained model
10
- model = torch.load('models/model_checkpoint_trained_on_train.pth')
11
- model.eval()
12
 
13
  def predict(image_file):
14
- # Open image
15
  image = Image.open(image_file)
16
- # Process image as per model requirements (resize, normalize, etc.)
17
- # Convert to tensor, e.g., if model expects tensor input
18
- image_tensor = torch.Tensor(np.array(image) / 255).unsqueeze(0) # adjust processing as needed
 
 
 
19
  with torch.no_grad():
20
- output = model(image_tensor)
21
- # Post-process output, e.g., binarize, segment
22
- segmentation = (output > 0.5).squeeze().numpy() # adjust as needed
 
 
23
  return segmentation
24
 
25
  def server(input, output, session):
26
  @output
27
  @render.image
28
  def segmented_image():
29
- # Check if a file has been uploaded
30
- if not input.image_file:
31
  return None
32
 
33
- # Predict segmentation
34
  uploaded_file = input.image_file()[0]
35
  segmentation = predict(uploaded_file)
36
 
37
- # Create figure to show segmentation
38
  plt.imshow(segmentation, cmap='gray')
39
  plt.axis('off')
40
 
41
- # Save to buffer and return
42
  buf = io.BytesIO()
43
  plt.savefig(buf, format='png')
44
  buf.seek(0)
45
  return buf
46
 
47
- print("Creating app UI")
48
  app_ui = ui.page_fluid(
49
  layout_sidebar(
50
  panel_sidebar(
@@ -55,6 +66,8 @@ app_ui = ui.page_fluid(
55
  ),
56
  )
57
  )
58
- print("Creating app")
59
  app = App(app_ui, server)
60
 
 
 
 
1
+ import torch
2
+ from transformers import SamConfig, SamProcessor, SamModel
3
  from shiny import App, render, ui
 
 
4
  import numpy as np
 
5
  from PIL import Image
6
  import io
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Load the model configuration
10
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
11
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
12
+
13
+ # Create an instance of the model architecture with the loaded configuration
14
+ my_model = SamModel(config=model_config)
15
+
16
+ # Update the model by loading the weights from a saved file
17
+ my_model.load_state_dict(torch.load("models/model_checkpoint_trained_on_train.pth"))
18
 
19
+ # Set model to evaluation mode
20
+ my_model.eval()
 
21
 
22
  def predict(image_file):
23
+ # Open and preprocess the image
24
  image = Image.open(image_file)
25
+ image_array = np.array(image)
26
+
27
+ # Process the image
28
+ inputs = processor(images=image_array, return_tensors="pt")
29
+
30
+ # Make a prediction with the model
31
  with torch.no_grad():
32
+ outputs = my_model(**inputs)
33
+
34
+ # Extract the mask or segmentation map
35
+ segmentation = outputs[0].squeeze().numpy() # Adjust to extract necessary outputs
36
+
37
  return segmentation
38
 
39
  def server(input, output, session):
40
  @output
41
  @render.image
42
  def segmented_image():
43
+ if not input.image_file():
 
44
  return None
45
 
46
+ # Get the uploaded file
47
  uploaded_file = input.image_file()[0]
48
  segmentation = predict(uploaded_file)
49
 
50
+ # Visualize the segmentation
51
  plt.imshow(segmentation, cmap='gray')
52
  plt.axis('off')
53
 
 
54
  buf = io.BytesIO()
55
  plt.savefig(buf, format='png')
56
  buf.seek(0)
57
  return buf
58
 
 
59
  app_ui = ui.page_fluid(
60
  layout_sidebar(
61
  panel_sidebar(
 
66
  ),
67
  )
68
  )
69
+
70
  app = App(app_ui, server)
71
 
72
+ if __name__ == "__main__":
73
+ app.run()
requirements.txt CHANGED
@@ -3,3 +3,9 @@ torch
3
  numpy
4
  pillow
5
  matplotlib
 
 
 
 
 
 
 
3
  numpy
4
  pillow
5
  matplotlib
6
+ segment-geospatial
7
+ groundingdino-py
8
+ leafmap
9
+ localtileserver
10
+ datasets
11
+ transformers