jeremyLE-Ekimetrics commited on
Commit
4debc65
1 Parent(s): 86735e0
Files changed (2) hide show
  1. biomap/inference.py +8 -1
  2. biomap/streamlit_app.py +1 -4
biomap/inference.py CHANGED
@@ -13,6 +13,7 @@ preprocess = T.Compose(
13
  ]
14
  )
15
 
 
16
  def inference(images, model):
17
  logging.info("Inference on Images")
18
  x = torch.stack([preprocess(image) for image in images]).cpu()
@@ -25,6 +26,10 @@ def inference(images, model):
25
  "img": x[i].detach().cpu(),
26
  "linear_preds": linear_pred[i].detach().cpu(),
27
  } for i in range(x.shape[0])]
 
 
 
 
28
  return outputs
29
 
30
 
@@ -32,6 +37,7 @@ if __name__ == "__main__":
32
  import hydra
33
  from model import LitUnsupervisedSegmenter
34
  from utils_gee import extract_img, transform_ee_img
 
35
  latitude = 2.98
36
  longitude = 48.81
37
  start_date = '2020-03-20'
@@ -49,7 +55,8 @@ if __name__ == "__main__":
49
  cfg = hydra.compose(config_name="my_train_config.yml")
50
 
51
  # Load the model
52
- model_path = "checkpoint/model/model.pt"
 
53
  saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
54
 
55
  nbclasses = cfg.dir_dataset_n_classes
 
13
  ]
14
  )
15
 
16
+ import numpy as np
17
  def inference(images, model):
18
  logging.info("Inference on Images")
19
  x = torch.stack([preprocess(image) for image in images]).cpu()
 
26
  "img": x[i].detach().cpu(),
27
  "linear_preds": linear_pred[i].detach().cpu(),
28
  } for i in range(x.shape[0])]
29
+
30
+ # water to natural green
31
+ for output in outputs:
32
+ output["linear_preds"] = torch.where(output["linear_preds"] == 5, 3, output["linear_preds"])
33
  return outputs
34
 
35
 
 
37
  import hydra
38
  from model import LitUnsupervisedSegmenter
39
  from utils_gee import extract_img, transform_ee_img
40
+ import os
41
  latitude = 2.98
42
  longitude = 48.81
43
  start_date = '2020-03-20'
 
55
  cfg = hydra.compose(config_name="my_train_config.yml")
56
 
57
  # Load the model
58
+
59
+ model_path = os.path.join(os.path.dirname(__file__), "checkpoint/model/model.pt")
60
  saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
61
 
62
  nbclasses = cfg.dir_dataset_n_classes
biomap/streamlit_app.py CHANGED
@@ -64,6 +64,7 @@ def app(model):
64
  st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
65
  st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
66
  st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
 
67
  if st.session_state["submit"]:
68
  fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
69
  st.session_state["infered"] = True
@@ -76,10 +77,6 @@ def app(model):
76
 
77
  if st.session_state["infered"]:
78
  st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
79
-
80
-
81
-
82
-
83
 
84
  col_1, col_2 = st.columns([0.5, 0.5])
85
  with col_1:
 
64
  st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
65
  st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
66
  st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
67
+
68
  if st.session_state["submit"]:
69
  fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
70
  st.session_state["infered"] = True
 
77
 
78
  if st.session_state["infered"]:
79
  st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
 
 
 
 
80
 
81
  col_1, col_2 = st.columns([0.5, 0.5])
82
  with col_1: