jeremyLE-Ekimetrics commited on
Commit
f49b1cc
1 Parent(s): f186d18

Update biomap/plot_functions.py

Browse files
Files changed (1) hide show
  1. biomap/plot_functions.py +777 -777
biomap/plot_functions.py CHANGED
@@ -1,778 +1,778 @@
1
- from PIL import Image
2
-
3
- import hydra
4
- import matplotlib as mpl
5
- from utils import prep_for_plot
6
-
7
- import torch.multiprocessing
8
- import torchvision.transforms as T
9
- # import matplotlib.pyplot as plt
10
- from model import LitUnsupervisedSegmenter
11
- colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
12
- class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
13
- cmap = mpl.colors.ListedColormap(colors)
14
- #from train_segmentation import LitUnsupervisedSegmenter, cmap
15
-
16
- from utils_gee import extract_img, transform_ee_img
17
-
18
- import plotly.graph_objects as go
19
- import plotly.express as px
20
- import numpy as np
21
- from plotly.subplots import make_subplots
22
-
23
- import os
24
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
25
-
26
-
27
- colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
28
- class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
29
- scores_init = [2,3,4,3,1,4,0]
30
-
31
- # Import model configs
32
- hydra.initialize(config_path="configs", job_name="corine")
33
- cfg = hydra.compose(config_name="my_train_config.yml")
34
-
35
- nbclasses = cfg.dir_dataset_n_classes
36
-
37
- # Load Model
38
- model_path = "checkpoint/model/model.pt"
39
- saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
40
-
41
- model = LitUnsupervisedSegmenter(nbclasses, cfg)
42
- model.load_state_dict(saved_state_dict)
43
-
44
- from PIL import Image
45
-
46
- import hydra
47
-
48
- from utils import prep_for_plot
49
-
50
- import torch.multiprocessing
51
- import torchvision.transforms as T
52
- # import matplotlib.pyplot as plt
53
-
54
- from model import LitUnsupervisedSegmenter
55
-
56
- from utils_gee import extract_img, transform_ee_img
57
-
58
- import plotly.graph_objects as go
59
- import plotly.express as px
60
- import numpy as np
61
- from plotly.subplots import make_subplots
62
-
63
- import os
64
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
65
-
66
-
67
- colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
68
- cmap = mpl.colors.ListedColormap(colors)
69
- class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
70
- scores_init = [2,3,4,3,1,4,0]
71
-
72
- # Import model configs
73
- #hydra.initialize(config_path="configs", job_name="corine")
74
- cfg = hydra.compose(config_name="my_train_config.yml")
75
-
76
- nbclasses = cfg.dir_dataset_n_classes
77
-
78
- # Load Model
79
- model_path = "checkpoint/model/model.pt"
80
- saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
81
-
82
- model = LitUnsupervisedSegmenter(nbclasses, cfg)
83
- model.load_state_dict(saved_state_dict)
84
-
85
-
86
- #normalize img
87
- preprocess = T.Compose([
88
- T.ToPILImage(),
89
- T.Resize((320,320)),
90
- # T.CenterCrop(224),
91
- T.ToTensor(),
92
- T.Normalize(
93
- mean=[0.485, 0.456, 0.406],
94
- std=[0.229, 0.224, 0.225]
95
- )
96
- ])
97
-
98
- # Function that look for img on EE and segment it
99
- # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
100
-
101
- def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
102
- if how == 'month':
103
- img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
104
- elif how == 'year' :
105
- if year_end == None :
106
- img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
107
- else :
108
- img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
109
-
110
-
111
- img_test= transform_ee_img(img, max = 0.25)
112
-
113
- # Preprocess opened img
114
- x = preprocess(img_test)
115
- x = torch.unsqueeze(x, dim=0).cpu()
116
- # model=model.cpu()
117
-
118
- with torch.no_grad():
119
- feats, code = model.net(x)
120
- linear_preds = model.linear_probe(x, code)
121
- linear_preds = linear_preds.argmax(1)
122
- outputs = {
123
- 'img': x[:model.cfg.n_images].detach().cpu(),
124
- 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
125
- }
126
- return outputs
127
-
128
-
129
- # Function that look for all img on EE and extract all segments with the date as first output arg
130
-
131
- def segment_group(location, start_date, end_date, how = 'month') :
132
- outputs = []
133
- st_month = int(start_date[5:7])
134
- end_month = int(end_date[5:7])
135
-
136
- st_year = int(start_date[0:4])
137
- end_year = int(end_date[0:4])
138
-
139
-
140
-
141
- for year in range(st_year, end_year+1) :
142
-
143
- if year != end_year :
144
- last = 12
145
- else :
146
- last = end_month
147
-
148
- if year != st_year:
149
- start = 1
150
- else :
151
- start = st_month
152
-
153
- if how == 'month' :
154
- for month in range(start, last + 1):
155
- month_str = f"{month:0>2d}"
156
- year_str = str(year)
157
-
158
- outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
159
-
160
- elif how == 'year' :
161
- outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
162
-
163
- elif how == '2months' :
164
- for month in range(start, last + 1):
165
- month_str = f"{month:0>2d}"
166
- year_str = str(year)
167
- month_end = (month) % 12 +1
168
- if month_end < month :
169
- year_end = year +1
170
- else :
171
- year_end = year
172
- month_end= f"{month_end:0>2d}"
173
- year_end = str(year_end)
174
-
175
- outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
176
-
177
-
178
- return outputs
179
-
180
-
181
- # Function that transforms an output to PIL images
182
-
183
- def transform_to_pil(outputs,alpha=0.3):
184
- # Transform img with torch
185
- img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
186
- img=T.ToPILImage()(img)
187
-
188
- # Transform label by saving it then open it
189
- # label = outputs['linear_preds'][0]
190
- # plt.imsave('label.png',label,cmap=cmap)
191
- # label = Image.open('label.png')
192
-
193
- cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
194
- labels = np.array(outputs['linear_preds'][0])-1
195
- label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
196
-
197
-
198
- # Overlay labels with img wit alpha
199
- background = img.convert("RGBA")
200
- overlay = label.convert("RGBA")
201
-
202
- labeled_img = Image.blend(background, overlay, alpha)
203
-
204
- return img, label, labeled_img
205
-
206
-
207
-
208
- # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
209
-
210
- def values_from_output(output):
211
- imgs = transform_to_pil(output,alpha = 0.3)
212
-
213
- img = imgs[0]
214
- img = np.array(img.convert('RGB'))
215
-
216
- labeled_img = imgs[2]
217
- labeled_img = np.array(labeled_img.convert('RGB'))
218
-
219
- nb_values = []
220
- for i in range(7):
221
- nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
222
-
223
- score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
224
-
225
- return img, labeled_img, nb_values, score
226
-
227
-
228
- # Function that extract from outputs (from segment_group function) all dates/ all images
229
- def values_from_outputs(outputs) :
230
- months = []
231
- imgs = []
232
- imgs_label = []
233
- nb_values = []
234
- scores = []
235
-
236
- for output in outputs:
237
- img, labeled_img, nb_value, score = values_from_output(output[1])
238
- months.append(output[0])
239
- imgs.append(img)
240
- imgs_label.append(labeled_img)
241
- nb_values.append(nb_value)
242
- scores.append(score)
243
-
244
- return months, imgs, imgs_label, nb_values, scores
245
-
246
-
247
-
248
- def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
249
-
250
- fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
251
- fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
252
-
253
- # Scores
254
- scatters = []
255
- temp = []
256
- for score in scores :
257
- temp_score = []
258
- temp_date = []
259
- score = scores[i]
260
- temp.append(score)
261
- text_temp = ["" for i in temp]
262
- text_temp[-1] = str(round(score,2))
263
- scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
264
-
265
-
266
- # Scores
267
- fig = make_subplots(
268
- rows=1, cols=4,
269
- # specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
270
- # row_heights=[0.8, 0.2],
271
- column_widths = [0.6, 0.6,0.3, 0.3],
272
- subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
273
- )
274
-
275
- fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
276
- fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
277
-
278
- fig.add_trace(go.Pie(labels = class_names,
279
- values = nb_values[0],
280
- marker_colors = colors,
281
- name="Segment repartition",
282
- textposition='inside',
283
- texttemplate = "%{percent:.0%}",
284
- textfont_size=14
285
- ),
286
- row=1, col=3)
287
-
288
-
289
- fig.add_trace(scatters[0], row=1, col=4)
290
- # fig.add_annotation(text='score:' + str(scores[0]),
291
- # showarrow=False,
292
- # row=2, col=2)
293
-
294
-
295
- number_frames = len(imgs)
296
- frames = [dict(
297
- name = k,
298
- data = [ fig2["frames"][k]["data"][0],
299
- fig3["frames"][k]["data"][0],
300
- go.Pie(labels = class_names,
301
- values = nb_values[k],
302
- marker_colors = colors,
303
- name="Segment repartition",
304
- textposition='inside',
305
- texttemplate = "%{percent:.0%}",
306
- textfont_size=14
307
- ),
308
- scatters[k]
309
- ],
310
- traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
311
- # that are updated by the above three go.Scatter instances
312
- ) for k in range(number_frames)]
313
-
314
- updatemenus = [dict(type='buttons',
315
- buttons=[dict(label='Play',
316
- method='animate',
317
- args=[[f'{k}' for k in range(number_frames)],
318
- dict(frame=dict(duration=500, redraw=False),
319
- transition=dict(duration=0),
320
- easing='linear',
321
- fromcurrent=True,
322
- mode='immediate'
323
- )])],
324
- direction= 'left',
325
- pad=dict(r= 10, t=85),
326
- showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
327
- ]
328
-
329
- sliders = [{'yanchor': 'top',
330
- 'xanchor': 'left',
331
- 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
332
- 'transition': {'duration': 500.0, 'easing': 'linear'},
333
- 'pad': {'b': 10, 't': 50},
334
- 'len': 0.9, 'x': 0.1, 'y': 0,
335
- 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
336
- 'transition': {'duration': 0, 'easing': 'linear'}}],
337
- 'label': months[k], 'method': 'animate'} for k in range(number_frames)
338
- ]}]
339
-
340
-
341
- fig.update(frames=frames)
342
-
343
- for i,fr in enumerate(fig["frames"]):
344
- fr.update(
345
- layout={
346
- "xaxis": {
347
- "range": [0,imgs[0].shape[1]+i/100000]
348
- },
349
- "yaxis": {
350
- "range": [imgs[0].shape[0]+i/100000,0]
351
- },
352
- })
353
-
354
- fr.update(layout_title_text= months[i])
355
-
356
-
357
- fig.update(layout_title_text= 'tot')
358
- fig.update(
359
- layout={
360
- "xaxis": {
361
- "range": [0,imgs[0].shape[1]+i/100000],
362
- 'showgrid': False, # thin lines in the background
363
- 'zeroline': False, # thick line at x=0
364
- 'visible': False, # numbers below
365
- },
366
-
367
- "yaxis": {
368
- "range": [imgs[0].shape[0]+i/100000,0],
369
- 'showgrid': False, # thin lines in the background
370
- 'zeroline': False, # thick line at y=0
371
- 'visible': False,},
372
-
373
- "xaxis3": {
374
- "range": [0,len(scores)+1],
375
- 'autorange': False, # thin lines in the background
376
- 'showgrid': False, # thin lines in the background
377
- 'zeroline': False, # thick line at y=0
378
- 'visible': False
379
- },
380
-
381
- "yaxis3": {
382
- "range": [0,1.5],
383
- 'autorange': False,
384
- 'showgrid': False, # thin lines in the background
385
- 'zeroline': False, # thick line at y=0
386
- 'visible': False # thin lines in the background
387
- }
388
- },
389
- legend=dict(
390
- yanchor="bottom",
391
- y=0.99,
392
- xanchor="center",
393
- x=0.01
394
- )
395
- )
396
-
397
-
398
- fig.update_layout(updatemenus=updatemenus,
399
- sliders=sliders)
400
-
401
- fig.update_layout(margin=dict(b=0, r=0))
402
-
403
- # fig.show() #in jupyter notebook
404
-
405
- return fig
406
-
407
-
408
-
409
- # Last function (global one)
410
- # how = 'month' or '2months' or 'year'
411
-
412
- def segment_region(location, start_date, end_date, how = 'month'):
413
-
414
- #extract the outputs for each image
415
- outputs = segment_group(location, start_date, end_date, how = how)
416
-
417
- #extract the intersting values from image
418
- months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
419
-
420
- #Create the figure
421
- fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
422
-
423
- return fig
424
- #normalize img
425
- preprocess = T.Compose([
426
- T.ToPILImage(),
427
- T.Resize((320,320)),
428
- # T.CenterCrop(224),
429
- T.ToTensor(),
430
- T.Normalize(
431
- mean=[0.485, 0.456, 0.406],
432
- std=[0.229, 0.224, 0.225]
433
- )
434
- ])
435
-
436
- # Function that look for img on EE and segment it
437
- # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
438
-
439
- def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
440
- if how == 'month':
441
- img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
442
- elif how == 'year' :
443
- if year_end == None :
444
- img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
445
- else :
446
- img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
447
-
448
-
449
- img_test= transform_ee_img(img, max = 0.25)
450
-
451
- # Preprocess opened img
452
- x = preprocess(img_test)
453
- x = torch.unsqueeze(x, dim=0).cpu()
454
- # model=model.cpu()
455
-
456
- with torch.no_grad():
457
- feats, code = model.net(x)
458
- linear_preds = model.linear_probe(x, code)
459
- linear_preds = linear_preds.argmax(1)
460
- outputs = {
461
- 'img': x[:model.cfg.n_images].detach().cpu(),
462
- 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
463
- }
464
- return outputs
465
-
466
-
467
- # Function that look for all img on EE and extract all segments with the date as first output arg
468
-
469
- def segment_group(location, start_date, end_date, how = 'month') :
470
- outputs = []
471
- st_month = int(start_date[5:7])
472
- end_month = int(end_date[5:7])
473
-
474
- st_year = int(start_date[0:4])
475
- end_year = int(end_date[0:4])
476
-
477
-
478
-
479
- for year in range(st_year, end_year+1) :
480
-
481
- if year != end_year :
482
- last = 12
483
- else :
484
- last = end_month
485
-
486
- if year != st_year:
487
- start = 1
488
- else :
489
- start = st_month
490
-
491
- if how == 'month' :
492
- for month in range(start, last + 1):
493
- month_str = f"{month:0>2d}"
494
- year_str = str(year)
495
-
496
- outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
497
-
498
- elif how == 'year' :
499
- outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
500
-
501
- elif how == '2months' :
502
- for month in range(start, last + 1):
503
- month_str = f"{month:0>2d}"
504
- year_str = str(year)
505
- month_end = (month) % 12 +1
506
- if month_end < month :
507
- year_end = year +1
508
- else :
509
- year_end = year
510
- month_end= f"{month_end:0>2d}"
511
- year_end = str(year_end)
512
-
513
- outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
514
-
515
-
516
- return outputs
517
-
518
-
519
- # Function that transforms an output to PIL images
520
-
521
- def transform_to_pil(outputs,alpha=0.3):
522
- # Transform img with torch
523
- img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
524
- img=T.ToPILImage()(img)
525
-
526
- # Transform label by saving it then open it
527
- # label = outputs['linear_preds'][0]
528
- # plt.imsave('label.png',label,cmap=cmap)
529
- # label = Image.open('label.png')
530
-
531
- cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
532
- labels = np.array(outputs['linear_preds'][0])-1
533
- label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
534
-
535
-
536
- # Overlay labels with img wit alpha
537
- background = img.convert("RGBA")
538
- overlay = label.convert("RGBA")
539
-
540
- labeled_img = Image.blend(background, overlay, alpha)
541
-
542
- return img, label, labeled_img
543
-
544
- def values_from_output(output):
545
- imgs = transform_to_pil(output,alpha = 0.3)
546
-
547
- img = imgs[0]
548
- img = np.array(img.convert('RGB'))
549
-
550
- labeled_img = imgs[2]
551
- labeled_img = np.array(labeled_img.convert('RGB'))
552
-
553
- nb_values = []
554
- for i in range(7):
555
- nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
556
-
557
- score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
558
-
559
- return img, labeled_img, nb_values, score
560
-
561
-
562
- # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
563
-
564
-
565
-
566
- # Function that extract from outputs (from segment_group function) all dates/ all images
567
- def values_from_outputs(outputs) :
568
- months = []
569
- imgs = []
570
- imgs_label = []
571
- nb_values = []
572
- scores = []
573
-
574
- for output in outputs:
575
- img, labeled_img, nb_value, score = values_from_output(output[1])
576
- months.append(output[0])
577
- imgs.append(img)
578
- imgs_label.append(labeled_img)
579
- nb_values.append(nb_value)
580
- scores.append(score)
581
-
582
- return months, imgs, imgs_label, nb_values, scores
583
-
584
-
585
-
586
- def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
587
-
588
- fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
589
- fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
590
-
591
- # Scores
592
- scatters = []
593
- temp = []
594
- for score in scores :
595
- temp_score = []
596
- temp_date = []
597
- #score = scores[i]
598
- temp.append(score)
599
- n = len(temp)
600
- text_temp = ["" for i in temp]
601
- text_temp[-1] = str(round(score,2))
602
- scatters.append(go.Scatter(x=[0,1], y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
603
- print(text_temp)
604
-
605
- # Scores
606
- fig = make_subplots(
607
- rows=1, cols=4,
608
- specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
609
- subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
610
- )
611
-
612
- fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
613
- fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
614
-
615
- fig.add_trace(go.Pie(labels = class_names,
616
- values = nb_values[0],
617
- marker_colors = colors,
618
- name="Segment repartition",
619
- textposition='inside',
620
- texttemplate = "%{percent:.0%}",
621
- textfont_size=14
622
- ),
623
- row=1, col=3)
624
-
625
-
626
- fig.add_trace(scatters[0], row=1, col=4)
627
- fig.update_traces(showlegend=False, selector=dict(type='scatter'))
628
- #fig.update_traces(, selector=dict(type='scatter'))
629
- # fig.add_annotation(text='score:' + str(scores[0]),
630
- # showarrow=False,
631
- # row=2, col=2)
632
-
633
-
634
- number_frames = len(imgs)
635
- frames = [dict(
636
- name = k,
637
- data = [ fig2["frames"][k]["data"][0],
638
- fig3["frames"][k]["data"][0],
639
- go.Pie(labels = class_names,
640
- values = nb_values[k],
641
- marker_colors = colors,
642
- name="Segment repartition",
643
- textposition='inside',
644
- texttemplate = "%{percent:.0%}",
645
- textfont_size=14
646
- ),
647
- scatters[k]
648
- ],
649
- traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
650
- # that are updated by the above three go.Scatter instances
651
- ) for k in range(number_frames)]
652
-
653
- updatemenus = [dict(type='buttons',
654
- buttons=[dict(label='Play',
655
- method='animate',
656
- args=[[f'{k}' for k in range(number_frames)],
657
- dict(frame=dict(duration=500, redraw=False),
658
- transition=dict(duration=0),
659
- easing='linear',
660
- fromcurrent=True,
661
- mode='immediate'
662
- )])],
663
- direction= 'left',
664
- pad=dict(r= 10, t=85),
665
- showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
666
- ]
667
-
668
- sliders = [{'yanchor': 'top',
669
- 'xanchor': 'left',
670
- 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
671
- 'transition': {'duration': 500.0, 'easing': 'linear'},
672
- 'pad': {'b': 10, 't': 50},
673
- 'len': 0.9, 'x': 0.1, 'y': 0,
674
- 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
675
- 'transition': {'duration': 0, 'easing': 'linear'}}],
676
- 'label': months[k], 'method': 'animate'} for k in range(number_frames)
677
- ]}]
678
-
679
-
680
- fig.update(frames=frames)
681
-
682
- for i,fr in enumerate(fig["frames"]):
683
- fr.update(
684
- layout={
685
- "xaxis": {
686
- "range": [0,imgs[0].shape[1]+i/100000]
687
- },
688
- "yaxis": {
689
- "range": [imgs[0].shape[0]+i/100000,0]
690
- },
691
- })
692
-
693
- fr.update(layout_title_text= months[i])
694
-
695
-
696
- fig.update(layout_title_text= months[0])
697
- fig.update(
698
- layout={
699
- "xaxis": {
700
- "range": [0,imgs[0].shape[1]+i/100000],
701
- 'showgrid': False, # thin lines in the background
702
- 'zeroline': False, # thick line at x=0
703
- 'visible': False, # numbers below
704
- },
705
-
706
- "yaxis": {
707
- "range": [imgs[0].shape[0]+i/100000,0],
708
- 'showgrid': False, # thin lines in the background
709
- 'zeroline': False, # thick line at y=0
710
- 'visible': False,},
711
-
712
- "xaxis2": {
713
- "range": [0,imgs[0].shape[1]+i/100000],
714
- 'showgrid': False, # thin lines in the background
715
- 'zeroline': False, # thick line at x=0
716
- 'visible': False, # numbers below
717
- },
718
-
719
- "yaxis2": {
720
- "range": [imgs[0].shape[0]+i/100000,0],
721
- 'showgrid': False, # thin lines in the background
722
- 'zeroline': False, # thick line at y=0
723
- 'visible': False,},
724
-
725
-
726
- "xaxis3": {
727
- "range": [0,len(scores)+1],
728
- 'autorange': False, # thin lines in the background
729
- 'showgrid': False, # thin lines in the background
730
- 'zeroline': False, # thick line at y=0
731
- 'visible': False
732
- },
733
-
734
- "yaxis3": {
735
- "range": [0,1.5],
736
- 'autorange': False,
737
- 'showgrid': False, # thin lines in the background
738
- 'zeroline': False, # thick line at y=0
739
- 'visible': False # thin lines in the background
740
- }
741
- }
742
- )
743
-
744
-
745
- fig.update_layout(updatemenus=updatemenus,
746
- sliders=sliders,
747
- legend=dict(
748
- yanchor= 'top',
749
- xanchor= 'left',
750
- orientation="h")
751
- )
752
-
753
-
754
- fig.update_layout(margin=dict(b=0, r=0))
755
-
756
- # fig.show() #in jupyter notebook
757
-
758
- return fig
759
-
760
-
761
-
762
- # Last function (global one)
763
- # how = 'month' or '2months' or 'year'
764
-
765
- def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
766
- location = [float(latitude),float(longitude)]
767
- how = how[0]
768
- #extract the outputs for each image
769
- outputs = segment_group(location, start_date, end_date, how = how)
770
-
771
- #extract the intersting values from image
772
- months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
773
-
774
-
775
- #Create the figure
776
- fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
777
-
778
  return fig
 
1
+ from PIL import Image
2
+
3
+ import hydra
4
+ import matplotlib as mpl
5
+ from utils import prep_for_plot
6
+
7
+ import torch.multiprocessing
8
+ import torchvision.transforms as T
9
+ # import matplotlib.pyplot as plt
10
+ from model import LitUnsupervisedSegmenter
11
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
12
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
13
+ cmap = mpl.colors.ListedColormap(colors)
14
+ #from train_segmentation import LitUnsupervisedSegmenter, cmap
15
+
16
+ from utils_gee import extract_img, transform_ee_img
17
+
18
+ import plotly.graph_objects as go
19
+ import plotly.express as px
20
+ import numpy as np
21
+ from plotly.subplots import make_subplots
22
+
23
+ import os
24
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
25
+
26
+
27
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
28
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
29
+ scores_init = [2,3,4,3,1,4,0]
30
+
31
+ # Import model configs
32
+ hydra.initialize(config_path="configs", job_name="corine")
33
+ cfg = hydra.compose(config_name="my_train_config.yml")
34
+
35
+ nbclasses = cfg.dir_dataset_n_classes
36
+
37
+ # Load Model
38
+ model_path = "biomap/checkpoint/model/model.pt"
39
+ saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
40
+
41
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
42
+ model.load_state_dict(saved_state_dict)
43
+
44
+ from PIL import Image
45
+
46
+ import hydra
47
+
48
+ from utils import prep_for_plot
49
+
50
+ import torch.multiprocessing
51
+ import torchvision.transforms as T
52
+ # import matplotlib.pyplot as plt
53
+
54
+ from model import LitUnsupervisedSegmenter
55
+
56
+ from utils_gee import extract_img, transform_ee_img
57
+
58
+ import plotly.graph_objects as go
59
+ import plotly.express as px
60
+ import numpy as np
61
+ from plotly.subplots import make_subplots
62
+
63
+ import os
64
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
65
+
66
+
67
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
68
+ cmap = mpl.colors.ListedColormap(colors)
69
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
70
+ scores_init = [2,3,4,3,1,4,0]
71
+
72
+ # Import model configs
73
+ #hydra.initialize(config_path="configs", job_name="corine")
74
+ cfg = hydra.compose(config_name="my_train_config.yml")
75
+
76
+ nbclasses = cfg.dir_dataset_n_classes
77
+
78
+ # Load Model
79
+ model_path = "biomap/checkpoint/model/model.pt"
80
+ saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
81
+
82
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
83
+ model.load_state_dict(saved_state_dict)
84
+
85
+
86
+ #normalize img
87
+ preprocess = T.Compose([
88
+ T.ToPILImage(),
89
+ T.Resize((320,320)),
90
+ # T.CenterCrop(224),
91
+ T.ToTensor(),
92
+ T.Normalize(
93
+ mean=[0.485, 0.456, 0.406],
94
+ std=[0.229, 0.224, 0.225]
95
+ )
96
+ ])
97
+
98
+ # Function that look for img on EE and segment it
99
+ # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
100
+
101
+ def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
102
+ if how == 'month':
103
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
104
+ elif how == 'year' :
105
+ if year_end == None :
106
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
107
+ else :
108
+ img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
109
+
110
+
111
+ img_test= transform_ee_img(img, max = 0.25)
112
+
113
+ # Preprocess opened img
114
+ x = preprocess(img_test)
115
+ x = torch.unsqueeze(x, dim=0).cpu()
116
+ # model=model.cpu()
117
+
118
+ with torch.no_grad():
119
+ feats, code = model.net(x)
120
+ linear_preds = model.linear_probe(x, code)
121
+ linear_preds = linear_preds.argmax(1)
122
+ outputs = {
123
+ 'img': x[:model.cfg.n_images].detach().cpu(),
124
+ 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
125
+ }
126
+ return outputs
127
+
128
+
129
+ # Function that look for all img on EE and extract all segments with the date as first output arg
130
+
131
+ def segment_group(location, start_date, end_date, how = 'month') :
132
+ outputs = []
133
+ st_month = int(start_date[5:7])
134
+ end_month = int(end_date[5:7])
135
+
136
+ st_year = int(start_date[0:4])
137
+ end_year = int(end_date[0:4])
138
+
139
+
140
+
141
+ for year in range(st_year, end_year+1) :
142
+
143
+ if year != end_year :
144
+ last = 12
145
+ else :
146
+ last = end_month
147
+
148
+ if year != st_year:
149
+ start = 1
150
+ else :
151
+ start = st_month
152
+
153
+ if how == 'month' :
154
+ for month in range(start, last + 1):
155
+ month_str = f"{month:0>2d}"
156
+ year_str = str(year)
157
+
158
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
159
+
160
+ elif how == 'year' :
161
+ outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
162
+
163
+ elif how == '2months' :
164
+ for month in range(start, last + 1):
165
+ month_str = f"{month:0>2d}"
166
+ year_str = str(year)
167
+ month_end = (month) % 12 +1
168
+ if month_end < month :
169
+ year_end = year +1
170
+ else :
171
+ year_end = year
172
+ month_end= f"{month_end:0>2d}"
173
+ year_end = str(year_end)
174
+
175
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
176
+
177
+
178
+ return outputs
179
+
180
+
181
+ # Function that transforms an output to PIL images
182
+
183
+ def transform_to_pil(outputs,alpha=0.3):
184
+ # Transform img with torch
185
+ img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
186
+ img=T.ToPILImage()(img)
187
+
188
+ # Transform label by saving it then open it
189
+ # label = outputs['linear_preds'][0]
190
+ # plt.imsave('label.png',label,cmap=cmap)
191
+ # label = Image.open('label.png')
192
+
193
+ cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
194
+ labels = np.array(outputs['linear_preds'][0])-1
195
+ label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
196
+
197
+
198
+ # Overlay labels with img wit alpha
199
+ background = img.convert("RGBA")
200
+ overlay = label.convert("RGBA")
201
+
202
+ labeled_img = Image.blend(background, overlay, alpha)
203
+
204
+ return img, label, labeled_img
205
+
206
+
207
+
208
+ # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
209
+
210
+ def values_from_output(output):
211
+ imgs = transform_to_pil(output,alpha = 0.3)
212
+
213
+ img = imgs[0]
214
+ img = np.array(img.convert('RGB'))
215
+
216
+ labeled_img = imgs[2]
217
+ labeled_img = np.array(labeled_img.convert('RGB'))
218
+
219
+ nb_values = []
220
+ for i in range(7):
221
+ nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
222
+
223
+ score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
224
+
225
+ return img, labeled_img, nb_values, score
226
+
227
+
228
+ # Function that extract from outputs (from segment_group function) all dates/ all images
229
+ def values_from_outputs(outputs) :
230
+ months = []
231
+ imgs = []
232
+ imgs_label = []
233
+ nb_values = []
234
+ scores = []
235
+
236
+ for output in outputs:
237
+ img, labeled_img, nb_value, score = values_from_output(output[1])
238
+ months.append(output[0])
239
+ imgs.append(img)
240
+ imgs_label.append(labeled_img)
241
+ nb_values.append(nb_value)
242
+ scores.append(score)
243
+
244
+ return months, imgs, imgs_label, nb_values, scores
245
+
246
+
247
+
248
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
249
+
250
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
251
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
252
+
253
+ # Scores
254
+ scatters = []
255
+ temp = []
256
+ for score in scores :
257
+ temp_score = []
258
+ temp_date = []
259
+ score = scores[i]
260
+ temp.append(score)
261
+ text_temp = ["" for i in temp]
262
+ text_temp[-1] = str(round(score,2))
263
+ scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
264
+
265
+
266
+ # Scores
267
+ fig = make_subplots(
268
+ rows=1, cols=4,
269
+ # specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
270
+ # row_heights=[0.8, 0.2],
271
+ column_widths = [0.6, 0.6,0.3, 0.3],
272
+ subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
273
+ )
274
+
275
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
276
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
277
+
278
+ fig.add_trace(go.Pie(labels = class_names,
279
+ values = nb_values[0],
280
+ marker_colors = colors,
281
+ name="Segment repartition",
282
+ textposition='inside',
283
+ texttemplate = "%{percent:.0%}",
284
+ textfont_size=14
285
+ ),
286
+ row=1, col=3)
287
+
288
+
289
+ fig.add_trace(scatters[0], row=1, col=4)
290
+ # fig.add_annotation(text='score:' + str(scores[0]),
291
+ # showarrow=False,
292
+ # row=2, col=2)
293
+
294
+
295
+ number_frames = len(imgs)
296
+ frames = [dict(
297
+ name = k,
298
+ data = [ fig2["frames"][k]["data"][0],
299
+ fig3["frames"][k]["data"][0],
300
+ go.Pie(labels = class_names,
301
+ values = nb_values[k],
302
+ marker_colors = colors,
303
+ name="Segment repartition",
304
+ textposition='inside',
305
+ texttemplate = "%{percent:.0%}",
306
+ textfont_size=14
307
+ ),
308
+ scatters[k]
309
+ ],
310
+ traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
311
+ # that are updated by the above three go.Scatter instances
312
+ ) for k in range(number_frames)]
313
+
314
+ updatemenus = [dict(type='buttons',
315
+ buttons=[dict(label='Play',
316
+ method='animate',
317
+ args=[[f'{k}' for k in range(number_frames)],
318
+ dict(frame=dict(duration=500, redraw=False),
319
+ transition=dict(duration=0),
320
+ easing='linear',
321
+ fromcurrent=True,
322
+ mode='immediate'
323
+ )])],
324
+ direction= 'left',
325
+ pad=dict(r= 10, t=85),
326
+ showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
327
+ ]
328
+
329
+ sliders = [{'yanchor': 'top',
330
+ 'xanchor': 'left',
331
+ 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
332
+ 'transition': {'duration': 500.0, 'easing': 'linear'},
333
+ 'pad': {'b': 10, 't': 50},
334
+ 'len': 0.9, 'x': 0.1, 'y': 0,
335
+ 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
336
+ 'transition': {'duration': 0, 'easing': 'linear'}}],
337
+ 'label': months[k], 'method': 'animate'} for k in range(number_frames)
338
+ ]}]
339
+
340
+
341
+ fig.update(frames=frames)
342
+
343
+ for i,fr in enumerate(fig["frames"]):
344
+ fr.update(
345
+ layout={
346
+ "xaxis": {
347
+ "range": [0,imgs[0].shape[1]+i/100000]
348
+ },
349
+ "yaxis": {
350
+ "range": [imgs[0].shape[0]+i/100000,0]
351
+ },
352
+ })
353
+
354
+ fr.update(layout_title_text= months[i])
355
+
356
+
357
+ fig.update(layout_title_text= 'tot')
358
+ fig.update(
359
+ layout={
360
+ "xaxis": {
361
+ "range": [0,imgs[0].shape[1]+i/100000],
362
+ 'showgrid': False, # thin lines in the background
363
+ 'zeroline': False, # thick line at x=0
364
+ 'visible': False, # numbers below
365
+ },
366
+
367
+ "yaxis": {
368
+ "range": [imgs[0].shape[0]+i/100000,0],
369
+ 'showgrid': False, # thin lines in the background
370
+ 'zeroline': False, # thick line at y=0
371
+ 'visible': False,},
372
+
373
+ "xaxis3": {
374
+ "range": [0,len(scores)+1],
375
+ 'autorange': False, # thin lines in the background
376
+ 'showgrid': False, # thin lines in the background
377
+ 'zeroline': False, # thick line at y=0
378
+ 'visible': False
379
+ },
380
+
381
+ "yaxis3": {
382
+ "range": [0,1.5],
383
+ 'autorange': False,
384
+ 'showgrid': False, # thin lines in the background
385
+ 'zeroline': False, # thick line at y=0
386
+ 'visible': False # thin lines in the background
387
+ }
388
+ },
389
+ legend=dict(
390
+ yanchor="bottom",
391
+ y=0.99,
392
+ xanchor="center",
393
+ x=0.01
394
+ )
395
+ )
396
+
397
+
398
+ fig.update_layout(updatemenus=updatemenus,
399
+ sliders=sliders)
400
+
401
+ fig.update_layout(margin=dict(b=0, r=0))
402
+
403
+ # fig.show() #in jupyter notebook
404
+
405
+ return fig
406
+
407
+
408
+
409
+ # Last function (global one)
410
+ # how = 'month' or '2months' or 'year'
411
+
412
+ def segment_region(location, start_date, end_date, how = 'month'):
413
+
414
+ #extract the outputs for each image
415
+ outputs = segment_group(location, start_date, end_date, how = how)
416
+
417
+ #extract the intersting values from image
418
+ months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
419
+
420
+ #Create the figure
421
+ fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
422
+
423
+ return fig
424
+ #normalize img
425
+ preprocess = T.Compose([
426
+ T.ToPILImage(),
427
+ T.Resize((320,320)),
428
+ # T.CenterCrop(224),
429
+ T.ToTensor(),
430
+ T.Normalize(
431
+ mean=[0.485, 0.456, 0.406],
432
+ std=[0.229, 0.224, 0.225]
433
+ )
434
+ ])
435
+
436
+ # Function that look for img on EE and segment it
437
+ # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
438
+
439
+ def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
440
+ if how == 'month':
441
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
442
+ elif how == 'year' :
443
+ if year_end == None :
444
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
445
+ else :
446
+ img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
447
+
448
+
449
+ img_test= transform_ee_img(img, max = 0.25)
450
+
451
+ # Preprocess opened img
452
+ x = preprocess(img_test)
453
+ x = torch.unsqueeze(x, dim=0).cpu()
454
+ # model=model.cpu()
455
+
456
+ with torch.no_grad():
457
+ feats, code = model.net(x)
458
+ linear_preds = model.linear_probe(x, code)
459
+ linear_preds = linear_preds.argmax(1)
460
+ outputs = {
461
+ 'img': x[:model.cfg.n_images].detach().cpu(),
462
+ 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
463
+ }
464
+ return outputs
465
+
466
+
467
+ # Function that look for all img on EE and extract all segments with the date as first output arg
468
+
469
+ def segment_group(location, start_date, end_date, how = 'month') :
470
+ outputs = []
471
+ st_month = int(start_date[5:7])
472
+ end_month = int(end_date[5:7])
473
+
474
+ st_year = int(start_date[0:4])
475
+ end_year = int(end_date[0:4])
476
+
477
+
478
+
479
+ for year in range(st_year, end_year+1) :
480
+
481
+ if year != end_year :
482
+ last = 12
483
+ else :
484
+ last = end_month
485
+
486
+ if year != st_year:
487
+ start = 1
488
+ else :
489
+ start = st_month
490
+
491
+ if how == 'month' :
492
+ for month in range(start, last + 1):
493
+ month_str = f"{month:0>2d}"
494
+ year_str = str(year)
495
+
496
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
497
+
498
+ elif how == 'year' :
499
+ outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
500
+
501
+ elif how == '2months' :
502
+ for month in range(start, last + 1):
503
+ month_str = f"{month:0>2d}"
504
+ year_str = str(year)
505
+ month_end = (month) % 12 +1
506
+ if month_end < month :
507
+ year_end = year +1
508
+ else :
509
+ year_end = year
510
+ month_end= f"{month_end:0>2d}"
511
+ year_end = str(year_end)
512
+
513
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
514
+
515
+
516
+ return outputs
517
+
518
+
519
+ # Function that transforms an output to PIL images
520
+
521
+ def transform_to_pil(outputs,alpha=0.3):
522
+ # Transform img with torch
523
+ img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
524
+ img=T.ToPILImage()(img)
525
+
526
+ # Transform label by saving it then open it
527
+ # label = outputs['linear_preds'][0]
528
+ # plt.imsave('label.png',label,cmap=cmap)
529
+ # label = Image.open('label.png')
530
+
531
+ cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
532
+ labels = np.array(outputs['linear_preds'][0])-1
533
+ label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
534
+
535
+
536
+ # Overlay labels with img wit alpha
537
+ background = img.convert("RGBA")
538
+ overlay = label.convert("RGBA")
539
+
540
+ labeled_img = Image.blend(background, overlay, alpha)
541
+
542
+ return img, label, labeled_img
543
+
544
+ def values_from_output(output):
545
+ imgs = transform_to_pil(output,alpha = 0.3)
546
+
547
+ img = imgs[0]
548
+ img = np.array(img.convert('RGB'))
549
+
550
+ labeled_img = imgs[2]
551
+ labeled_img = np.array(labeled_img.convert('RGB'))
552
+
553
+ nb_values = []
554
+ for i in range(7):
555
+ nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
556
+
557
+ score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
558
+
559
+ return img, labeled_img, nb_values, score
560
+
561
+
562
+ # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
563
+
564
+
565
+
566
+ # Function that extract from outputs (from segment_group function) all dates/ all images
567
+ def values_from_outputs(outputs) :
568
+ months = []
569
+ imgs = []
570
+ imgs_label = []
571
+ nb_values = []
572
+ scores = []
573
+
574
+ for output in outputs:
575
+ img, labeled_img, nb_value, score = values_from_output(output[1])
576
+ months.append(output[0])
577
+ imgs.append(img)
578
+ imgs_label.append(labeled_img)
579
+ nb_values.append(nb_value)
580
+ scores.append(score)
581
+
582
+ return months, imgs, imgs_label, nb_values, scores
583
+
584
+
585
+
586
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
587
+
588
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
589
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
590
+
591
+ # Scores
592
+ scatters = []
593
+ temp = []
594
+ for score in scores :
595
+ temp_score = []
596
+ temp_date = []
597
+ #score = scores[i]
598
+ temp.append(score)
599
+ n = len(temp)
600
+ text_temp = ["" for i in temp]
601
+ text_temp[-1] = str(round(score,2))
602
+ scatters.append(go.Scatter(x=[0,1], y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
603
+ print(text_temp)
604
+
605
+ # Scores
606
+ fig = make_subplots(
607
+ rows=1, cols=4,
608
+ specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
609
+ subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
610
+ )
611
+
612
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
613
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
614
+
615
+ fig.add_trace(go.Pie(labels = class_names,
616
+ values = nb_values[0],
617
+ marker_colors = colors,
618
+ name="Segment repartition",
619
+ textposition='inside',
620
+ texttemplate = "%{percent:.0%}",
621
+ textfont_size=14
622
+ ),
623
+ row=1, col=3)
624
+
625
+
626
+ fig.add_trace(scatters[0], row=1, col=4)
627
+ fig.update_traces(showlegend=False, selector=dict(type='scatter'))
628
+ #fig.update_traces(, selector=dict(type='scatter'))
629
+ # fig.add_annotation(text='score:' + str(scores[0]),
630
+ # showarrow=False,
631
+ # row=2, col=2)
632
+
633
+
634
+ number_frames = len(imgs)
635
+ frames = [dict(
636
+ name = k,
637
+ data = [ fig2["frames"][k]["data"][0],
638
+ fig3["frames"][k]["data"][0],
639
+ go.Pie(labels = class_names,
640
+ values = nb_values[k],
641
+ marker_colors = colors,
642
+ name="Segment repartition",
643
+ textposition='inside',
644
+ texttemplate = "%{percent:.0%}",
645
+ textfont_size=14
646
+ ),
647
+ scatters[k]
648
+ ],
649
+ traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
650
+ # that are updated by the above three go.Scatter instances
651
+ ) for k in range(number_frames)]
652
+
653
+ updatemenus = [dict(type='buttons',
654
+ buttons=[dict(label='Play',
655
+ method='animate',
656
+ args=[[f'{k}' for k in range(number_frames)],
657
+ dict(frame=dict(duration=500, redraw=False),
658
+ transition=dict(duration=0),
659
+ easing='linear',
660
+ fromcurrent=True,
661
+ mode='immediate'
662
+ )])],
663
+ direction= 'left',
664
+ pad=dict(r= 10, t=85),
665
+ showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
666
+ ]
667
+
668
+ sliders = [{'yanchor': 'top',
669
+ 'xanchor': 'left',
670
+ 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
671
+ 'transition': {'duration': 500.0, 'easing': 'linear'},
672
+ 'pad': {'b': 10, 't': 50},
673
+ 'len': 0.9, 'x': 0.1, 'y': 0,
674
+ 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
675
+ 'transition': {'duration': 0, 'easing': 'linear'}}],
676
+ 'label': months[k], 'method': 'animate'} for k in range(number_frames)
677
+ ]}]
678
+
679
+
680
+ fig.update(frames=frames)
681
+
682
+ for i,fr in enumerate(fig["frames"]):
683
+ fr.update(
684
+ layout={
685
+ "xaxis": {
686
+ "range": [0,imgs[0].shape[1]+i/100000]
687
+ },
688
+ "yaxis": {
689
+ "range": [imgs[0].shape[0]+i/100000,0]
690
+ },
691
+ })
692
+
693
+ fr.update(layout_title_text= months[i])
694
+
695
+
696
+ fig.update(layout_title_text= months[0])
697
+ fig.update(
698
+ layout={
699
+ "xaxis": {
700
+ "range": [0,imgs[0].shape[1]+i/100000],
701
+ 'showgrid': False, # thin lines in the background
702
+ 'zeroline': False, # thick line at x=0
703
+ 'visible': False, # numbers below
704
+ },
705
+
706
+ "yaxis": {
707
+ "range": [imgs[0].shape[0]+i/100000,0],
708
+ 'showgrid': False, # thin lines in the background
709
+ 'zeroline': False, # thick line at y=0
710
+ 'visible': False,},
711
+
712
+ "xaxis2": {
713
+ "range": [0,imgs[0].shape[1]+i/100000],
714
+ 'showgrid': False, # thin lines in the background
715
+ 'zeroline': False, # thick line at x=0
716
+ 'visible': False, # numbers below
717
+ },
718
+
719
+ "yaxis2": {
720
+ "range": [imgs[0].shape[0]+i/100000,0],
721
+ 'showgrid': False, # thin lines in the background
722
+ 'zeroline': False, # thick line at y=0
723
+ 'visible': False,},
724
+
725
+
726
+ "xaxis3": {
727
+ "range": [0,len(scores)+1],
728
+ 'autorange': False, # thin lines in the background
729
+ 'showgrid': False, # thin lines in the background
730
+ 'zeroline': False, # thick line at y=0
731
+ 'visible': False
732
+ },
733
+
734
+ "yaxis3": {
735
+ "range": [0,1.5],
736
+ 'autorange': False,
737
+ 'showgrid': False, # thin lines in the background
738
+ 'zeroline': False, # thick line at y=0
739
+ 'visible': False # thin lines in the background
740
+ }
741
+ }
742
+ )
743
+
744
+
745
+ fig.update_layout(updatemenus=updatemenus,
746
+ sliders=sliders,
747
+ legend=dict(
748
+ yanchor= 'top',
749
+ xanchor= 'left',
750
+ orientation="h")
751
+ )
752
+
753
+
754
+ fig.update_layout(margin=dict(b=0, r=0))
755
+
756
+ # fig.show() #in jupyter notebook
757
+
758
+ return fig
759
+
760
+
761
+
762
+ # Last function (global one)
763
+ # how = 'month' or '2months' or 'year'
764
+
765
+ def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
766
+ location = [float(latitude),float(longitude)]
767
+ how = how[0]
768
+ #extract the outputs for each image
769
+ outputs = segment_group(location, start_date, end_date, how = how)
770
+
771
+ #extract the intersting values from image
772
+ months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
773
+
774
+
775
+ #Create the figure
776
+ fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
777
+
778
  return fig