Spaces:
Build error
Build error
Trang Dang
commited on
Commit
•
83fe5b0
1
Parent(s):
989d8bc
validate
Browse files
app.py
CHANGED
@@ -7,6 +7,8 @@ import seaborn as sns
|
|
7 |
import shinyswatch
|
8 |
|
9 |
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
|
|
|
|
|
10 |
|
11 |
sns.set_theme()
|
12 |
|
@@ -42,6 +44,30 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
42 |
if input.image_input():
|
43 |
src = input.image_input()[0]['datapath']
|
44 |
img = {"src": src, "width": "500px"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return img
|
46 |
return None
|
47 |
|
|
|
7 |
import shinyswatch
|
8 |
|
9 |
from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
|
10 |
+
from transformers import SamModel, SamConfig, SamProcessor
|
11 |
+
import torch
|
12 |
|
13 |
sns.set_theme()
|
14 |
|
|
|
44 |
if input.image_input():
|
45 |
src = input.image_input()[0]['datapath']
|
46 |
img = {"src": src, "width": "500px"}
|
47 |
+
|
48 |
+
# Load the model configuration
|
49 |
+
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
|
50 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
51 |
+
|
52 |
+
# Create an instance of the model architecture with the loaded configuration
|
53 |
+
my_sam_model = SamModel(config=model_config)
|
54 |
+
#Update the model by loading the weights from saved file.
|
55 |
+
my_sam_model.load_state_dict(torch.load("sam_model.pth", map_location=torch.device('cpu')))
|
56 |
+
|
57 |
+
new_image = np.array(Image.open(src))
|
58 |
+
inputs = processor(new_image, return_tensors="pt")
|
59 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
60 |
+
my_sam_model.eval()
|
61 |
+
# forward pass
|
62 |
+
with torch.no_grad():
|
63 |
+
outputs = my_sam_model(**inputs, multimask_output=False)
|
64 |
+
|
65 |
+
# apply sigmoid
|
66 |
+
single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
67 |
+
# convert soft mask to hard mask
|
68 |
+
single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
|
69 |
+
single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
|
70 |
+
|
71 |
return img
|
72 |
return None
|
73 |
|