Trang Dang commited on
Commit
4aaf6ae
1 Parent(s): 218ce85

updare requirements

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -1
  2. run.py +14 -1
requirements.txt CHANGED
@@ -3,4 +3,6 @@ shinyswatch==0.6.0
3
  seaborn==0.12.2
4
  matplotlib==3.7.1
5
  transformers
6
- torch
 
 
 
3
  seaborn==0.12.2
4
  matplotlib==3.7.1
5
  transformers
6
+ torch
7
+ patchify
8
+ PIL
run.py CHANGED
@@ -61,8 +61,21 @@ def pred(src):
61
 
62
  single_patch = Image.fromarray(random_array)
63
  inputs = processor(single_patch, input_points=input_points, return_tensors="pt")
64
-
65
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  x = 1
67
  # my_sam_model.eval()
68
  # # forward pass
 
61
 
62
  single_patch = Image.fromarray(random_array)
63
  inputs = processor(single_patch, input_points=input_points, return_tensors="pt")
64
+
65
  inputs = {k: v.to(device) for k, v in inputs.items()}
66
+
67
+ my_mito_model.eval()
68
+
69
+ # forward pass
70
+ with torch.no_grad():
71
+ outputs = my_mito_model(**inputs, multimask_output=False)
72
+
73
+ # apply sigmoid
74
+ single_patch_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
75
+ # convert soft mask to hard mask
76
+ single_patch_prob = single_patch_prob.cpu().numpy().squeeze()
77
+ single_patch_prediction = (single_patch_prob > 0.5).astype(np.uint8)
78
+
79
  x = 1
80
  # my_sam_model.eval()
81
  # # forward pass