Yuanhao Zhai commited on
Commit
959a00c
1 Parent(s): 779bb30
Files changed (6) hide show
  1. README.md +19 -0
  2. demo.py +53 -0
  3. demo/au.jpg +0 -0
  4. demo/tp.jpg +0 -0
  5. engine.py +1 -1
  6. requirements.txt +0 -1
README.md CHANGED
@@ -11,6 +11,10 @@
11
 
12
  This repo contains the MIL-FCN version of our WSCL implementation.
13
 
 
 
 
 
14
  ## 1. Setup
15
  Clone this repo
16
 
@@ -46,6 +50,21 @@ python main.py --load configs/final.yaml --eval --resume checkpoint-path
46
 
47
  We provide our pre-trained checkpoint [here](https://buffalo.box.com/s/2t3eqvwp7ua2ircpdx12sfq04sne4x50).
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ## Citation
50
  If you feel this project is helpful, please consider citing our paper
51
  ```bibtex
 
11
 
12
  This repo contains the MIL-FCN version of our WSCL implementation.
13
 
14
+ ## 🚨News
15
+
16
+ **03/2024**: add demo script! Check [here](https://github.com/yhZhai/WSCL?tab=readme-ov-file#4-demo) for more details!
17
+
18
  ## 1. Setup
19
  Clone this repo
20
 
 
50
 
51
  We provide our pre-trained checkpoint [here](https://buffalo.box.com/s/2t3eqvwp7ua2ircpdx12sfq04sne4x50).
52
 
53
+
54
+ ## 4. Demo
55
+
56
+ Running our manipulation model on your custom data!
57
+ Before running, please configure your desired input and output path in the `demo.py` file.
58
+
59
+ ```shell
60
+ python demo.py --load configs/final.yaml --resume checkpoint-path
61
+ ```
62
+
63
+ By default, it evaluates all `.jpg` files in the `demo` folder, and saves the
64
+ detection result in `tmp`.
65
+
66
+
67
+
68
  ## Citation
69
  If you feel this project is helpful, please consider citing our paper
70
  ```bibtex
demo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+ import tqdm
5
+ from albumentations.pytorch.functional import img_to_tensor
6
+ from pathlib import Path
7
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
8
+ from torchvision.utils import draw_segmentation_masks, make_grid, save_image
9
+
10
+ import utils.misc as misc
11
+ from models import get_ensemble_model
12
+ from opt import get_opt
13
+
14
+
15
+ def demo(folder_path, output_path=Path("tmp")):
16
+ opt = get_opt()
17
+ model = get_ensemble_model(opt).to(opt.device)
18
+ misc.resume_from(model, opt.resume)
19
+
20
+ with torch.no_grad():
21
+ for image_path in tqdm.tqdm(folder_path.glob("*.jpg")):
22
+ image = cv2.imread(image_path.as_posix())
23
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
24
+ dsm_image = torch.from_numpy(image).permute(2, 0, 1)
25
+ image_size = image.shape[:2]
26
+ raw_image = img_to_tensor(image)
27
+ image = img_to_tensor(
28
+ image,
29
+ normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
30
+ )
31
+ image = image.to(opt.device).unsqueeze(0)
32
+ outputs = model(image, seg_size=image_size)
33
+ out_map = outputs["ensemble"]["out_map"][0, ...].detach().cpu()
34
+
35
+ overlay = draw_segmentation_masks(
36
+ dsm_image, masks=out_map[0, ...] > opt.mask_threshold
37
+ )
38
+ grid_image = make_grid(
39
+ [
40
+ raw_image,
41
+ (out_map.repeat(3, 1, 1) > opt.mask_threshold).float() * 255,
42
+ overlay / 255.0,
43
+ ],
44
+ padding=5,
45
+ )
46
+ save_image(grid_image, (output_path / image_path.name).as_posix())
47
+
48
+
49
+ if __name__ == "__main__":
50
+ folder_path = Path("demo")
51
+ output_path = Path("tmp")
52
+ output_path.mkdir(exist_ok=True, parents=True)
53
+ demo(folder_path)
demo/au.jpg ADDED
demo/tp.jpg ADDED
engine.py CHANGED
@@ -10,7 +10,7 @@ import prettytable as pt
10
  import torch
11
  import torch.nn as nn
12
  from fast_pytorch_kmeans import KMeans
13
- from pathlib2 import Path
14
  from scipy.stats import hmean
15
  from sklearn import metrics
16
  from termcolor import cprint
 
10
  import torch
11
  import torch.nn as nn
12
  from fast_pytorch_kmeans import KMeans
13
+ from pathlib import Path
14
  from scipy.stats import hmean
15
  from sklearn import metrics
16
  from termcolor import cprint
requirements.txt CHANGED
@@ -10,7 +10,6 @@ opencv_contrib_python==4.5.3.56
10
  opencv_python==4.4.0.46
11
  opencv_python_headless==4.5.3.56
12
  pandas==1.3.5
13
- pathlib2==2.3.5
14
  Pillow==9.4.0
15
  prettytable==2.2.1
16
  pydensecrf==1.0rc2
 
10
  opencv_python==4.4.0.46
11
  opencv_python_headless==4.5.3.56
12
  pandas==1.3.5
 
13
  Pillow==9.4.0
14
  prettytable==2.2.1
15
  pydensecrf==1.0rc2