venite commited on
Commit
f670afc
1 Parent(s): 29fd2b4
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +122 -13
  2. __pycache__/options.cpython-38.pyc +0 -0
  3. __pycache__/test.cpython-38.pyc +0 -0
  4. __pycache__/utils.cpython-38.pyc +0 -0
  5. app.py +244 -0
  6. data/CVACT_Shi.py +119 -0
  7. data/CVUSA.py +86 -0
  8. dataset/INSTALL.md +32 -0
  9. demo_img/case1/groundview.image.png +0 -0
  10. demo_img/case1/groundview.sky.png +0 -0
  11. demo_img/case1/satview-input.png +0 -0
  12. demo_img/case10/groundview.image.png +0 -0
  13. demo_img/case10/groundview.sky.png +0 -0
  14. demo_img/case10/satview-input.png +0 -0
  15. demo_img/case11/groundview.image.png +0 -0
  16. demo_img/case11/groundview.sky.png +0 -0
  17. demo_img/case11/satview-input.png +0 -0
  18. demo_img/case12/groundview.image.png +0 -0
  19. demo_img/case12/groundview.sky.png +0 -0
  20. demo_img/case12/satview-input.png +0 -0
  21. demo_img/case13/groundview.image.png +0 -0
  22. demo_img/case13/groundview.sky.png +0 -0
  23. demo_img/case13/satview-input.png +0 -0
  24. demo_img/case2/groundview.image.png +0 -0
  25. demo_img/case2/groundview.sky.png +0 -0
  26. demo_img/case2/satview-input.png +0 -0
  27. demo_img/case3/groundview.image.png +0 -0
  28. demo_img/case3/groundview.sky.png +0 -0
  29. demo_img/case3/satview-input.png +0 -0
  30. demo_img/case4/groundview.image.png +0 -0
  31. demo_img/case4/groundview.sky.png +0 -0
  32. demo_img/case4/satview-input.png +0 -0
  33. demo_img/case5/groundview.image.png +0 -0
  34. demo_img/case5/groundview.sky.png +0 -0
  35. demo_img/case5/satview-input.png +0 -0
  36. demo_img/case6/groundview.image.png +0 -0
  37. demo_img/case6/groundview.sky.png +0 -0
  38. demo_img/case6/satview-input.png +0 -0
  39. demo_img/case7/groundview.image.png +0 -0
  40. demo_img/case7/groundview.sky.png +0 -0
  41. demo_img/case7/satview-input.png +0 -0
  42. demo_img/case8/groundview.image.png +0 -0
  43. demo_img/case8/groundview.sky.png +0 -0
  44. demo_img/case8/satview-input.png +0 -0
  45. demo_img/case9/groundview.image.png +0 -0
  46. demo_img/case9/groundview.sky.png +0 -0
  47. demo_img/case9/satview-input.png +0 -0
  48. demo_img/runall.sh +30 -0
  49. imaginaire/__init__.py +4 -0
  50. imaginaire/__pycache__/__init__.cpython-38.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,122 @@
1
- ---
2
- title: Sat3density
3
- emoji: 🏆
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.41.2
8
- app_file: app.py
9
- pinned: false
10
- license: other
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs
2
+
3
+ > [Ming Qian](https://qianmingduowan.github.io/), Jincheng Xiong, [Gui-Song Xia](http://www.captain-whu.com/xia_En.html), [Nan Xue](https://xuenan.net)
4
+ >
5
+ > IEEE/CVF International Conference on Computer Vision (ICCV), 2023
6
+ >
7
+ > [Project](https://sat2density.github.io/) | [Paper](https://arxiv.org/abs/2303.14672) | [Data]() | [Install.md](docs/INSTALL.md)
8
+
9
+ > <p align="center" float="left">
10
+ > <img src="docs/figures/demo/case1.sat.gif" alt="drawing" width="19%">
11
+ > <img src="docs/figures/demo-density/case1.gif" alt="drawing" width="38%">
12
+ > <img src="docs/figures/demo/case1.render.gif" alt="drawing" width="38%">
13
+ > </p>
14
+
15
+ > <p align="center" float="left">
16
+ > <img src="docs/figures/demo/case2.sat.gif" alt="drawing" width="19%">
17
+ > <img src="docs/figures/demo-density/case2.gif" alt="drawing" width="38%">
18
+ > <img src="docs/figures/demo/case2.render.gif" alt="drawing" width="38%">
19
+ > </p>
20
+
21
+ > <p align="center" float="left">
22
+ > <img src="docs/figures/demo/case3.sat.gif" alt="drawing" width="19%">
23
+ > <img src="docs/figures/demo-density/case3.gif" alt="drawing" width="38%">
24
+ > <img src="docs/figures/demo/case3.render.gif" alt="drawing" width="38%">
25
+ > </p>
26
+
27
+ > <p align="center" float="left">
28
+ > <img src="docs/figures/demo/case4.sat.gif" alt="drawing" width="19%">
29
+ > <img src="docs/figures/demo-density/case4.gif" alt="drawing" width="38%">
30
+ > <img src="docs/figures/demo/case4.render.gif" alt="drawing" width="38%">
31
+ > </p>
32
+
33
+ ## Checkpoints Downloading
34
+ > Two checkpoints for CVACT and CVUSA can be found from [thisurl](https://github.com/sat2density/checkpoints/releases). You can also run the following command to download them.
35
+ ```
36
+ bash scripts/download_weights.sh
37
+ ```
38
+
39
+ ## QuickStart Demo
40
+ ### Video Synthesis
41
+ #### Example Usage
42
+ ```
43
+ python test.py --yaml=sat2density_cvact \
44
+ --test_ckpt_path=2u87bj8w \
45
+ --task=test_vid \
46
+ --demo_img=demo_img/case1/satview-input.png \
47
+ --sty_img=demo_img/case1/groundview.image.png \
48
+ --save_dir=results/case1
49
+ ```
50
+ ####
51
+
52
+ ### Illumination Interpolation
53
+ <!-- ```
54
+ bash inference/quick_demo_interpolation.sh
55
+ ``` -->
56
+ ```
57
+ python test.py --task=test_interpolation \
58
+ --yaml=sat2density_cvact \
59
+ --test_ckpt_path=2u87bj8w \
60
+ --sty_img1=demo_img/case9/groundview.image.png \
61
+ --sty_img2=demo_img/case7/groundview.image.png \
62
+ --demo_img=demo_img/case3/satview-input.png \
63
+ --save_dir=results/case2
64
+ ```
65
+
66
+ ## Train & Inference
67
+ - *We trained our model using 1 V100 32GB GPU. The training phase will take about 20 hours.*
68
+ - *For data preparation, please check out [data.md](dataset/INSTALL.md).*
69
+
70
+
71
+
72
+
73
+ ### Inference
74
+
75
+ To test Center Ground-View Synthesis setting
76
+ If you want save results, please add --task=vis_test
77
+ ```bash
78
+ # CVACT
79
+ python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w
80
+ # CVUSA
81
+ python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4
82
+ ```
83
+
84
+ To test inference with different illumination
85
+ ```bash
86
+ # CVACT
87
+ bash inference/single_style_test_cvact.sh
88
+ # CVUSA
89
+ bash inference/single_style_test_cvusa.sh
90
+ ```
91
+
92
+ To test synthesis ground videos
93
+ ```bash
94
+ bash inference/synthesis_video.sh
95
+ ```
96
+
97
+ ## Training
98
+
99
+ ### Training command
100
+
101
+ ```bash
102
+ # CVACT
103
+ CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvact
104
+ # CVUSA
105
+ CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvusa
106
+ ```
107
+
108
+ ## Citation
109
+ If you use this code for your research, please cite
110
+
111
+ ```
112
+ @inproceedings{qian2021sat2density,
113
+ title={Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs},
114
+ author={Qian, Ming and Xiong, Jincheng and Xia, Gui-Song and Xue, Nan},
115
+ booktitle={ICCV},
116
+ year={2023}
117
+ }
118
+ ```
119
+
120
+ ## License
121
+ This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
122
+ For commercial use, please contact [mingqian@whu.edu.cn].
__pycache__/options.cpython-38.pyc ADDED
Binary file (3.74 kB). View file
 
__pycache__/test.cpython-38.pyc ADDED
Binary file (9.11 kB). View file
 
__pycache__/utils.cpython-38.pyc ADDED
Binary file (8.16 kB). View file
 
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ import options
8
+ import test
9
+ import importlib
10
+ from scipy.interpolate import interp1d, splev, splprep
11
+ import cv2
12
+
13
+
14
+ def get_single(sat_img, style_img, x_offset, y_offset):
15
+ name = ''
16
+ for i in [name for name in os.listdir('demo_img') if 'case' in name]:
17
+ style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
18
+ style =np.array(style)
19
+ if (style == style_img).all():
20
+ name = i
21
+ break
22
+
23
+ input_dict = {}
24
+ trans = transforms.ToTensor()
25
+ input_dict['sat'] = trans(sat_img)
26
+ input_dict['pano'] = trans(style_img)
27
+ input_dict['paths'] = "demo.png"
28
+ sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
29
+ input_a = input_dict['pano']*sky
30
+ sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
31
+ input_dict['sky_histc'] = sky_histc
32
+ input_dict['sky_mask'] = sky
33
+
34
+ for key in input_dict.keys():
35
+ if isinstance(input_dict[key], torch.Tensor):
36
+ input_dict[key] = input_dict[key].unsqueeze(0)
37
+
38
+ args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
39
+ "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
40
+ opt_cmd = options.parse_arguments(args=args)
41
+ opt = options.set(opt_cmd=opt_cmd)
42
+ opt.isTrain = False
43
+ opt.name = opt.yaml if opt.name is None else opt.name
44
+ opt.batch_size = 1
45
+
46
+ m = importlib.import_module("model.{}".format(opt.model))
47
+ model = m.Model(opt)
48
+
49
+ # m.load_dataset(opt)
50
+ model.build_networks(opt)
51
+ ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
52
+ model.netG.load_state_dict(ckpt['netG'])
53
+ model.netG.eval()
54
+
55
+ model.set_input(input_dict)
56
+
57
+ model.style_temp = model.sky_histc
58
+ opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future
59
+
60
+ model.forward(opt)
61
+
62
+ rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
63
+ rgb = np.array(rgb*255, dtype=np.uint8)
64
+ return rgb
65
+
66
+ def get_video(sat_img, style_img, positions):
67
+ name = ''
68
+ for i in [name for name in os.listdir('demo_img') if 'case' in name]:
69
+ style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
70
+ style =np.array(style)
71
+ if (style == style_img).all():
72
+ name = i
73
+ break
74
+
75
+ input_dict = {}
76
+ trans = transforms.ToTensor()
77
+ input_dict['sat'] = trans(sat_img)
78
+ input_dict['pano'] = trans(style_img)
79
+ input_dict['paths'] = "demo.png"
80
+ sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
81
+ input_a = input_dict['pano']*sky
82
+ sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
83
+ input_dict['sky_histc'] = sky_histc
84
+ input_dict['sky_mask'] = sky
85
+
86
+ for key in input_dict.keys():
87
+ if isinstance(input_dict[key], torch.Tensor):
88
+ input_dict[key] = input_dict[key].unsqueeze(0)
89
+
90
+ args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
91
+ "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
92
+ opt_cmd = options.parse_arguments(args=args)
93
+ opt = options.set(opt_cmd=opt_cmd)
94
+ opt.isTrain = False
95
+ opt.name = opt.yaml if opt.name is None else opt.name
96
+ opt.batch_size = 1
97
+
98
+ m = importlib.import_module("model.{}".format(opt.model))
99
+ model = m.Model(opt)
100
+
101
+ # m.load_dataset(opt)
102
+ model.build_networks(opt)
103
+ ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
104
+ model.netG.load_state_dict(ckpt['netG'])
105
+ model.netG.eval()
106
+
107
+ model.set_input(input_dict)
108
+
109
+ model.style_temp = model.sky_histc
110
+
111
+ unique_lst = list(dict.fromkeys(positions))
112
+ pixels = []
113
+ for x in positions:
114
+ if x in unique_lst:
115
+ if x not in pixels:
116
+ pixels.append(x)
117
+ pixels = np.array(pixels)
118
+ tck, u = splprep(pixels.T, s=25, per=0)
119
+ u_new = np.linspace(u.min(), u.max(), 80)
120
+ x_new, y_new = splev(u_new, tck)
121
+ smooth_path = np.array([x_new,y_new]).T
122
+
123
+ rendered_image_list = []
124
+ rendered_depth_list = []
125
+
126
+
127
+ for i, (x,y) in enumerate(smooth_path):
128
+ opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
129
+ print('Rendering at ({}, {})'.format(x,y))
130
+ model.forward(opt)
131
+
132
+ rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
133
+ rgb = np.array(rgb*255, dtype=np.uint8)
134
+ rendered_image_list.append(rgb)
135
+
136
+ rendered_depth_list.append(
137
+ model.out_put.depth[0,0].cpu().detach().numpy()
138
+ )
139
+
140
+ output_video_path = 'output_video.mp4'
141
+
142
+ # 设置视频的帧率、宽度和高度
143
+ frame_rate = 15
144
+ frame_width = 512
145
+ frame_height = 128
146
+
147
+ # 使用OpenCV创建视频写入对象,选择H.264编码器
148
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
149
+ out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))
150
+
151
+ # 遍历图像列表并将它们写入视频
152
+ for image_np in rendered_image_list:
153
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
154
+ out.write(image_np)
155
+
156
+ # 释放视频写入对象
157
+ out.release()
158
+
159
+ return "output_video.mp4"
160
+
161
+ def copy_image(image):
162
+ return image
163
+
164
+ def show_image_and_point(image, x, y):
165
+ x = int(x*image.shape[1])
166
+ y = image.shape[0]-int(y*image.shape[0])
167
+ mask = np.zeros(image.shape[:2])
168
+ radius = min(image.shape[0], image.shape[1])//60
169
+ for i in range(x-radius-2, x+radius+2):
170
+ for j in range(y-radius-2, y+radius+2):
171
+ if (i-x)**2+(j-y)**2<=radius**2:
172
+ mask[j, i] = 1
173
+ return (image, [(mask, 'render point')])
174
+
175
+ def add_select_point(image, evt: gr.SelectData, state1):
176
+ if state1 == None:
177
+ state1 = []
178
+ x, y = evt.index
179
+ state1.append((x, y))
180
+ print(state1)
181
+ radius = min(image.shape[0], image.shape[1])//60
182
+ for i in range(x-radius-2, x+radius+2):
183
+ for j in range(y-radius-2, y+radius+2):
184
+ if (i-x)**2+(j-y)**2<=radius**2:
185
+ image[j, i, :] = 0
186
+ return image, state1
187
+
188
+ def reset_select_points(image):
189
+ return image, []
190
+
191
+
192
+
193
+
194
+
195
+
196
+ with gr.Blocks() as demo:
197
+ gr.Markdown("# Sat2Density Demos")
198
+ gr.Markdown("### select/upload the satllite image and select the style image")
199
+ with gr.Row():
200
+ with gr.Column():
201
+ sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True)
202
+ img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
203
+ inputs=sat_img, outputs=None, examples_per_page=20)
204
+ with gr.Column():
205
+ style_img = gr.Image()
206
+ style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
207
+ inputs=style_img, outputs=None, examples_per_page=20)
208
+
209
+
210
+ gr.Markdown("### select a certain point to generate single groundview image")
211
+ with gr.Row():
212
+ with gr.Column():
213
+ with gr.Row():
214
+ with gr.Column():
215
+ slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position")
216
+ slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position")
217
+ btn_single = gr.Button(label="demo1")
218
+
219
+ annotation_image = gr.AnnotatedImage()
220
+
221
+ out_single = gr.Image()
222
+
223
+ gr.Markdown("### draw a trajectory on the map to generate video")
224
+ state_select_points = gr.State()
225
+ with gr.Row():
226
+ with gr.Column():
227
+ draw_img = gr.Image(shape=[256, 256], interactive=True)
228
+ with gr.Column():
229
+ out_video = gr.Video()
230
+ reset_btn =gr.Button(value="Reset")
231
+ btn_video = gr.Button(label="demo1")
232
+
233
+ sat_img.change(copy_image, inputs = sat_img, outputs=draw_img)
234
+
235
+ draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points])
236
+ sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image)
237
+ slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
238
+ slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
239
+ btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single)
240
+ reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points])
241
+ btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发
242
+
243
+
244
+ demo.launch()
data/CVACT_Shi.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,os
2
+ from torch.utils.data.dataset import Dataset
3
+ from PIL import Image
4
+ import scipy.io as sio
5
+ import torchvision.transforms as transforms
6
+
7
+ def data_list(img_root,mode):
8
+ exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct'))
9
+ exist_grd_list = os.listdir(os.path.join(img_root , 'streetview'))
10
+ allDataList = os.path.join(img_root, 'ACT_data.mat')
11
+ anuData = sio.loadmat(allDataList)
12
+
13
+ all_data_list = []
14
+ for i in range(0, len(anuData['panoIds'])):
15
+ grd_id_align = anuData['panoIds'][i] + '_grdView.png'
16
+ sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png'
17
+ all_data_list.append([grd_id_align, sat_id_ori])
18
+
19
+ data_list = []
20
+
21
+ if mode=='train':
22
+ training_inds = anuData['trainSet']['trainInd'][0][0] - 1
23
+ trainNum = len(training_inds)
24
+ for k in range(trainNum):
25
+ data_list.append(all_data_list[training_inds[k][0]])
26
+ else:
27
+ val_inds = anuData['valSet']['valInd'][0][0] - 1
28
+ valNum = len(val_inds)
29
+ for k in range(valNum):
30
+ data_list.append(all_data_list[val_inds[k][0]])
31
+
32
+
33
+ pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list]
34
+
35
+ return pano_list
36
+
37
+ def img_read(img,size=None,datatype='RGB'):
38
+ img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
39
+ if size:
40
+ if type(size) is int:
41
+ size = (size,size)
42
+ img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
43
+ img = transforms.ToTensor()(img)
44
+ return img
45
+
46
+
47
+ class Dataset(Dataset):
48
+ def __init__(self, opt,split='train',sub=None,sty_img=None):
49
+ if sty_img:
50
+ assert sty_img.endswith('grdView.png')
51
+ demo_img_path = os.path.join(opt.data.root,'streetview',sty_img)
52
+ self.pano_list = [demo_img_path]
53
+
54
+ elif opt.task in ['test_vid','test_interpolation'] :
55
+ demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png'))
56
+ self.pano_list = [demo_img_path]
57
+
58
+ else:
59
+ self.pano_list = data_list(img_root=opt.data.root,mode=split)
60
+ if sub:
61
+ self.pano_list = self.pano_list[:sub]
62
+
63
+ # select some ground images to test the influence of different skys.
64
+ # different skys guide different illumination intensity, colors, and etc.
65
+ if opt.task == 'test_sty':
66
+ demo_name = [
67
+ 'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png',
68
+ 'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png',
69
+ 'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png',
70
+ 'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png',
71
+ 'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png',
72
+ 'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png',
73
+ 'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png',
74
+ 'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png',
75
+ 'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png',
76
+ 'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png',
77
+ 'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png',
78
+ 'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png',
79
+ 'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png',
80
+ 'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png',
81
+ 'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png',
82
+ ]
83
+ self.pano_list = demo_name
84
+
85
+ self.opt = opt
86
+
87
+ def __len__(self):
88
+ return len(self.pano_list)
89
+
90
+ def __getitem__(self, index):
91
+ pano = self.pano_list[index]
92
+ aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish')
93
+ if self.opt.data.sky_mask:
94
+ sky = pano.replace('streetview','pano_sky_mask')
95
+ name = pano
96
+ aer = img_read(aer, size = self.opt.data.sat_size)
97
+ pano = img_read(pano,size = self.opt.data.pano_size)
98
+ if self.opt.data.sky_mask:
99
+ sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
100
+
101
+ input = {}
102
+ input['sat']=aer
103
+ input['pano']=pano
104
+ input['paths']=name
105
+ if self.opt.data.sky_mask:
106
+ input['sky_mask']=sky
107
+ black_ground = torch.zeros_like(pano)
108
+ if self.opt.data.histo_mode =='grey':
109
+ input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
110
+ elif self.opt.data.histo_mode in ['rgb','RGB']:
111
+ input_a = (pano*sky+black_ground*(1-sky))
112
+ for idx in range(len(input_a)):
113
+ if idx == 0:
114
+ sky_histc = input_a[idx].histc()[10:]
115
+ else:
116
+ sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
117
+ input['sky_histc'] = sky_histc
118
+ return input
119
+
data/CVUSA.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch,os
2
+ from torch.utils.data.dataset import Dataset
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import re
6
+ from easydict import EasyDict as edict
7
+
8
+ def data_list(img_root,mode):
9
+ data_list=[]
10
+ if mode=='train':
11
+ split_file=os.path.join(img_root, 'splits/train-19zl.csv')
12
+ with open(split_file) as f:
13
+ list = f.readlines()
14
+ for i in list:
15
+ aerial_name=re.split(r',', re.split('\n', i)[0])[0]
16
+ panorama_name = re.split(r',', re.split('\n', i)[0])[1]
17
+ data_list.append([aerial_name, panorama_name])
18
+ else:
19
+ split_file=os.path.join(img_root+'splits/val-19zl.csv')
20
+ with open(split_file) as f:
21
+ list = f.readlines()
22
+ for i in list:
23
+ aerial_name=re.split(r',', re.split('\n', i)[0])[0]
24
+ panorama_name = re.split(r',', re.split('\n', i)[0])[1]
25
+ data_list.append([aerial_name, panorama_name])
26
+ print('length of dataset is: ', len(data_list))
27
+ return [os.path.join(img_root, i[1]) for i in data_list]
28
+
29
+ def img_read(img,size=None,datatype='RGB'):
30
+ img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
31
+ if size:
32
+ if type(size) is int:
33
+ size = (size,size)
34
+ img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
35
+ img = transforms.ToTensor()(img)
36
+ return img
37
+
38
+
39
+ class Dataset(Dataset):
40
+ def __init__(self, opt,split='train',sub=None,sty_img=None):
41
+ self.pano_list = data_list(img_root=opt.data.root,mode=split)
42
+ if sub:
43
+ self.pano_list = self.pano_list[:sub]
44
+ if opt.task == 'test_vid':
45
+ demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img)
46
+ self.pano_list = [demo_img_path]
47
+ if sty_img:
48
+ assert opt.sty_img.split('.')[-1] == 'jpg'
49
+ demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img)
50
+ self.pano_list = [demo_img_path]
51
+
52
+ self.opt = opt
53
+
54
+ def __len__(self):
55
+ return len(self.pano_list)
56
+
57
+ def __getitem__(self, index):
58
+ pano = self.pano_list[index]
59
+ aer = pano.replace('streetview/panos', 'bingmap/19')
60
+ if self.opt.data.sky_mask:
61
+ sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png')
62
+ name = pano
63
+ aer = img_read(aer, size = self.opt.data.sat_size)
64
+ pano = img_read(pano,size = self.opt.data.pano_size)
65
+ if self.opt.data.sky_mask:
66
+ sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
67
+
68
+ input = {}
69
+ input['sat']=aer
70
+ input['pano']=pano
71
+ input['paths']=name
72
+ if self.opt.data.sky_mask:
73
+ input['sky_mask']=sky
74
+ black_ground = torch.zeros_like(pano)
75
+ if self.opt.data.histo_mode =='grey':
76
+ input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
77
+ elif self.opt.data.histo_mode in ['rgb','RGB']:
78
+ input_a = (pano*sky+black_ground*(1-sky))
79
+ for idx in range(len(input_a)):
80
+ if idx == 0:
81
+ sky_histc = input_a[idx].histc()[10:]
82
+ else:
83
+ sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
84
+ input['sky_histc'] = sky_histc
85
+ return input
86
+
dataset/INSTALL.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ For reproduce our paper,
2
+
3
+ you should first download 4 zip file:
4
+
5
+ `
6
+ CVACT/satview_correct.zip ,
7
+ CVACT/streetview.zip ,
8
+ CVUSA/bingmap/19.zip ,
9
+ CVUSA/streetview/panos.zip
10
+ `
11
+ from [here](https://anu365-my.sharepoint.com/:f:/g/personal/u6293587_anu_edu_au/EuOBUDUQNClJvCpQ8bD1hnoBjdRBWxsHOVp946YVahiMGg?e=F4yRAC), the project page is [Sat2StrPanoramaSynthesis](https://github.com/shiyujiao/Sat2StrPanoramaSynthesis).
12
+
13
+ Then download the sky mask from [here](https://drive.google.com/drive/folders/1pfzwONg4P-Mzvxvzb2HoCpuZFynElPCk?usp=sharing)
14
+
15
+ Last,the users should organize the dataset just like:
16
+ ```
17
+ ├dataset
18
+ ├── CVACT
19
+ │ ├── streetview
20
+ │ ├── satview_correct
21
+ │ ├── pano_sky_mask
22
+ │ ├── ACT_data.mat
23
+ └── CVUSA
24
+ │ ├── bingmap
25
+ │ │ ├── 19
26
+ │ └── streetview
27
+ │ │ ├── panos
28
+ │ ├── sky_mask
29
+ │ ├── splits
30
+ ```
31
+
32
+ Tip: The sky masks are processed with [Trans4PASS](https://github.com/jamycheung/Trans4PASS).
demo_img/case1/groundview.image.png ADDED
demo_img/case1/groundview.sky.png ADDED
demo_img/case1/satview-input.png ADDED
demo_img/case10/groundview.image.png ADDED
demo_img/case10/groundview.sky.png ADDED
demo_img/case10/satview-input.png ADDED
demo_img/case11/groundview.image.png ADDED
demo_img/case11/groundview.sky.png ADDED
demo_img/case11/satview-input.png ADDED
demo_img/case12/groundview.image.png ADDED
demo_img/case12/groundview.sky.png ADDED
demo_img/case12/satview-input.png ADDED
demo_img/case13/groundview.image.png ADDED
demo_img/case13/groundview.sky.png ADDED
demo_img/case13/satview-input.png ADDED
demo_img/case2/groundview.image.png ADDED
demo_img/case2/groundview.sky.png ADDED
demo_img/case2/satview-input.png ADDED
demo_img/case3/groundview.image.png ADDED
demo_img/case3/groundview.sky.png ADDED
demo_img/case3/satview-input.png ADDED
demo_img/case4/groundview.image.png ADDED
demo_img/case4/groundview.sky.png ADDED
demo_img/case4/satview-input.png ADDED
demo_img/case5/groundview.image.png ADDED
demo_img/case5/groundview.sky.png ADDED
demo_img/case5/satview-input.png ADDED
demo_img/case6/groundview.image.png ADDED
demo_img/case6/groundview.sky.png ADDED
demo_img/case6/satview-input.png ADDED
demo_img/case7/groundview.image.png ADDED
demo_img/case7/groundview.sky.png ADDED
demo_img/case7/satview-input.png ADDED
demo_img/case8/groundview.image.png ADDED
demo_img/case8/groundview.sky.png ADDED
demo_img/case8/satview-input.png ADDED
demo_img/case9/groundview.image.png ADDED
demo_img/case9/groundview.sky.png ADDED
demo_img/case9/satview-input.png ADDED
demo_img/runall.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for case in `ls -d demo_img/case*`
2
+ for case_id in 1 2 3 4
3
+ do
4
+ case=demo_img/case$case_id
5
+ echo $case
6
+ python test.py --yaml=sat2density_cvact \
7
+ --test_ckpt_path=2u87bj8w \
8
+ --task=test_vid \
9
+ --demo_img=$case/satview-input.png \
10
+ --sty_img=$case/groundview.image.png \
11
+ --save_dir=results/$case
12
+ # ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png results/$case/render.gif
13
+ ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -vf "palettegen" results/$case-palette.png
14
+ ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/render.gif
15
+
16
+ ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -vf "palettegen" results/$case-palette.png
17
+ ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/sat.gif
18
+ # ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png results/$case/sat.gif
19
+ done
20
+
21
+ # for case in `ls -d demo_img/case*`
22
+ for case_id in 1 2 3 4
23
+ do
24
+ case=demo_img/case$case_id
25
+ sat_gif=results/$case/sat.gif
26
+ render_gif=results/$case/render.gif
27
+ # echo $sat_gif
28
+ cp $sat_gif docs/figures/demo/case$case_id.sat.gif
29
+ cp $render_gif docs/figures/demo/case$case_id.render.gif
30
+ done
imaginaire/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is made available under the Nvidia Source Code License-NC.
4
+ # To view a copy of this license, check out LICENSE.md
imaginaire/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (135 Bytes). View file