ONNX
File size: 2,523 Bytes
67a3943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import onnxruntime as ort
from typing import List, Tuple, Any, Dict
from pathlib import Path
import numpy as np
from croplands.io import read_zarr, read_zarr_profile
from croplands.utils import impute_nan, normalize_s2
from croplands.polygonize import polygonize_raster
import json
from skimage import measure

class CroplandHandler():

  def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None:

    self.input_dir = Path(input_dir)
    self.output_dir = Path(output_dir)

    
    assert self.input_dir.exists(), "Input directory doesn't exist"
    assert self.output_dir.exists(), "Output directory doesn't exist"
    assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device."

    
    mdoel_path = "model_repository/utae.onnx"
    provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider"
    self.session = ort.InferenceSession(str(mdoel_path), providers=[provider])

    with open("months_per_patch.json") as dates:
      self.dates = json.load(dates)
  
  def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]:

    assert file is not None, "Missing input file for inference"

    file_path = self.input_dir / file
    data = read_zarr(file_path)
    data = impute_nan(data)
    data = normalize_s2(data)
    profile = read_zarr_profile(file_path)
    dates = self.dates[file_path.stem]
    batch = np.expand_dims(data,axis=0)
    dates = np.expand_dims(np.array(dates),axis=0)
    return batch, profile, dates
  
  def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array:
    outputs = np.array(outputs)


    if save_raster:
      out_class = np.argmax(outputs[0][0], axis=0)
      out_bin = (out_class!=0).astype(np.uint8)
      components = measure.label(out_bin, connectivity=1)
      gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"],
                              crs=profile["crs"])
      data_path = self.input_dir / file
      save_path = self.output_dir / (data_path.stem + ".parquet")
      gdf.to_parquet(save_path)

    return outputs

  def predict(self, files: List[str], save_raster: bool = False) -> np.array:
    
    # Preprocessing
    batch, profiles, dates = self.preprocess(files)
    # Inference
    outputs = self.session.run(None, {"input": batch, "batch_positions": dates})
    # Postprocessing
    outputs = self.postprocess(outputs, files, profiles, save_raster)

    return outputs