Luecke commited on
Commit
0c1bad2
·
1 Parent(s): 72e071a

ready for merging

Browse files
detectree2/predictions/predict.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:53ca357b6e96813becd6f494b040a24286387127134141d77220afa259d28373
3
- size 8857
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:87d72c31226fa90773202d299b6e413ed5330cc5a25a15c44e413c5f15d23178
3
+ size 4749
detectree2/predictions/predict.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # REQUIREMENTS
3
+ """
4
+ !python -m pip -q install torchvision torch
5
+ !python -m pip -q install rasterio
6
+ !python -m pip -q install git+https://github.com/PatBall1/detectree2.git # in order for this to work, you must have installed gdal
7
+ !python -m pip install opencv-python
8
+ !python -m pip install requests
9
+ """
10
+ from detectree2.preprocessing.tiling import tile_data
11
+ from detectree2.models.outputs import project_to_geojson, stitch_crowns, clean_crowns
12
+ from detectree2.models.predict import predict_on_data
13
+ from detectree2.models.train import setup_cfg
14
+ from detectron2.engine import DefaultPredictor
15
+ import rasterio
16
+ import os
17
+ import requests
18
+
19
+ #Somehow this tiles_path where the tilings are stored, only works if the absolute path is provided
20
+ #Do not use relative path
21
+
22
+ #Make sure that tiles_path ends with '/' otherwise the predict_on_data() will not work later
23
+
24
+ def create_tiles(input_path, tile_width, tile_height, tile_buffer):
25
+ img_path = input_path
26
+
27
+ current_directory = os.getcwd()
28
+ tiles_directory = os.path.join(current_directory, "tiles")
29
+ if not os.path.exists(tiles_directory):
30
+ os.makedirs(tiles_directory)
31
+
32
+ data = rasterio.open(img_path)
33
+
34
+ buffer = tile_buffer
35
+ tile_width = tile_width
36
+ tile_height = tile_height
37
+ tile_data(data, tiles_directory, buffer, tile_width, tile_height, dtype_bool = True)
38
+
39
+ return tiles_directory
40
+
41
+ def download_file(url, local_filename):
42
+ with requests.get(url, stream=True) as r:
43
+ r.raise_for_status()
44
+ with open(local_filename, 'wb') as f:
45
+ for chunk in r.iter_content(chunk_size=8192):
46
+ f.write(chunk)
47
+ return local_filename
48
+
49
+ def predict(tile_path, overlap_threshold, confidence_threshold, simplify_value, store_path):
50
+ url = "https://zenodo.org/records/10522461/files/230103_randresize_full.pth"
51
+ trained_model = "./230103_randresize_full.pth"
52
+
53
+ download_file(url=url, local_filename=trained_model)
54
+
55
+ cfg = setup_cfg(update_model=trained_model)
56
+ #cfg.MODEL.DEVICE = "cpu"
57
+ predict_on_data(tile_path, predictor=DefaultPredictor(cfg))
58
+
59
+ project_to_geojson(tile_path, tile_path + "predictions/", tile_path + "predictions_geo/")
60
+ crowns = stitch_crowns(tile_path + "predictions_geo/", 1)
61
+ clean = clean_crowns(crowns, overlap_threshold, confidence=confidence_threshold)
62
+ clean = clean.set_geometry(clean.simplify(simplify_value))
63
+ clean.to_file(store_path + "predicted_delineations.geojson")
64
+
65
+ def run_detectree2(tif_input_path, tile_width=20, tile_height=20, tile_buffer=20, overlap_threshold=0.35, confidence_threshold=0.2, simplify_value=0.2, store_path='./train_outputs/'):
66
+ tile_path = create_tiles(input_path=tif_input_path, tile_width=tile_width, tile_height=tile_height, tile_buffer=tile_buffer)
67
+ predict(tile_path=tile_path, overlap_threshold=overlap_threshold, confidence_threshold=confidence_threshold, simplify_value=simplify_value, store_path=store_path)
68
+
69
+ run_detectree2(tif_input_path='/Users/jonathanseele/ETH/Hackathons/EcoHackathon/input_dataset/GeoData/TreeCrownVectorDataset.tif')