Trang Dang commited on
Commit
83fe5b0
1 Parent(s): 989d8bc
Files changed (1) hide show
  1. app.py +26 -0
app.py CHANGED
@@ -7,6 +7,8 @@ import seaborn as sns
7
  import shinyswatch
8
 
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
 
 
10
 
11
  sns.set_theme()
12
 
@@ -42,6 +44,30 @@ def server(input: Inputs, output: Outputs, session: Session):
42
  if input.image_input():
43
  src = input.image_input()[0]['datapath']
44
  img = {"src": src, "width": "500px"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return img
46
  return None
47
 
 
7
  import shinyswatch
8
 
9
  from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
+ from transformers import SamModel, SamConfig, SamProcessor
11
+ import torch
12
 
13
  sns.set_theme()
14
 
 
44
  if input.image_input():
45
  src = input.image_input()[0]['datapath']
46
  img = {"src": src, "width": "500px"}
47
+
48
+ # Load the model configuration
49
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
50
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
51
+
52
+ # Create an instance of the model architecture with the loaded configuration
53
+ my_sam_model = SamModel(config=model_config)
54
+ #Update the model by loading the weights from saved file.
55
+ my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
56
+
57
+ new_image = np.array(Image.open(src))
58
+ inputs = processor(new_image, return_tensors="pt")
59
+ inputs = {k: v.to(device) for k, v in inputs.items()}
60
+ my_sam_model.eval()
61
+ # forward pass
62
+ with torch.no_grad():
63
+ outputs = my_sam_model(**inputs, multimask_output=False)
64
+
65
+ # apply sigmoid
66
+ single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
67
+ # convert soft mask to hard mask
68
+ single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
69
+ single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
70
+
71
  return img
72
  return None
73