rlrocha commited on
Commit
37af9c1
1 Parent(s): 4aa14ae
Files changed (1) hide show
  1. utils.py +155 -0
utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import imageio
3
+ import pickle
4
+ import tensorflow as tf
5
+ import matplotlib.pyplot as plt
6
+ import cartopy.crs as ccrs
7
+ import matplotlib.ticker as mticker
8
+ from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
9
+ import shapefile as shp
10
+ from matplotlib import animation
11
+ from IPython.display import HTML
12
+
13
+ def get_array(source, scaler_dict):
14
+
15
+ reader = imageio.get_reader(source)
16
+
17
+ source_video = []
18
+
19
+ try:
20
+ for im in reader:
21
+ source_video.append(im)
22
+ except RuntimeError:
23
+ pass
24
+
25
+ reader.close()
26
+
27
+ scaler_path = scaler_dict[source[-14:]]
28
+
29
+ with open(scaler_path, 'rb') as f:
30
+ sc = pickle.load(f)
31
+
32
+ data = np.array(source_video)[:,:,:,0]
33
+ data = sc.inverse_transform(data)
34
+ data = np.swapaxes(data, 0, 2)
35
+ data = np.swapaxes(data, 0, 1)
36
+
37
+ X = data[:,:,0:12]
38
+ y = data[:,:,12:]
39
+
40
+ return X, y
41
+
42
+ def get_slices(values, slices):
43
+
44
+ dim_size = len(values)
45
+ idx_step = int(dim_size/slices)
46
+
47
+ slices_list = []
48
+
49
+ for i in range(idx_step, dim_size, idx_step):
50
+
51
+ slices_list.append(np.round(values[i], 2))
52
+
53
+ return slices_list
54
+
55
+ def save_video(X, threshold=0, file_path = 'data/video.mp4'):
56
+
57
+ # Get vmax
58
+ var = X.copy()
59
+
60
+ var[np.isnan(var)] = 0
61
+ var[var<=0] = 0
62
+
63
+ counts, bins = np.histogram(var[:])
64
+
65
+ value = counts[counts>np.median(counts)][-1]
66
+ idx = np.where(counts==value)[0][0]
67
+
68
+ vmax = np.round(bins[idx])
69
+
70
+ # Latitude and longitude
71
+ lon = np.loadtxt('data/longitude.txt')
72
+ lat = np.loadtxt('data/latitude.txt')
73
+
74
+ area = [lon.min(),lon.max(),lat.min(),lat.max()]
75
+
76
+ lat_list = get_slices(lat, 4)
77
+ lon_list = get_slices(lon, 6)
78
+
79
+ # Visualization
80
+ ims = []
81
+ fig = plt.figure(figsize=(7,5))
82
+ ax = plt.axes(projection=ccrs.PlateCarree())
83
+
84
+ gl = ax.gridlines(crs=ccrs.PlateCarree(),
85
+ draw_labels=True,
86
+ linewidth=0.3,
87
+ color='black',
88
+ linestyle='--')
89
+
90
+ gl.top_labels = False
91
+ gl.right_labels = False
92
+ gl.xlines = True
93
+ gl.xlocator = mticker.FixedLocator(lon_list)
94
+ gl.ylocator = mticker.FixedLocator(lat_list)
95
+ gl.xformatter = LONGITUDE_FORMATTER
96
+ gl.yformatter = LATITUDE_FORMATTER
97
+ gl.xlabel_style = {'size':10, 'color':'black'}
98
+ gl.ylabel_style = {'size':10, 'color':'black'}
99
+
100
+ frames = X
101
+ frames[frames<=threshold] = np.nan
102
+
103
+ barra = np.arange(0, vmax+1, 5)
104
+
105
+ for i in range(frames.shape[2]):
106
+ im = plt.imshow(frames[..., i],
107
+ cmap=plt.cm.rainbow,
108
+ vmin=0,
109
+ vmax=vmax,
110
+ extent=area,
111
+ origin='lower',
112
+ animated=True)
113
+
114
+ ims.append([im])
115
+
116
+ cbar = plt.colorbar(ax=ax, pad=0.02, aspect=16, shrink=0.77)
117
+ cbar.set_ticks(barra)
118
+ cbar.set_label('mm/h')
119
+
120
+ shapeID = shp.Reader("data/shapefile/regiao_sul.shp")
121
+
122
+ for shape in shapeID.shapeRecords():
123
+ point = np.array( shape.shape.points )
124
+ dummy = plt.plot( point[:,0] , point[:,1], color="black", linewidth=0.5 ) # 1
125
+
126
+ ani = animation.ArtistAnimation(fig, ims, interval=500, blit=True, repeat_delay=1000)
127
+
128
+ FFwriter = animation.FFMpegWriter(fps=2)
129
+ ani.save(file_path, writer = FFwriter)
130
+
131
+ # ani.save(f'data/vis.gif', writer='pillow', fps=6)
132
+
133
+ plt.close(ani._fig)
134
+ HTML(ani.to_html5_video())
135
+
136
+ def make_predictions(X):
137
+
138
+ filepath = "models/model.h5"
139
+ model = tf.keras.models.load_model(filepath)
140
+
141
+ X[np.isnan(X)] = 0
142
+
143
+ X = np.expand_dims(X, axis=0)
144
+
145
+ scaler_path = 'models/scaler.pkl'
146
+ with open(scaler_path, 'rb') as f:
147
+ sc = pickle.load(f)
148
+
149
+ X = sc.transform(X)
150
+ ypred = model.predict(X)
151
+ print(ypred.shape)
152
+ ypred = sc.inverse_transform(ypred)[0]
153
+ print(ypred.shape)
154
+
155
+ return ypred