Trang Dang commited on
Commit
d152f7f
1 Parent(s): 3614812
Files changed (2) hide show
  1. app.py +19 -11
  2. run.py +3 -73
app.py CHANGED
@@ -21,9 +21,7 @@ app_ui = ui.page_fillable(
21
  ui.input_file("image_input", "Upload image: ", multiple=True),
22
  ),
23
  ui.output_image("image"),
24
- ui.output_image("image_output"),
25
- ui.output_image("single_patch_prediction"),
26
- ui.output_image("single_patch_prob")
27
  ),
28
  )
29
 
@@ -32,7 +30,6 @@ def server(input: Inputs, output: Outputs, session: Session):
32
  @output
33
  @render.image
34
  def image():
35
- here = Path(__file__).parent
36
  if input.image_input():
37
  # print(input.image_input())
38
  src = input.image_input()[0]['datapath']
@@ -41,15 +38,26 @@ def server(input: Inputs, output: Outputs, session: Session):
41
  return None
42
 
43
  @output
44
- @render.image
45
- def image_output():
46
- here = Path(__file__).parent
47
  if input.image_input():
48
  src = input.image_input()[0]['datapath']
49
- img = {"src": src, "width": "500px"}
50
- x = run.pred(src)
51
- print(x)
52
- return img
 
 
 
 
 
 
 
 
 
 
 
 
53
  return None
54
 
55
 
 
21
  ui.input_file("image_input", "Upload image: ", multiple=True),
22
  ),
23
  ui.output_image("image"),
24
+ ui.output_plot("plot_output"),
 
 
25
  ),
26
  )
27
 
 
30
  @output
31
  @render.image
32
  def image():
 
33
  if input.image_input():
34
  # print(input.image_input())
35
  src = input.image_input()[0]['datapath']
 
38
  return None
39
 
40
  @output
41
+ @render.plot
42
+ def plot_output():
 
43
  if input.image_input():
44
  src = input.image_input()[0]['datapath']
45
+ single_patch_prob, single_patch_prediction = run.pred(src)
46
+ fig, axes = plt.subplots(1, 2, figsize=(15, 5))
47
+
48
+ axes[0].imshow(single_patch_prob, cmap='gray')
49
+ axes[0].set_title("Probability Map")
50
+
51
+ im = axes[1].imshow(single_patch_prediction)
52
+ axes[1].set_title("Prediction")
53
+ cbar = fig.colorbar(im, ax=axes[1])
54
+
55
+ for ax in axes:
56
+ ax.set_xticks([])
57
+ ax.set_yticks([])
58
+ ax.set_xticklabels([])
59
+ ax.set_yticklabels([])
60
+ return fig
61
  return None
62
 
63
 
run.py CHANGED
@@ -5,22 +5,7 @@ import app
5
  import os
6
  from PIL import Image
7
 
8
-
9
- # def customized_patchify(large_image):
10
- # print(len(large_image), len(large_image[0]))
11
- # patch_size = 256
12
- # all_img_patches = []
13
- # patches_img = patchify(large_image, (patch_size, patch_size), step=256) #Step=256 for 256 patches means no overlap
14
- # for i in range(patches_img.shape[0]):
15
- # for j in range(patches_img.shape[1]):
16
- # single_patch_img = patches_img[i,j,:,:]
17
- # all_img_patches.append(single_patch_img)
18
- # # image = np.array(all_img_patches)
19
- # return all_img_patches
20
-
21
-
22
  def pred(src):
23
- # os.environ['HUGGINGFACE_HUB_HOME'] = './.cache'
24
  # Load the model configuration
25
  cache_dir = "/code/cache"
26
 
@@ -31,8 +16,8 @@ def pred(src):
31
  my_sam_model = SamModel(config=model_config)
32
  # #Update the model by loading the weights from saved file.
33
  my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
34
- print(src)
35
- new_image = np.array(Image.open(src))
36
  inputs = processor(new_image, return_tensors="pt")
37
  my_sam_model.eval()
38
 
@@ -45,60 +30,5 @@ def pred(src):
45
  # # convert soft mask to hard mask
46
  single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
47
  single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
48
- # patches = customized_patchify(new_image)
49
-
50
- # # Define the size of your array
51
- # array_size = 256
52
-
53
- # # Define the size of your grid
54
- # grid_size = 10
55
-
56
- # # Generate the grid points
57
- # x = np.linspace(0, array_size-1, grid_size)
58
- # y = np.linspace(0, array_size-1, grid_size)
59
-
60
- # # Generate a grid of coordinates
61
- # xv, yv = np.meshgrid(x, y)
62
-
63
- # # Convert the numpy arrays to lists
64
- # xv_list = xv.tolist()
65
- # yv_list = yv.tolist()
66
-
67
- # # Combine the x and y coordinates into a list of list of lists
68
- # input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
69
- # input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
70
-
71
- # i, j = 1, 2
72
-
73
- # # Selectelected patch for segmentation
74
- # random_array = patches[i, j]
75
 
76
- # single_patch = Image.fromarray(random_array)
77
- # inputs = processor(single_patch, input_points=input_points, return_tensors="pt")
78
-
79
- # inputs = {k: v.to(device) for k, v in inputs.items()}
80
-
81
- # my_mito_model.eval()
82
-
83
- # # forward pass
84
- # with torch.no_grad():
85
- # outputs = my_mito_model(**inputs, multimask_output=False)
86
-
87
- # # apply sigmoid
88
- # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
89
- # # convert soft mask to hard mask
90
- # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
91
- # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
92
-
93
- x = 1
94
- # my_sam_model.eval()
95
- # # forward pass
96
- # with torch.no_grad():
97
- # outputs = my_sam_model(**inputs, multimask_output=False)
98
-
99
- # # apply sigmoid
100
- # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
101
- # # convert soft mask to hard mask
102
- # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
103
- # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
104
- return x
 
5
  import os
6
  from PIL import Image
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def pred(src):
 
9
  # Load the model configuration
10
  cache_dir = "/code/cache"
11
 
 
16
  my_sam_model = SamModel(config=model_config)
17
  # #Update the model by loading the weights from saved file.
18
  my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
19
+
20
+ new_image = np.array(Image.open(src).convert("RGB"))
21
  inputs = processor(new_image, return_tensors="pt")
22
  my_sam_model.eval()
23
 
 
30
  # # convert soft mask to hard mask
31
  single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
32
  single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ return single_patch_prob, single_patch_prediction