File size: 6,036 Bytes
5c718d1
 
 
9fcd62f
5dd3935
5c718d1
9fcd62f
 
5c718d1
5dd3935
5c718d1
5dd3935
 
9fcd62f
 
5c718d1
9fcd62f
709a47d
5c718d1
 
 
 
 
 
 
 
 
 
 
 
5dd3935
 
 
 
 
 
 
 
 
 
 
 
 
 
5c718d1
 
 
 
5dd3935
 
5c718d1
 
 
1527861
 
 
9fcd62f
5dd3935
5c718d1
5dd3935
 
 
 
 
5c718d1
5dd3935
 
709a47d
851dbaf
5dd3935
 
 
9fcd62f
709a47d
5c718d1
 
 
 
 
 
 
 
 
 
 
 
5dd3935
 
5c718d1
 
 
 
 
5dd3935
 
 
 
 
 
 
 
851dbaf
709a47d
851dbaf
9fcd62f
5c718d1
 
 
 
 
5dd3935
5c718d1
5dd3935
 
 
 
 
5c718d1
 
 
 
 
 
 
 
 
 
5dd3935
5c718d1
 
 
5dd3935
5c718d1
5dd3935
5c718d1
5dd3935
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, compute_biodiv_score, plot_imgs_labels, plot_image
from utils_gee import get_image
from dateutil.relativedelta import relativedelta

from model import LitUnsupervisedSegmenter
import datetime
import matplotlib as mpl
from joblib import Parallel, cpu_count, delayed
import logging
from inference import inference
import streamlit as st
import cv2

@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location(model, latitude=48.81, longitude=2.98, start_date=2020, end_date=2022, how="year"):
    """Performe an inference on the latitude and longitude between the start date and the end date

    Args:
        latitude (float): the latitude of the landscape
        longitude (float): the longitude of the landscape
        start_date (str): the start date for our inference
        end_date (str): the end date for our inference
        model (_type_, optional): _description_. Defaults to model.

    Returns:            
        img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
    """
    logging.info("Running Inference on location")
    logging.info(f"latitude : {latitude} & longitude : {longitude}")
    logging.info(f"start date : {start_date} & end_date : {end_date}")
    logging.info(f"Prediction on intervale : {how}")
    if how == "month":
        delta_month = 1
    elif how == "2months":
        delta_month = 2
    elif how == "year":
        delta_month = 11
    else:
        raise ValueError("Wrong interval")

    assert int(end_date) > int(start_date), "end date must be stricly higher than start date"
    location = [float(latitude), float(longitude)]
    
    # Extract img numpy from earth engine and transform it to PIL img
    dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
    while dates[-1] < datetime.datetime(int(end_date), 1, 1, 0, 0, 0):
        dates.append(dates[-1] + relativedelta(months=delta_month))
    
    dates = [d.strftime("%Y-%m-%d") for d in dates]

    n_jobs = min([12, len(dates)])

    all_image = Parallel(n_jobs=n_jobs, prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
    # all_image = [cv2.imread("output/img.png") for i in range(len(dates))]
    outputs = inference(np.array(all_image), model)

    logging.info("Calculating Biodiversity Scores...")
    scores, scores_details = map(list, zip(*[compute_biodiv_score(output["linear_preds"].detach().numpy()) for output in outputs]))
    logging.info(f"Calculated Biodiversity Score : {scores}")

    imgs, labels, labeled_imgs = map(list, zip(*[transform_to_pil(output) for output in outputs]))
    
    images = [np.asarray(img) for img in imgs]
    labeled_imgs = [np.asarray(img) for img in labeled_imgs]
    title=f"TimeLapse at location ({location[0]:.2f},{location[1]:.2f}) between {start_date} and {end_date}"
    fig = plot_imgs_labels(dates, images, labeled_imgs, scores_details, scores, title=title)
    # fig.save("test.png")
    return fig

@st.cache_data(hash_funcs={LitUnsupervisedSegmenter: lambda dt: dt.name})
def inference_on_location_and_month(model, latitude = 48.81, longitude = 2.98, start_date = '2020-03-20'):
    """Performe an inference on the latitude and longitude between the start date and the end date

    Args:
        latitude (float): the latitude of the landscape
        longitude (float): the longitude of the landscape
        start_date (str): the start date for our inference
        end_date (str): the end date for our inference
        model (_type_, optional): _description_. Defaults to model.

    Returns:
        img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
    """
    logging.info("Running Inference on location and month")
    logging.info(f"latitude : {latitude} & longitude : {longitude}")
    location = [float(latitude), float(longitude)]
    
    # Extract img numpy from earth engine and transform it to PIL img
    end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
    end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")

    img_test = get_image(location, start_date, end_date)
    outputs = inference(np.array([img_test]), model)

    logging.info("Calculating Biodiversity Score...")
    score, score_details = compute_biodiv_score(outputs[0]["linear_preds"].detach().numpy())
    logging.info(f"Calculated Biodiversity Score : {score}")
    img, label, labeled_img = transform_to_pil(outputs[0])
    
    title=f"Prediction at location ({location[0]:.2f},{location[1]:.2f})  at {start_date}"
    fig = plot_image([start_date], [np.asarray(img)], [np.asarray(labeled_img)], [score_details], [score],title=title)
    return fig


if __name__ == "__main__":
    import logging
    import hydra
    import sys
    from model import LitUnsupervisedSegmenter
    file_handler = logging.FileHandler(filename='biomap.log')
    stdout_handler = logging.StreamHandler(stream=sys.stdout)
    handlers = [file_handler, stdout_handler]

    logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    # Initialize hydra with configs
    hydra.initialize(config_path="configs", job_name="corine")
    cfg = hydra.compose(config_name="my_train_config.yml")
    logging.info(f"config : {cfg}")
    # Load the model

    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    logging.info(f"Model Initialiazed")
    
    model_path = "biomap/checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)

    logging.info(f"Model Loaded")
    # inference_on_location_and_month(model)
    inference_on_location(model)