Trang Dang commited on
Commit
915d664
1 Parent(s): 70f7a82
Files changed (2) hide show
  1. app.py +3 -31
  2. run.py +27 -22
app.py CHANGED
@@ -1,10 +1,10 @@
1
  from pathlib import Path
2
  from typing import List, Dict, Tuple
3
  import matplotlib.colors as mpl_colors
4
- import os
5
  import pandas as pd
6
  import seaborn as sns
7
  import shinyswatch
 
8
 
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
  from transformers import SamModel, SamConfig, SamProcessor
@@ -46,36 +46,8 @@ def server(input: Inputs, output: Outputs, session: Session):
46
  if input.image_input():
47
  src = input.image_input()[0]['datapath']
48
  img = {"src": src, "width": "500px"}
49
-
50
- # Specify the cache directory
51
- transformers_cache_dir = "/usr/local/lib/python3.9/site-packages/transformers"
52
-
53
- # Set the TRANSFORMERS_CACHE environment variable
54
- os.environ["TRANSFORMERS_CACHE"] = transformers_cache_dir
55
-
56
- # Load the model configuration
57
- model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
58
- processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
59
-
60
- # Create an instance of the model architecture with the loaded configuration
61
- my_sam_model = SamModel(config=model_config)
62
- #Update the model by loading the weights from saved file.
63
- my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
64
-
65
- new_image = np.array(Image.open(src))
66
- inputs = processor(new_image, return_tensors="pt")
67
- inputs = {k: v.to(device) for k, v in inputs.items()}
68
- # my_sam_model.eval()
69
- # # forward pass
70
- # with torch.no_grad():
71
- # outputs = my_sam_model(**inputs, multimask_output=False)
72
-
73
- # # apply sigmoid
74
- # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
75
- # # convert soft mask to hard mask
76
- # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
77
- # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
78
-
79
  return img
80
  return None
81
 
 
1
  from pathlib import Path
2
  from typing import List, Dict, Tuple
3
  import matplotlib.colors as mpl_colors
 
4
  import pandas as pd
5
  import seaborn as sns
6
  import shinyswatch
7
+ import run
8
 
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
  from transformers import SamModel, SamConfig, SamProcessor
 
46
  if input.image_input():
47
  src = input.image_input()[0]['datapath']
48
  img = {"src": src, "width": "500px"}
49
+ x = run.pred(src)
50
+ print(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return img
52
  return None
53
 
run.py CHANGED
@@ -1,26 +1,31 @@
1
- # from transformers import SamModel, SamConfig, SamProcessor
2
- # import torch
3
- # import numpy as np
4
- # import matplotlib.pyplot as plt
5
- # import app
6
 
7
- # # Load the model configuration
8
- # model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
9
- # processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
 
10
 
11
- # # Create an instance of the model architecture with the loaded configuration
12
- # my_sam_model = SamModel(config=model_config)
13
- # #Update the model by loading the weights from saved file.
14
- # my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
15
 
 
 
 
 
 
 
 
 
16
 
17
- # load image
18
- # test_image = {"src": src, "width": "500px"}
19
-
20
- import os
21
-
22
- # Specify the cache directory
23
- transformers_cache_dir = "/usr/local/lib/python3.9/site-packages/transformers"
24
-
25
- # Set the TRANSFORMERS_CACHE environment variable
26
- os.environ["TRANSFORMERS_CACHE"] = transformers_cache_dir
 
1
+ from transformers import SamModel, SamConfig, SamProcessor
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import app
6
 
7
+ def pred(src):
8
+ # Load the model configuration
9
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
10
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
 
12
+ # Create an instance of the model architecture with the loaded configuration
13
+ my_sam_model = SamModel(config=model_config)
14
+ #Update the model by loading the weights from saved file.
15
+ my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
16
 
17
+ new_image = np.array(Image.open(src))
18
+ inputs = processor(new_image, return_tensors="pt")
19
+ inputs = {k: v.to(device) for k, v in inputs.items()}
20
+ x = 1
21
+ # my_sam_model.eval()
22
+ # # forward pass
23
+ # with torch.no_grad():
24
+ # outputs = my_sam_model(**inputs, multimask_output=False)
25
 
26
+ # # apply sigmoid
27
+ # single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
28
+ # # convert soft mask to hard mask
29
+ # single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
30
+ # single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
31
+ return x