gera-richarte commited on
Commit
2ef57a2
1 Parent(s): 3e43423

Big refactoring and extensive use of numpy

Browse files
Files changed (1) hide show
  1. app.py +132 -86
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from datasets import load_dataset, get_dataset_config_names
2
  from functools import partial
3
  from pandas import DataFrame
 
4
  import gradio as gr
5
  import numpy as np
6
  import tqdm
@@ -59,10 +60,74 @@ def open_dataset(dataset, set_name, split, batch_size, state, shard = -1):
59
  state
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def get_images(batch_size, state):
63
- items = []
64
- metadatas = []
65
 
 
 
 
66
  for i in tqdm.trange(batch_size, desc=f"Getting images"):
67
  if DEBUG:
68
  image = np.random.randint(0,255,(384,384,3))
@@ -73,40 +138,20 @@ def get_images(batch_size, state):
73
  except StopIteration:
74
  break
75
  metadata = item["metadata"]
76
- if state["config"] == "satellogic":
77
- # image = (np.asarray(item["1m"])).astype("uint8")
78
- # items.append(image[0,0,:,:])
79
- image = np.asarray(item["rgb"][0]).astype(np.uint8)
80
- items.append(image.transpose(1,2,0))
81
-
82
- if state["config"] == "sentinel_1":
83
- metadata = json.loads(metadata)
84
- data = np.asarray(item["10m"])
85
- for i in range(data.shape[0]):
86
- # Mapping of V and H to RGB. May not be correct
87
- # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
88
- image = np.zeros((3,384,384), "uint8")
89
- image[0] = data[i][0]
90
- image[1] = data[i][1]
91
- image[2] = (image[0]/(image[1]+0.1))*256
92
- items.append(image.transpose(1,2,0))
93
-
94
- if state["config"] == "default":
95
- dataRGB = np.asarray(item["rgb"]).astype("uint8")
96
- dataCHM = np.asarray(item["chm"]).astype("uint8")
97
- data1m = np.asarray(item["1m"]).astype("uint8")
98
- for i in range(dataRGB.shape[0]):
99
- image = dataRGB[i,:,:,:]
100
- items.append(image.transpose(1,2,0))
101
-
102
- image = dataCHM[i,0,:,:]
103
- items.append(image)
104
-
105
- image = data1m[i,0,:,:]
106
- items.append(image)
107
- metadatas.append(metadata)
108
-
109
- return items, DataFrame(metadatas)
110
 
111
  def update_shape(rows, columns):
112
  return gr.update(rows=rows, columns=columns)
@@ -114,53 +159,54 @@ def update_shape(rows, columns):
114
  def new_state():
115
  return gr.State({})
116
 
117
- with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo:
118
- state = new_state()
119
-
120
- gr.Markdown(f"# Viewer for [{DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
121
- batch_size = gr.Number(10, label = "Batch Size", render=False)
122
- shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
123
- table = gr.DataFrame(render = False)
124
- # headers=["Index","TimeStamp","Bounds","CRS"],
125
-
126
- gallery = gr.Gallery(
127
- label=DATASET,
128
- interactive=False,
129
- columns=5, rows=2, render=False)
130
-
131
- with gr.Row():
132
- dataset = gr.Textbox(label="Dataset", value=DATASET, interactive=False)
133
- config = gr.Dropdown(choices=get_dataset_config_names(DATASET), label="Config", value="satellogic", )
134
- split = gr.Textbox(label="Split", value="train")
135
- initial_shard = gr.Number(label = "Initial shard", value=0, info="-1 for whole dataset")
136
-
137
- gr.Button("Load (minutes)").click(
138
- open_dataset,
139
- inputs=[dataset, config, split, batch_size, state, initial_shard],
140
- outputs=[shard, gallery, table, state])
141
-
142
- gallery.render()
143
-
144
- with gr.Row():
145
- batch_size.render()
146
-
147
- rows = gr.Number(2, label="Rows")
148
- columns = gr.Number(5, label="Coluns")
149
-
150
- rows.change(update_shape, [rows, columns], [gallery])
151
- columns.change(update_shape, [rows, columns], [gallery])
152
-
153
- with gr.Row():
154
- shard.render()
155
- shard.release(
156
- open_dataset,
157
- inputs=[dataset, config, split, batch_size, state, shard],
158
- outputs=[shard, gallery, table, state])
159
-
160
- btn = gr.Button("Next Batch (same shard)", scale=0)
161
- btn.click(get_images, [batch_size, state], [gallery, table])
162
- btn.click()
163
-
164
- table.render()
165
-
166
- demo.launch(show_api=False)
 
 
1
  from datasets import load_dataset, get_dataset_config_names
2
  from functools import partial
3
  from pandas import DataFrame
4
+ from PIL import Image
5
  import gradio as gr
6
  import numpy as np
7
  import tqdm
 
60
  state
61
  )
62
 
63
+ def item_to_images(config, item):
64
+ metadata = item["metadata"]
65
+ if type(metadata) == str:
66
+ metadata = json.loads(metadata)
67
+
68
+ item = {
69
+ k: np.asarray(v).astype("uint8")
70
+ for k,v in item.items()
71
+ if k != "metadata"
72
+ }
73
+ item["metadata"] = metadata
74
+
75
+ if config == "satellogic":
76
+ item["rgb"] = [
77
+ Image.fromarray(image.transpose(1,2,0))
78
+ for image in item["rgb"]
79
+ ]
80
+ item["1m"] = [
81
+ Image.fromarray(image[0,:,:])
82
+ for image in item["1m"]
83
+ ]
84
+ elif config == "sentinel_1":
85
+ # Mapping of V and H to RGB. May not be correct
86
+ # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels
87
+ i10m = item["10m"]
88
+ i10m = np.concatenate(
89
+ ( i10m,
90
+ np.expand_dims(
91
+ i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256,
92
+ 1
93
+ ).astype("uint8")
94
+ ),
95
+ 1
96
+ )
97
+ item["10m"] = [
98
+ Image.fromarray(image.transpose(1,2,0))
99
+ for image in i10m
100
+ ]
101
+ elif config == "default":
102
+ item["rgb"] = [
103
+ Image.fromarray(image.transpose(1,2,0))
104
+ for image in item["rgb"]
105
+ ]
106
+ item["chm"] = [
107
+ Image.fromarray(image[0])
108
+ for image in item["chm"]
109
+ ]
110
+
111
+ # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB
112
+ # It just averages each 1/3 of the bads and assigns it to a channel
113
+ item["1m"] = [
114
+ Image.fromarray(
115
+ np.concatenate((
116
+ np.expand_dims(np.average(image[:124],0),2),
117
+ np.expand_dims(np.average(image[124:247],0),2),
118
+ np.expand_dims(np.average(image[247:],0),2))
119
+ ,2).astype("uint8"))
120
+ for image in item["1m"]
121
+ ]
122
+ return item
123
+
124
+
125
  def get_images(batch_size, state):
126
+ config = state["config"]
 
127
 
128
+ images = []
129
+ metadatas = []
130
+
131
  for i in tqdm.trange(batch_size, desc=f"Getting images"):
132
  if DEBUG:
133
  image = np.random.randint(0,255,(384,384,3))
 
138
  except StopIteration:
139
  break
140
  metadata = item["metadata"]
141
+ item = item_to_images(config, item)
142
+
143
+ if config == "satellogic":
144
+ images.extend(item["rgb"])
145
+ images.extend(item["1m"])
146
+ if config == "sentinel_1":
147
+ images.extend(item["10m"])
148
+ if config == "default":
149
+ images.extend(item["rgb"])
150
+ images.extend(item["chm"])
151
+ images.extend(item["1m"])
152
+ metadatas.append(item["metadata"])
153
+
154
+ return images, DataFrame(metadatas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def update_shape(rows, columns):
157
  return gr.update(rows=rows, columns=columns)
 
159
  def new_state():
160
  return gr.State({})
161
 
162
+ if __name__ == "__main__":
163
+ with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo:
164
+ state = new_state()
165
+
166
+ gr.Markdown(f"# Viewer for [{DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
167
+ batch_size = gr.Number(10, label = "Batch Size", render=False)
168
+ shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
169
+ table = gr.DataFrame(render = False)
170
+ # headers=["Index","TimeStamp","Bounds","CRS"],
171
+
172
+ gallery = gr.Gallery(
173
+ label=DATASET,
174
+ interactive=False,
175
+ columns=5, rows=2, render=False)
176
+
177
+ with gr.Row():
178
+ dataset = gr.Textbox(label="Dataset", value=DATASET, interactive=False)
179
+ config = gr.Dropdown(choices=sets.keys(), label="Config", value="satellogic", )
180
+ split = gr.Textbox(label="Split", value="train")
181
+ initial_shard = gr.Number(label = "Initial shard", value=0, info="-1 for whole dataset")
182
+
183
+ gr.Button("Load (minutes)").click(
184
+ open_dataset,
185
+ inputs=[dataset, config, split, batch_size, state, initial_shard],
186
+ outputs=[shard, gallery, table, state])
187
+
188
+ gallery.render()
189
+
190
+ with gr.Row():
191
+ batch_size.render()
192
+
193
+ rows = gr.Number(2, label="Rows")
194
+ columns = gr.Number(5, label="Coluns")
195
+
196
+ rows.change(update_shape, [rows, columns], [gallery])
197
+ columns.change(update_shape, [rows, columns], [gallery])
198
+
199
+ with gr.Row():
200
+ shard.render()
201
+ shard.release(
202
+ open_dataset,
203
+ inputs=[dataset, config, split, batch_size, state, shard],
204
+ outputs=[shard, gallery, table, state])
205
+
206
+ btn = gr.Button("Next Batch (same shard)", scale=0)
207
+ btn.click(get_images, [batch_size, state], [gallery, table])
208
+ btn.click()
209
+
210
+ table.render()
211
+
212
+ demo.launch(show_api=False)