incrl commited on
Commit
5b557cf
·
verified ·
1 Parent(s): 303fe96

Initial Upload (attempt 2)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. LICENSE +21 -0
  3. README.md +36 -14
  4. app.py +302 -0
  5. clusters.py +234 -0
  6. example-broche-rose-gold.splat +3 -0
  7. example.jpg +0 -0
  8. graph_helpers.py +400 -0
  9. graph_io.py +306 -0
  10. graph_networks/.DS_Store +0 -0
  11. graph_networks/LinearStyleTransfer/.DS_Store +0 -0
  12. graph_networks/LinearStyleTransfer/LICENSE +25 -0
  13. graph_networks/LinearStyleTransfer/README.md +102 -0
  14. graph_networks/LinearStyleTransfer/TestArtistic.py +98 -0
  15. graph_networks/LinearStyleTransfer/TestPhotoReal.py +118 -0
  16. graph_networks/LinearStyleTransfer/TestVideo.py +108 -0
  17. graph_networks/LinearStyleTransfer/Train.py +185 -0
  18. graph_networks/LinearStyleTransfer/TrainSPN.py +141 -0
  19. graph_networks/LinearStyleTransfer/__init__.py +0 -0
  20. graph_networks/LinearStyleTransfer/libs/.DS_Store +0 -0
  21. graph_networks/LinearStyleTransfer/libs/Criterion.py +62 -0
  22. graph_networks/LinearStyleTransfer/libs/Loader.py +44 -0
  23. graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py +162 -0
  24. graph_networks/LinearStyleTransfer/libs/Matrix.py +89 -0
  25. graph_networks/LinearStyleTransfer/libs/MatrixTest.py +154 -0
  26. graph_networks/LinearStyleTransfer/libs/SPN.py +156 -0
  27. graph_networks/LinearStyleTransfer/libs/__init__.py +0 -0
  28. graph_networks/LinearStyleTransfer/libs/models.py +662 -0
  29. graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md +12 -0
  30. graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py +0 -0
  31. graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py +0 -0
  32. graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py +15 -0
  33. graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py +34 -0
  34. graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py +0 -0
  35. graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py +47 -0
  36. graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py +46 -0
  37. graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh +9 -0
  38. graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py +1 -0
  39. graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py +12 -0
  40. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store +0 -0
  41. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu +697 -0
  42. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o +0 -0
  43. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h +28 -0
  44. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c +91 -0
  45. graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h +6 -0
  46. graph_networks/LinearStyleTransfer/libs/smooth_filter.py +407 -0
  47. graph_networks/LinearStyleTransfer/libs/utils.py +92 -0
  48. graph_networks/LinearStyleTransfer/models/dec_r31.pth +3 -0
  49. graph_networks/LinearStyleTransfer/models/dec_r41.pth +3 -0
  50. graph_networks/LinearStyleTransfer/models/r31.pth +3 -0
.gitattributes CHANGED
@@ -40,3 +40,10 @@ FastSplatStyler_huggingface/style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -
40
  FastSplatStyler_huggingface/style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
41
  FastSplatStyler_huggingface/style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
42
  FastSplatStyler_huggingface/style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
40
  FastSplatStyler_huggingface/style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
41
  FastSplatStyler_huggingface/style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
42
  FastSplatStyler_huggingface/style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
43
+ example-broche-rose-gold.splat filter=lfs diff=lfs merge=lfs -text
44
+ output.splat filter=lfs diff=lfs merge=lfs -text
45
+ style_ims/style-10.jpg filter=lfs diff=lfs merge=lfs -text
46
+ style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -text
47
+ style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
48
+ style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
49
+ style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ECU Computer Vision Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,36 @@
1
- ---
2
- title: FastSplatStyler
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 6.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Optimization-Free Style Transfer for 3D Gaussian Splats
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastSplatStyler
2
+ Official Implementation of "Optimization-Free Style Transfer of 3D Gaussian Splats"
3
+
4
+ [arXiv Paper](https://arxiv.org/abs/2508.05813)
5
+
6
+ ![](example.jpg)
7
+
8
+ ## Example Outputs
9
+
10
+ Example Outputs can be visualized using the [Antimatter WebGL viewer](https://antimatter15.com/splat/) at the following links.
11
+
12
+ - Broche: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_style3.splat)
13
+ - Crystal Lamp: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-style2.splat)
14
+ - Family Statue: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family-style-6.splat)
15
+ - M60 Tanks: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60-style-31.splat)
16
+ - Table: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table_style5.splat)
17
+ - Train: [Original](https://antimatter15.com/splat/) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/train_style1.splat)
18
+ - Truck: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck-style-21.splat)
19
+
20
+ Example inputs and outputs can be downloaded from [Google Drive](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link) or [Hugging Face](https://huggingface.co/datasets/incrl/fast-splat-styler/tree/main)
21
+
22
+ ## Demo
23
+
24
+ **Coming soon**
25
+
26
+ ## Install
27
+
28
+ This work relies heavily on the [Pytorch](https://pytorch.org/) and [Pytorch Geometric](https://www.pyg.org/) libraries.
29
+
30
+ This code was tested with Python 3.12, Pytorch 2.9.1 (CUDA Toolkit 12.8), and Pytorch Geometric 2.8 on Windows. It was also tested on Mac with the same setup (without Cuda). The provided requirements.txt file comes from the Mac configuration.
31
+
32
+ This repository relies on a graph networks library that was presented in a [previous work](https://github.com/davidmhart/interpolated-selectionconv/tree/main). The library can be downloaded, with included model weights, at this [google drive link](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link). Place the "graph_networks" folder in the main directory.
33
+
34
+ To stylize an example splat, run `python styletransfer_splat.py example-broche-rose-gold.splat --stylePath style_ims/style0.jpg --samplingRate 1.5`. You can change the input splat and style image to your specific use case.
35
+
36
+ Supports `.splat` and `.ply` files.
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import tempfile
5
+ import shutil
6
+ from pathlib import Path
7
+ from time import time
8
+
9
+ # ── Core style-transfer logic (adapted from styletransfer_splat.py) ──────────
10
+ import pointCloudToMesh as ply2M
11
+ import utils
12
+ import graph_io as gio
13
+ from clusters import *
14
+ import splat_mesh_helpers as splt
15
+ import clusters as cl
16
+ from torch_geometric.data import Data
17
+ from scipy.interpolate import NearestNDInterpolator
18
+
19
+ from graph_networks.LinearStyleTransfer_vgg import encoder, decoder
20
+ from graph_networks.LinearStyleTransfer_matrix import TransformLayer
21
+ from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
22
+ from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
23
+
24
+
25
+ # ── Example assets (place your own files in ./examples/) ─────────────────────
26
+ EXAMPLE_SPLATS = [
27
+ ["example-broche-rose-gold.splat", "style_ims/style2.jpg"],
28
+ ["example-broche-rose-gold.splat", "style_ims/style6.jpg"],
29
+ ]
30
+
31
+
32
+ # ── Style-transfer function called by Gradio ─────────────────────────────────
33
+ def run_style_transfer(
34
+ splat_file,
35
+ style_image,
36
+ threshold: float,
37
+ sampling_rate: float,
38
+ device_choice: str,
39
+ progress=gr.Progress(track_tqdm=True),
40
+ ):
41
+ if splat_file is None:
42
+ raise gr.Error("Please upload a 3D Gaussian Splat file (.ply or .splat).")
43
+ if style_image is None:
44
+ raise gr.Error("Please upload a style image.")
45
+
46
+ device = device_choice if device_choice == "cpu" else f"cuda:{device_choice}"
47
+
48
+ # ── Parameters ────────────────────────────────────────────────────────────
49
+ n = 25
50
+ ratio = 0.25
51
+ depth = 3
52
+ style_shape = (512, 512)
53
+
54
+ logs = []
55
+
56
+ def log(msg):
57
+ logs.append(msg)
58
+ print(msg)
59
+ return "\n".join(logs)
60
+
61
+ # ── 1. Load splat ─────────────────────────────────────────────────────────
62
+ progress(0.05, desc="Loading splat…")
63
+ splat_path = splat_file.name if hasattr(splat_file, "name") else splat_file
64
+ log(f"Loading splat: {splat_path}")
65
+
66
+ pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = \
67
+ splt.splat_unpacker_with_threshold(n, splat_path, threshold)
68
+
69
+ # ── 2. Gaussian super-sampling ────────────────────────────────────────────
70
+ progress(0.15, desc="Super-sampling…")
71
+ t0 = time()
72
+ if sampling_rate > 1:
73
+ GaussianSamples = int(pos3D_Original.shape[0] * sampling_rate)
74
+ pos3D, colors = splt.splat_GaussianSuperSampler(
75
+ pos3D_Original.clone(), colors_Original.clone(),
76
+ opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(),
77
+ GaussianSamples,
78
+ )
79
+ else:
80
+ pos3D, colors = pos3D_Original, colors_Original
81
+ log(f"Nodes in graph: {pos3D.shape[0]} ({time()-t0:.1f}s)")
82
+
83
+ # ── 3. Graph construction ─────────────────────────────────────────────────
84
+ progress(0.30, desc="Building surface graph…")
85
+ t0 = time()
86
+ style_ref = utils.loadImage(style_image, shape=style_shape)
87
+
88
+ normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
89
+ normals = torch.from_numpy(normalsNP)
90
+
91
+ up_vector = torch.tensor([[1, 1, 1]], dtype=torch.float)
92
+ up_vector = up_vector / torch.linalg.norm(up_vector, dim=1)
93
+
94
+ pos3D = pos3D.to(device)
95
+ colors = colors.to(device)
96
+ normals = normals.to(device)
97
+ up_vector = up_vector.to(device)
98
+
99
+ edge_index, directions = gh.surface2Edges(pos3D, normals, up_vector, k_neighbors=16)
100
+ edge_index, selections, interps = gh.edges2Selections(edge_index, directions, interpolated=True)
101
+
102
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(
103
+ pos3D, normals, edge_index, selections, interps,
104
+ ratio=ratio, up_vector=up_vector, depth=depth, device=device,
105
+ )
106
+ log(f"Graph built ({time()-t0:.1f}s)")
107
+
108
+ # ── 4. Load networks ──────────────────────────────────────────────────────
109
+ progress(0.50, desc="Loading networks…")
110
+ t0 = time()
111
+
112
+ enc_ref = encoder4()
113
+ dec_ref = decoder4()
114
+ matrix_ref = MulLayer("r41")
115
+
116
+ enc_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/vgg_r41.pth", map_location=device))
117
+ dec_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/dec_r41.pth", map_location=device))
118
+ matrix_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/r41.pth", map_location=device))
119
+
120
+ enc = encoder(padding_mode="replicate")
121
+ dec = decoder(padding_mode="replicate")
122
+ matrix = TransformLayer()
123
+
124
+ with torch.no_grad():
125
+ enc.copy_weights(enc_ref)
126
+ dec.copy_weights(dec_ref)
127
+ matrix.copy_weights(matrix_ref)
128
+
129
+ content = Data(
130
+ x=colors, clusters=clusters,
131
+ edge_indexes=edge_indexes,
132
+ selections_list=selections_list,
133
+ interps_list=interps_list,
134
+ ).to(device)
135
+
136
+ style, _ = gio.image2Graph(style_ref, depth=3, device=device)
137
+
138
+ enc = enc.to(device)
139
+ dec = dec.to(device)
140
+ matrix = matrix.to(device)
141
+ log(f"Networks loaded ({time()-t0:.1f}s)")
142
+
143
+ # ── 5. Style transfer ─────────────────────────────────────────────────────
144
+ progress(0.70, desc="Running style transfer…")
145
+ t0 = time()
146
+
147
+ with torch.no_grad():
148
+ cF = enc(content)
149
+ sF = enc(style)
150
+ feature, _ = matrix(
151
+ cF["r41"], sF["r41"],
152
+ content.edge_indexes[3], content.selections_list[3],
153
+ style.edge_indexes[3], style.selections_list[3],
154
+ content.interps_list[3] if hasattr(content, "interps_list") else None,
155
+ )
156
+ result = dec(feature, content).clamp(0, 1)
157
+
158
+ colors[:, 0:3] = result
159
+ log(f"Stylization done ({time()-t0:.1f}s)")
160
+
161
+ # ── 6. Interpolate back to original resolution ────────────────────────────
162
+ progress(0.88, desc="Interpolating back to original splat…")
163
+ t0 = time()
164
+
165
+ interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
166
+ results_OriginalNP = interp2(pos3D_Original)
167
+ results_Original = torch.from_numpy(results_OriginalNP).to(torch.float32)
168
+ colors_and_opacity_Original = torch.cat(
169
+ (results_Original, opacity_Original.unsqueeze(1)), dim=1
170
+ )
171
+ log(f"Interpolation done ({time()-t0:.1f}s)")
172
+
173
+ # ── 7. Save output ────────────────────────────────────────────────────────
174
+ progress(0.95, desc="Saving output splat…")
175
+ suffix = ".splat" if fileType == "splat" else ".ply"
176
+ out_dir = tempfile.mkdtemp()
177
+ out_path = os.path.join(out_dir, f"stylized{suffix}")
178
+
179
+ splt.splat_save(
180
+ pos3D_Original.numpy(),
181
+ scales_Original.numpy(),
182
+ rots_Original.numpy(),
183
+ colors_and_opacity_Original.numpy(),
184
+ out_path,
185
+ fileType,
186
+ )
187
+ log(f"Saved to: {out_path}")
188
+ progress(1.0, desc="Done!")
189
+
190
+ return out_path, "\n".join(logs)
191
+
192
+
193
+ # ── Gradio UI ─────────────────────────────────────────────────────────────────
194
+ def build_ui():
195
+ available_devices = (
196
+ [str(i) for i in range(torch.cuda.device_count())] + ["cpu"]
197
+ if torch.cuda.is_available()
198
+ else ["cpu"]
199
+ )
200
+
201
+ with gr.Blocks(
202
+ title="3DGS Style Transfer",
203
+ theme=gr.themes.Soft(primary_hue="violet"),
204
+ css="""
205
+ #title { text-align: center; }
206
+ #subtitle { text-align: center; color: #666; margin-bottom: 1rem; }
207
+ .panel { border-radius: 12px; }
208
+ #run-btn { font-size: 1.1rem; }
209
+ """,
210
+ ) as demo:
211
+
212
+ gr.Markdown("# 🎨 3D Gaussian Splat Style Transfer", elem_id="title")
213
+ gr.Markdown(
214
+ "Upload a 3DGS scene and a style image — the app will repaint the splat "
215
+ "with the artistic style of the image and give you a stylized splat to download. "
216
+ "After downloading, you can view your splat with an [online viewer](https://antimatter15.com/splat/).",
217
+ elem_id="subtitle",
218
+ )
219
+
220
+ with gr.Row():
221
+ # ── Left column: inputs ───────────────────────────────────────────
222
+ with gr.Column(scale=1, elem_classes="panel"):
223
+ gr.Markdown("### 📂 Inputs")
224
+
225
+ splat_input = gr.File(
226
+ label="3D Gaussian Splat (.ply or .splat)",
227
+ file_types=[".ply", ".splat"],
228
+ type="filepath",
229
+ )
230
+
231
+ style_input = gr.Image(
232
+ label="Style Image",
233
+ type="filepath",
234
+ height=240,
235
+ )
236
+
237
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
238
+ threshold_slider = gr.Slider(
239
+ minimum=90.0, maximum=100.0, value=99.8, step=0.1,
240
+ label="Opacity threshold (percentile)",
241
+ info="Points below this opacity percentile are removed.",
242
+ )
243
+ sampling_slider = gr.Slider(
244
+ minimum=0.5, maximum=3.0, value=1.5, step=0.1,
245
+ label="Gaussian super-sampling rate",
246
+ info="Values > 1 add extra samples; 1.0 = no super-sampling.",
247
+ )
248
+ device_radio = gr.Radio(
249
+ choices=available_devices,
250
+ value=available_devices[0],
251
+ label="Device",
252
+ )
253
+
254
+ run_btn = gr.Button("🚀 Run Style Transfer", variant="primary", elem_id="run-btn")
255
+
256
+ # ── Right column: outputs ─────────────────────────────────────────
257
+ with gr.Column(scale=1, elem_classes="panel"):
258
+ gr.Markdown("### 📥 Output")
259
+
260
+ output_file = gr.File(
261
+ label="Download Stylized Splat",
262
+ interactive=False,
263
+ )
264
+
265
+ log_box = gr.Textbox(
266
+ label="Progress log",
267
+ lines=12,
268
+ max_lines=20,
269
+ interactive=False,
270
+ placeholder="Logs will appear here once processing starts…",
271
+ )
272
+
273
+ # ── Examples ─────────────────────────────────────────────────────────
274
+ example_splat_paths = [row[0] for row in EXAMPLE_SPLATS]
275
+ example_style_paths = [row[1] for row in EXAMPLE_SPLATS]
276
+
277
+ valid_examples = [
278
+ row for row in EXAMPLE_SPLATS
279
+ if os.path.exists(row[0]) and os.path.exists(row[1])
280
+ ]
281
+
282
+ if valid_examples:
283
+ gr.Markdown("### 🖼️ Examples")
284
+ gr.Examples(
285
+ examples=valid_examples,
286
+ inputs=[splat_input, style_input],
287
+ label="Click an example to load it",
288
+ )
289
+
290
+ # ── Event wiring ──────────────────────────────────────────────────────
291
+ run_btn.click(
292
+ fn=run_style_transfer,
293
+ inputs=[splat_input, style_input, threshold_slider, sampling_slider, device_radio],
294
+ outputs=[output_file, log_box],
295
+ )
296
+
297
+ return demo
298
+
299
+
300
+ if __name__ == "__main__":
301
+ demo = build_ui()
302
+ demo.launch(share=False)
clusters.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_scatter import scatter
3
+ from torch_geometric.nn.pool.consecutive import consecutive_cluster
4
+ from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops
5
+ from torch_geometric.nn import fps, knn
6
+ from torch_sparse import coalesce
7
+ import graph_helpers as gh
8
+ import sphere_helpers as sh
9
+ import mesh_helpers as mh
10
+
11
+ import math
12
+ from math import pi,sqrt
13
+
14
+
15
+ from warnings import warn
16
+
17
+ def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2):
18
+ clusters = []
19
+ edge_indexes = [torch.clone(edge_index).to(device)]
20
+ selections_list = [torch.clone(selections).to(device)]
21
+
22
+ for _ in range(depth):
23
+ Nx = Nx//stride
24
+ Ny = Ny//stride
25
+ cx,cy = getGrid(pos2D,Nx,Ny)
26
+ cluster, pos2D = gridCluster(pos2D,cx,cy,Nx)
27
+ edge_index, selections = selectionAverage(cluster, edge_index, selections)
28
+
29
+ clusters.append(torch.clone(cluster).to(device))
30
+ edge_indexes.append(torch.clone(edge_index).to(device))
31
+ selections_list.append(torch.clone(selections).to(device))
32
+
33
+ return clusters, edge_indexes, selections_list
34
+
35
+ def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'):
36
+ clusters = []
37
+ edge_indexes = [torch.clone(edge_index).to(device)]
38
+ selections_list = [torch.clone(selections).to(device)]
39
+ interps_list = [torch.clone(interps).to(device)]
40
+
41
+ for _ in range(depth):
42
+
43
+ rows = rows//stride
44
+ cols = cols//stride
45
+
46
+ if bary_d is not None:
47
+ bary_d = bary_d*stride
48
+
49
+ if cluster_method == "equirec":
50
+ centroids, _ = sh.sampleSphere_Equirec(rows,cols)
51
+
52
+ elif cluster_method == "layering":
53
+ centroids, _ = sh.sampleSphere_Layering(rows)
54
+
55
+ elif cluster_method == "spiral":
56
+ centroids, _ = sh.sampleSphere_Spiral(rows,cols)
57
+
58
+ elif cluster_method == "icosphere":
59
+ centroids, _ = sh.sampleSphere_Icosphere(rows)
60
+
61
+ elif cluster_method == "random":
62
+ centroids, _ = sh.sampleSphere_Random(rows,cols)
63
+
64
+ elif cluster_method == "random_nodes":
65
+ index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
66
+ centroids = pos3D[index]
67
+
68
+ elif cluster_method == "fps":
69
+ # Farthest Point Search used in PointNet++
70
+ index = fps(pos3D, ratio=ratio)
71
+ centroids = pos3D[index]
72
+ else:
73
+ raise ValueError("Sphere cluster_method unknown")
74
+
75
+
76
+ # Find closest centriod to each current point
77
+ cluster = knn(centroids,pos3D,1)[1]
78
+ cluster, _ = consecutive_cluster(cluster)
79
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
80
+
81
+ # Regenerate surface graph
82
+ normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors
83
+ edge_index,directions = gh.surface2Edges(pos3D,normals)
84
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
85
+
86
+ clusters.append(torch.clone(cluster).to(device))
87
+ edge_indexes.append(torch.clone(edge_index).to(device))
88
+ selections_list.append(torch.clone(selections).to(device))
89
+ interps_list.append(torch.clone(interps).to(device))
90
+
91
+ return clusters, edge_indexes, selections_list, interps_list
92
+
93
+ def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'):
94
+ clusters = []
95
+ edge_indexes = [torch.clone(edge_index).to(device)]
96
+ selections_list = [torch.clone(selections).to(device)]
97
+ interps_list = [torch.clone(interps).to(device)]
98
+
99
+ for _ in range(depth):
100
+
101
+ #Desired number of clusters in the next level
102
+ N = int(len(pos3D) * ratio)
103
+
104
+ if cluster_method == "random":
105
+ index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
106
+ centroids = pos3D[index]
107
+
108
+ elif cluster_method == "fps":
109
+ # Farthest Point Search used in PointNet++
110
+ index = fps(pos3D, ratio=ratio)
111
+ centroids = pos3D[index]
112
+
113
+ # Find closest centriod to each current point
114
+ cluster = knn(centroids,pos3D,1)[1]
115
+ cluster, _ = consecutive_cluster(cluster)
116
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
117
+ normals = scatter(normals, cluster, dim=0, reduce='mean')
118
+
119
+ # Regenerate surface graph
120
+ normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
121
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
122
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
123
+
124
+ clusters.append(torch.clone(cluster).to(device))
125
+ edge_indexes.append(torch.clone(edge_index).to(device))
126
+ selections_list.append(torch.clone(selections).to(device))
127
+ interps_list.append(torch.clone(interps).to(device))
128
+
129
+ return clusters, edge_indexes, selections_list, interps_list
130
+
131
+
132
+ def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'):
133
+ clusters = []
134
+ edge_indexes = [torch.clone(edge_index).to(device)]
135
+ selections_list = [torch.clone(selections).to(device)]
136
+ interps_list = [torch.clone(interps).to(device)]
137
+
138
+ for _ in range(depth):
139
+
140
+ #Desired number of clusters in the next level
141
+ N = int(len(pos3D) * ratio)
142
+
143
+ # Generate new point cloud from downsampled version of texture map
144
+ centroids, normals = mh.sampleSurface(mesh,N,return_x=False)
145
+
146
+ # Find closest centriod to each current point
147
+ cluster = knn(centroids,pos3D,1)[1]
148
+ cluster, _ = consecutive_cluster(cluster)
149
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
150
+
151
+ # Regenerate surface graph
152
+ #normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
153
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector)
154
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
155
+
156
+ clusters.append(torch.clone(cluster).to(device))
157
+ edge_indexes.append(torch.clone(edge_index).to(device))
158
+ selections_list.append(torch.clone(selections).to(device))
159
+ interps_list.append(torch.clone(interps).to(device))
160
+
161
+ return clusters, edge_indexes, selections_list, interps_list
162
+
163
+
164
+
165
+ def getGrid(pos,Nx,Ny,xrange=None,yrange=None):
166
+ xmin = torch.min(pos[:,0]) if xrange is None else xrange[0]
167
+ ymin = torch.min(pos[:,1]) if yrange is None else yrange[0]
168
+ xmax = torch.max(pos[:,0]) if xrange is None else xrange[1]
169
+ ymax = torch.max(pos[:,1]) if yrange is None else yrange[1]
170
+
171
+ cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1)
172
+ cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1)
173
+ return cx, cy
174
+
175
+ def gridCluster(pos,cx,cy,xmax):
176
+ cluster = cx + cy*xmax
177
+ cluster = cluster.type(torch.long) # Cast appropriately
178
+ cluster, _ = consecutive_cluster(cluster)
179
+ pos = scatter(pos, cluster, dim=0, reduce='mean')
180
+
181
+ return cluster, pos
182
+
183
+ def selectionAverage(cluster, edge_index, selections):
184
+ num_nodes = cluster.size(0)
185
+ edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1)
186
+ edge_index, selections = remove_self_loops(edge_index, selections)
187
+ if edge_index.numel() > 0:
188
+
189
+ # To avoid means over discontinuities, do mean for two selections at at a time
190
+ final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
191
+ selections_check = torch.round(selections_check).type(torch.long)
192
+
193
+ final_selections = torch.zeros_like(selections_check).to(selections.device)
194
+
195
+ final_selections[torch.where(selections_check==4)] = 4
196
+ final_selections[torch.where(selections_check==5)] = 5
197
+
198
+ #Rotate selection kernel
199
+ selections += 2
200
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
201
+
202
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
203
+ selections_check = torch.round(selections_check).type(torch.long)
204
+ final_selections[torch.where(selections_check==4)] = 2
205
+ final_selections[torch.where(selections_check==5)] = 3
206
+
207
+ #Rotate selection kernel
208
+ selections += 2
209
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
210
+
211
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
212
+ selections_check = torch.round(selections_check).type(torch.long)
213
+ final_selections[torch.where(selections_check==4)] = 8
214
+ final_selections[torch.where(selections_check==5)] = 1
215
+
216
+ #Rotate selection kernel
217
+ selections += 2
218
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
219
+
220
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
221
+ selections_check = torch.round(selections_check).type(torch.long)
222
+ final_selections[torch.where(selections_check==4)] = 6
223
+ final_selections[torch.where(selections_check==5)] = 7
224
+
225
+ #print(torch.min(final_selections), torch.max(final_selections))
226
+ #print(torch.mean(final_selections.type(torch.float)))
227
+
228
+ edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long))
229
+
230
+ else:
231
+ edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long))
232
+ print("Warning: Edge Pool found no edges")
233
+
234
+ return edge_index, selections
example-broche-rose-gold.splat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62cccf0adfb4e5713985e8ee874fa30a6a37ae222a23ead7c6e4639a5802ab62
3
+ size 4157728
example.jpg ADDED
graph_helpers.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_geometric.nn import radius_graph, knn_graph
3
+ import torch_geometric as tg
4
+ from torch_geometric.utils import subgraph
5
+ import utils
6
+ from math import sqrt
7
+
8
+ def getImPos(rows,cols,start_row=0,start_col=0):
9
+ row_space = torch.arange(start_row,rows+start_row)
10
+ col_space = torch.arange(start_col,cols+start_col)
11
+ col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy')
12
+ im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2))
13
+ return im_pos
14
+
15
+ def convertImPos(im_pos,flip_y=True):
16
+
17
+ # Cast to float for clustering based methods
18
+ pos2D = im_pos.float()
19
+
20
+ # Switch rows,cols to x,y
21
+ pos2D[:,[1,0]] = pos2D[:,[0,1]]
22
+
23
+ if flip_y:
24
+
25
+ # Flip to y-axis to match mathematical definition and edges2Selections settings
26
+ pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1]
27
+
28
+ return pos2D
29
+
30
+
31
+ def grid2Edges(locs):
32
+ # Assume locs are already spaced at a distance of 1 structure
33
+ edge_index = radius_graph(locs,1.44,loop=True)
34
+ return edge_index
35
+
36
+ def radius2Edges(locs,r=1.0):
37
+ edge_index = radius_graph(locs,r,loop=True)
38
+ return edge_index
39
+
40
+ def knn2Edges(locs,knn=9):
41
+ edge_index = knn_graph(locs,knn,loop=True)
42
+ return edge_index
43
+
44
+ def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9):
45
+
46
+ if up_vector is None:
47
+ up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device)
48
+
49
+ # K Nearest Neighbors graph
50
+ edge_index = knn_graph(pos3D,k_neighbors,loop=True)
51
+
52
+ # Cull neighbors based on normals (dot them together)
53
+ culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1)
54
+ edge_index = edge_index[:,torch.where(culling>0)[0]]
55
+
56
+ # For each node, rotate based on Grahm-Schmidt Orthognalization
57
+ norms = normals[edge_index[0]]
58
+
59
+ z_dir = norms
60
+ z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector
61
+ #x_dir = torch.cross(up_vector,norms,dim=1)
62
+ x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch
63
+ x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True)
64
+ #y_dir = torch.cross(norms,x_dir,dim=1)
65
+ y_dir = utils.cross(norms,x_dir)
66
+ y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True)
67
+
68
+ directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]])
69
+
70
+ # Perform rotation by multiplying out rotation matrix
71
+ temp = torch.clone(directions) # Buffer
72
+ directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2]
73
+ directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2]
74
+ #directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2]
75
+
76
+ # Drop z coordinate
77
+ directions = directions[:,:2]
78
+
79
+ return edge_index, directions
80
+
81
+ def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False):
82
+
83
+ # Current Ordering
84
+ # 4 3 2
85
+ # 5 0 1
86
+ # 6 7 8
87
+ if y_down:
88
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0)
89
+ else:
90
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
91
+
92
+ if interpolated:
93
+
94
+ if bary_d is None:
95
+ edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList)
96
+ else:
97
+ edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList)
98
+ interps = normalizeEdges(edge_index,selections,interps)
99
+ return edge_index,selections,interps
100
+
101
+ else:
102
+ selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1
103
+ selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection
104
+ return selections
105
+
106
+
107
+ def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True):
108
+
109
+ sources = sources.flatten()
110
+ targets = targets.flatten()
111
+
112
+ prev_sources += sources.tolist()
113
+ prev_targets += targets.tolist()
114
+ prev_selections += len(sources)*[selection]
115
+
116
+ if reverse:
117
+ prev_sources += targets
118
+ prev_targets += sources
119
+ prev_selections += len(sources)*[utils.reverse_selection(selection)]
120
+
121
+ return prev_sources,prev_targets,prev_selections
122
+
123
+ def maskNodes(mask,x):
124
+ node_mask = torch.where(mask)
125
+ x = x[node_mask]
126
+ return x
127
+
128
+ def maskPoints(mask,x,y):
129
+
130
+ mask = torch.squeeze(mask)
131
+
132
+ x0 = torch.floor(x).long()
133
+ x1 = x0 + 1
134
+ y0 = torch.floor(y).long()
135
+ y1 = y0 + 1
136
+
137
+ x0 = torch.clip(x0, 0, mask.shape[1]-1);
138
+ x1 = torch.clip(x1, 0, mask.shape[1]-1);
139
+ y0 = torch.clip(y0, 0, mask.shape[0]-1);
140
+ y1 = torch.clip(y1, 0, mask.shape[0]-1);
141
+
142
+ Ma = mask[ y0, x0 ]
143
+ Mb = mask[ y1, x0 ]
144
+ Mc = mask[ y0, x1 ]
145
+ Md = mask[ y1, x1 ]
146
+
147
+ node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0]
148
+
149
+ return node_mask
150
+
151
+
152
+ def maskGraph(mask,edge_index,selections,interps=None):
153
+
154
+ edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True)
155
+ selections = selections[edge_mask]
156
+
157
+ if interps:
158
+ interps = interps[edge_mask]
159
+ return edge_index, selections, interps
160
+ else:
161
+ return edge_index, selections
162
+
163
+ def interpolateSelections(edge_index,directions,vectorList=None):
164
+
165
+ if vectorList is None:
166
+ # Current Ordering
167
+ # 4 3 2
168
+ # 5 0 1
169
+ # 6 7 8
170
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
171
+
172
+ # Normalize directions for simplicity of calculations
173
+ dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
174
+ directions = directions/dir_norm
175
+ #locs = torch.where(dir_norm > 1)[0]
176
+ #directions[locs] = directions[locs]/dir_norm[locs]
177
+
178
+ values = torch.matmul(directions,vectorList)
179
+ best = torch.unsqueeze(torch.argmax(values,dim=1),1)
180
+
181
+ best_val = torch.take_along_dim(values,best,dim=1)
182
+
183
+ # Look at both neighbors to see who is closer
184
+ lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
185
+ upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
186
+
187
+ comp_vals = torch.cat((lower_val,upper_val),dim=1)
188
+
189
+ second_best_vals = torch.amax(comp_vals,dim=1)
190
+ second_best = torch.argmax(comp_vals,dim=1)
191
+
192
+ # Find the interpolation value (in terms of angles)
193
+ best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function
194
+ angle_best = torch.arccos(best_val)
195
+ angle_second_best = torch.arccos(second_best_vals)
196
+
197
+ angle_vals = angle_best/(angle_second_best + angle_best)
198
+
199
+ # Use negative values for clockwise selections
200
+ clockwise = torch.where(second_best == 0)[0]
201
+ angle_vals[clockwise] = -angle_vals[clockwise]
202
+
203
+ # Handle computation problems at the poles
204
+ angle_vals = torch.nan_to_num(angle_vals)
205
+
206
+ # Make Selections
207
+ selections = best[:,0] + 1
208
+
209
+ # Same cell selection
210
+ same_locs = torch.where(edge_index[0] == edge_index[1])
211
+ selections[same_locs] = 0
212
+ angle_vals[same_locs] = 0
213
+
214
+ # Make starting interp_values
215
+ interps = torch.ones_like(angle_vals)
216
+ interps -= torch.abs(angle_vals)
217
+
218
+ # Add new edges
219
+ pos_interp_locs = torch.where(angle_vals > 1e-2)[0]
220
+ pos_interps = angle_vals[pos_interp_locs]
221
+ pos_edges = edge_index[:,pos_interp_locs]
222
+ pos_selections = selections[pos_interp_locs] + 1
223
+ pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around
224
+
225
+ neg_interp_locs = torch.where(angle_vals < -1e-2)[0]
226
+ neg_interps = torch.abs(angle_vals[neg_interp_locs])
227
+ neg_edges = edge_index[:,neg_interp_locs]
228
+ neg_selections = selections[neg_interp_locs] - 1
229
+ neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
230
+
231
+ edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1)
232
+ selections = torch.cat((selections,pos_selections,neg_selections),dim=0)
233
+ interps = torch.cat((interps,pos_interps,neg_interps),dim=0)
234
+
235
+ return edge_index,selections,interps
236
+
237
+ def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None):
238
+
239
+ if vectorList is None:
240
+ # Current Ordering
241
+ # 4 3 2
242
+ # 5 0 1
243
+ # 6 7 8
244
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device)
245
+
246
+ # Preprune central selections and reappend them at the end
247
+ same_locs = torch.where(edge_index[0] == edge_index[1])
248
+ same_edges = edge_index[:,same_locs[0]]
249
+
250
+ different_locs = torch.where(edge_index[0] != edge_index[1])
251
+ edge_index = edge_index[:,different_locs[0]]
252
+ directions = directions[different_locs[0]]
253
+
254
+ # Normalize directions for simplicity of calculations
255
+ dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
256
+ unit_directions = directions/dir_norm
257
+ #locs = torch.where(dir_norm > 1)[0]
258
+ #directions[locs] = directions[locs]/dir_norm[locs]
259
+
260
+ values = torch.matmul(unit_directions,vectorList)
261
+ best = torch.unsqueeze(torch.argmax(values,dim=1),1)
262
+ #best_val = torch.take_along_dim(values,best,dim=1)
263
+
264
+ # Look at both neighbors to see who is closer
265
+ lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
266
+ upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
267
+
268
+ comp_vals = torch.cat((lower_val,upper_val),dim=1)
269
+
270
+ second_best = torch.argmax(comp_vals,dim=1)
271
+ #second_best_vals = torch.amax(comp_vals,dim=1)
272
+
273
+ # Convert into uv cooridnates for barycentric interpolation calculation
274
+ # /|
275
+ # / |v
276
+ # /__|
277
+ # u
278
+
279
+ scaled_directions = torch.abs(directions/d)
280
+ u = torch.amax(scaled_directions,dim=1)
281
+ v = torch.amin(scaled_directions,dim=1)
282
+
283
+ # Force coordinates to be within the triangle
284
+ boundary_check = torch.where(u > d)
285
+ v[boundary_check] /= u[boundary_check]
286
+ u[boundary_check] = 1.0
287
+
288
+ # Precalculated barycentric values from linear matrix solve
289
+ I0 = 1 - u
290
+ I1 = u - v
291
+ I2 = v
292
+
293
+ # Make first selections and proper interps
294
+ selections = best[:,0] + 1
295
+ interps = I1
296
+ even_sels = torch.where(selections % 2 == 0)
297
+ interps[even_sels] = I2[even_sels] # Corners get different weights
298
+
299
+ # Make new edges for the central selections
300
+ central_edges = torch.clone(edge_index).to(edge_index.device)
301
+ central_selections = torch.zeros_like(selections)
302
+ central_interps = I0
303
+
304
+ # Make new edges for the last selection
305
+ pos_locs = torch.where(second_best==1)[0]
306
+ pos_edges = edge_index[:,pos_locs]
307
+ pos_selections = selections[pos_locs] + 1
308
+ pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around
309
+ pos_interps = I1[pos_locs]
310
+ even_sels = torch.where(pos_selections % 2 == 0)
311
+ pos_interps[even_sels] = I2[pos_locs][even_sels]
312
+
313
+ neg_locs = torch.where(second_best==0)[0]
314
+ neg_edges = edge_index[:,neg_locs]
315
+ neg_selections = selections[neg_locs] - 1
316
+ neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
317
+ neg_interps = I1[neg_locs]
318
+ even_sels = torch.where(neg_selections % 2 == 0)
319
+ neg_interps[even_sels] = I2[neg_locs][even_sels]
320
+
321
+ # Account for the previously pruned same node edges
322
+ same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long)
323
+ same_interps = torch.ones(same_edges.shape[1],dtype=torch.float)
324
+
325
+ # Combine
326
+ edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1)
327
+ selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0)
328
+ interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0)
329
+
330
+ #edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1)
331
+ #selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0)
332
+ #interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0)
333
+
334
+ # Account for edges to the same node
335
+ #same_locs = torch.where(edge_index[0] == edge_index[1])
336
+ #selections[same_locs] = 0
337
+ #interps[same_locs] = 1
338
+
339
+ return edge_index,selections,interps
340
+
341
+ def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False):
342
+ '''Given an edge_index and selections, normalize the edges for each node so that
343
+ aggregation of edges with interps = 1. If interps is given, use a weighted average.
344
+ if kernel_norm = True, account for missing selections by increasing weight on other selections.'''
345
+
346
+ N = torch.max(edge_index) + 1
347
+ S = torch.max(selections) + 1
348
+
349
+ total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device)
350
+
351
+ if interps is None:
352
+ interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device)
353
+
354
+ # Aggregate all edges to determine normalizations per selection
355
+ nodes = edge_index[0]
356
+ #total_weight[nodes,selections] += interps
357
+ total_weight.index_put_((nodes,selections),interps,accumulate=True)
358
+
359
+ # Reassign interps accordingly
360
+ if kernel_norm:
361
+ row_totals = torch.sum(total_weight,dim=1)
362
+ interps = interps * S/row_totals[nodes]
363
+ else:
364
+ norms = total_weight[nodes,selections]
365
+ norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error
366
+ interps = interps/norms
367
+
368
+ return interps
369
+
370
+ def simplifyGraph(edge_index,selections,edge_lengths):
371
+ # Take the shortest edge for the set of the same selections on a given node
372
+ num_edges = edge_index.shape[1]
373
+
374
+ # Keep track of which nodes have been visited
375
+ keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device)
376
+
377
+ previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
378
+ previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
379
+
380
+ for i in range(num_edges):
381
+ start_node = edge_index[0,i]
382
+ #end_node = edge_index[1,i]
383
+ selection = selections[i]
384
+ distance = edge_lengths[i]
385
+
386
+ if distance < previous_best_distance[start_node,selection]:
387
+ previous_best_distance[start_node,selection] = distance
388
+ keep_edges[i] = True
389
+
390
+ prev = previous_best_edge[start_node,selection]
391
+ if prev != -1:
392
+ keep_edges[prev] = False
393
+
394
+ previous_best_edge[start_node,selection] = i
395
+
396
+ edge_index = edge_index[:,torch.where(keep_edges)[0]]
397
+ selections = selections[torch.where(keep_edges)]
398
+
399
+ return edge_index, selections
400
+
graph_io.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch_geometric.nn import knn
4
+ from torch_geometric.data import Data
5
+ from torch_geometric.nn import radius_graph, knn_graph
6
+
7
+ import graph_helpers as gh
8
+ import sphere_helpers as sh
9
+ import mesh_helpers as mh
10
+ import clusters as cl
11
+ import utils
12
+
13
+ from torch_scatter import scatter
14
+
15
+ import math
16
+ from math import pi, sqrt
17
+
18
+ from warnings import warn
19
+
20
+ def image2Graph(data, gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
21
+
22
+ _,ch,rows,cols = data.shape
23
+
24
+ x = torch.reshape(data,(ch,rows*cols)).permute((1,0)).to(device)
25
+
26
+ if mask is not None:
27
+ # Mask out nodes
28
+ node_mask = torch.where(mask.flatten())
29
+ x = x[node_mask]
30
+
31
+ if gt is not None:
32
+ y = gt.flatten().to(device)
33
+ if mask is not None:
34
+ y = y[node_mask]
35
+
36
+ if x_only:
37
+ if gt is not None:
38
+ return x,y
39
+ else:
40
+ return x
41
+
42
+ im_pos = gh.getImPos(rows,cols)
43
+
44
+ if mask is not None:
45
+ im_pos = im_pos[node_mask]
46
+
47
+ # Make "point cloud" for clustering
48
+ pos2D = gh.convertImPos(im_pos,flip_y=False)
49
+
50
+ # Generate initial graph
51
+ edge_index = gh.grid2Edges(pos2D)
52
+ directions = pos2D[edge_index[1]] - pos2D[edge_index[0]]
53
+ selections = gh.edges2Selections(edge_index,directions,interpolated=False,y_down=True)
54
+
55
+ # Generate info for downsampled versions of the graph
56
+ clusters, edge_indexes, selections_list = cl.makeImageClusters(pos2D,cols,rows,edge_index,selections,depth=depth,device=device)
57
+
58
+ # Make final graph and metadata needed for mapping the result after going through the network
59
+ graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=None)
60
+ metadata = Data(original=data,im_pos=im_pos.long(),rows=rows,cols=cols,ch=ch)
61
+
62
+ if gt is not None:
63
+ graph.y = y
64
+
65
+ return graph,metadata
66
+
67
+ def graph2Image(result,metadata,canvas=None):
68
+
69
+ x = utils.toNumpy(result,permute=False)
70
+ im_pos = utils.toNumpy(metadata.im_pos,permute=False)
71
+ if canvas is None:
72
+ canvas = utils.makeCanvas(x,metadata.original)
73
+
74
+ # Paint over the original image (neccesary for masked images)
75
+ canvas[im_pos[:,0],im_pos[:,1]] = x
76
+
77
+ return canvas
78
+
79
+ ### Begin Interpolated Methods ###
80
+
81
+ def sphere2Graph(data, structure="layering", cluster_method="layering", scale=1.0, stride=2, interpolation_mode = "angle", gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
82
+
83
+ _,ch,rows,cols = data.shape
84
+
85
+ if structure == "equirec":
86
+ # Use the original data to start with
87
+ cartesian, spherical = sh.sampleSphere_Equirec(scale*rows,scale*cols)
88
+ elif structure == "layering":
89
+ cartesian, spherical = sh.sampleSphere_Layering(scale*rows)
90
+ elif structure == "spiral":
91
+ cartesian, spherical = sh.sampleSphere_Spiral(scale*rows,scale*cols)
92
+ elif structure == "icosphere":
93
+ cartesian, spherical = sh.sampleSphere_Icosphere(scale*rows)
94
+ elif structure == "random":
95
+ cartesian, spherical = sh.sampleSphere_Random(scale*rows,scale*cols)
96
+ else:
97
+ raise ValueError("Sphere structure unknown")
98
+
99
+ if interpolation_mode == "bary":
100
+ bary_d = pi/(scale*rows)
101
+ else:
102
+ bary_d = None
103
+
104
+ # Get the landing point for each node
105
+ sample_x, sample_y = sh.spherical2equirec(spherical[:,0],spherical[:,1],rows,cols)
106
+
107
+ if mask is not None:
108
+
109
+ node_mask = gh.maskPoints(mask,sample_x,sample_y)
110
+ sample_x = sample_x[node_mask]
111
+ sample_y = sample_y[node_mask]
112
+ spherical = spherical[node_mask]
113
+ cartesian = cartesian[node_mask]
114
+
115
+ features = utils.bilinear_interpolate(data, sample_x, sample_y).to(device)
116
+
117
+ if gt is not None:
118
+ features_y = utils.bilinear_interpolate(gt.unsqueeze(0), sample_x, sample_y).to(device)
119
+
120
+ if x_only:
121
+ if gt is not None:
122
+ return features,features_y
123
+ else:
124
+ return features
125
+
126
+ # Build initial graph
127
+ edge_index,directions = gh.surface2Edges(cartesian,cartesian)
128
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
129
+
130
+ # Generate info for downsampled versions of the graph
131
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSphereClusters(cartesian,edge_index,selections,interps,rows*scale,cols*scale,cluster_method,stride=stride,bary_d=bary_d,depth=depth,device=device)
132
+
133
+ # Make final graph and metadata needed for mapping the result after going through the network
134
+ graph = Data(x=features,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
135
+ metadata = Data(original=data,pos3D=cartesian,mask=mask,rows=rows,cols=cols,ch=ch)
136
+
137
+ if gt is not None:
138
+ graph.y = features_y
139
+
140
+ return graph, metadata
141
+
142
+ def graph2Sphere(features,metadata):
143
+
144
+ # Generate equirectangular points and their 3D locations
145
+ theta, phi = sh.equirec2spherical(metadata.rows, metadata.cols)
146
+ x,y,z = sh.spherical2xyz(theta,phi)
147
+
148
+ v = torch.stack((x,y,z),dim=1)
149
+
150
+ # Find closest 3D point to each equirectangular point
151
+ nearest = torch.reshape(knn(metadata.pos3D,v,3)[1],(len(v),3))
152
+
153
+ #Interpolate based on proximty to each node
154
+ w0 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,0]]),dim=1, keepdim=True).to(features.device)
155
+ w1 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,1]]),dim=1, keepdim=True).to(features.device)
156
+ w2 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,2]]),dim=1, keepdim=True).to(features.device)
157
+
158
+ w0 = torch.nan_to_num(w0, nan=1e6)
159
+ w1 = torch.nan_to_num(w1, nan=1e6)
160
+ w2 = torch.nan_to_num(w2, nan=1e6)
161
+
162
+ w0 = torch.clamp(w0,0,1e6)
163
+ w1 = torch.clamp(w1,0,1e6)
164
+ w2 = torch.clamp(w2,0,1e6)
165
+
166
+ total = w0 + w1 + w2
167
+
168
+ #w0,w1,w2 = mh.getBarycentricWeights(v,metadata.pos3D[nearest[:,0]],metadata.pos3D[nearest[:,1]],metadata.pos3D[nearest[:,2]])
169
+
170
+ #w0 = w0.unsqueeze(1).to(features.device)
171
+ #w1 = w1.unsqueeze(1).to(features.device)
172
+ #w2 = w2.unsqueeze(1).to(features.device)
173
+
174
+ result = (w0*features[nearest[:,0]] + w1*features[nearest[:,1]] + w2*features[nearest[:,2]])/total
175
+
176
+ #result = result.clamp(0,1)
177
+
178
+ if hasattr(metadata,"mask"):
179
+ mask = utils.toNumpy(metadata.mask.squeeze(),permute=False)
180
+ canvas = utils.makeCanvas(result,metadata.original)
181
+ result = np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
182
+ canvas[np.where(mask)] = result[np.where(mask)]
183
+ return canvas
184
+ else:
185
+ return np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
186
+
187
+
188
+
189
+ def splat2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, depth = 1, device = 'cpu'):
190
+ """ Sample mesh faces to determine graph """
191
+
192
+ if up_vector == None:
193
+ up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
194
+ #up_vector = 2*torch.rand((1,3))-1
195
+ up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
196
+
197
+
198
+ #position, normal vector, uv coordinates in the texture map, x is color
199
+ pos3D, normals = mh.sampleSurface(mesh,N)
200
+
201
+ # Build initial graph
202
+ #edge_index are neighbors of a point, directions are the directions from that point
203
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
204
+ #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
205
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
206
+
207
+ # Generate info for downsampled versions of the graph
208
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
209
+ #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
210
+
211
+ # Make final graph and metadata needed for mapping the result after going through the network
212
+ graph = Data(clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
213
+ metadata = Data(original=data,pos3D=pos3D,mesh=mesh)
214
+
215
+ return graph,metadata
216
+
217
+ def mesh2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, mask = None, depth = 1, x_only = False, device = 'cpu'):
218
+ """ Sample mesh faces to determine graph """
219
+
220
+ if up_vector == None:
221
+ up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
222
+ #up_vector = 2*torch.rand((1,3))-1
223
+ up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
224
+
225
+ if mask is not None:
226
+ warn("Masks are not currently implemented for mesh graphs")
227
+
228
+ #position, normal vector, uv coordinates in the texture map, x is color
229
+ pos3D, normals, uvs, x = mh.sampleSurface(mesh,N,return_x=True)
230
+
231
+ x = x.to(device)
232
+
233
+ if x_only:
234
+ warn("x_only returns randomly selected points for mesh2Graph. Do not use with previous graph structures")
235
+ return x
236
+
237
+ # Build initial graph
238
+ #edge_index are neighbors of a point, directions are the directions from that point
239
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
240
+ #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
241
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
242
+
243
+ # Generate info for downsampled versions of the graph
244
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
245
+ #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
246
+
247
+ # Make final graph and metadata needed for mapping the result after going through the network
248
+ graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
249
+ metadata = Data(original=data,pos3D=pos3D,uvs=uvs,mesh=mesh)
250
+
251
+ return graph,metadata
252
+
253
+ def graph2Splat(features,metadata,view3D=False):
254
+
255
+ features = features.cpu().numpy()
256
+
257
+ canvas = utils.toNumpy(metadata.original)
258
+ rows,cols,ch = canvas.shape
259
+
260
+ # Get 2D positions by scaling uv
261
+ pos2D = metadata.uvs.cpu().numpy()
262
+ pos2D[:,0] = pos2D[:,0]*cols
263
+ pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
264
+ pos2D[:,1] = pos2D[:,1]*rows
265
+
266
+ # Generate desired points
267
+ row_space = np.arange(rows)
268
+ col_space = np.arange(cols)
269
+ col_image,row_image = np.meshgrid(col_space,row_space)
270
+
271
+ canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
272
+ canvas = np.clip(canvas,0,1)
273
+
274
+ if view3D:
275
+ mesh = mh.setTexture(metadata.mesh,canvas)
276
+ mesh.show()
277
+
278
+ return canvas
279
+
280
+
281
+ def graph2Mesh(features,metadata,view3D=False):
282
+
283
+ features = features.cpu().numpy()
284
+
285
+ canvas = utils.toNumpy(metadata.original)
286
+ rows,cols,ch = canvas.shape
287
+
288
+ # Get 2D positions by scaling uv
289
+ pos2D = metadata.uvs.cpu().numpy()
290
+ pos2D[:,0] = pos2D[:,0]*cols
291
+ pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
292
+ pos2D[:,1] = pos2D[:,1]*rows
293
+
294
+ # Generate desired points
295
+ row_space = np.arange(rows)
296
+ col_space = np.arange(cols)
297
+ col_image,row_image = np.meshgrid(col_space,row_space)
298
+
299
+ canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
300
+ canvas = np.clip(canvas,0,1)
301
+
302
+ if view3D:
303
+ mesh = mh.setTexture(metadata.mesh,canvas)
304
+ mesh.show()
305
+
306
+ return canvas
graph_networks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
graph_networks/LinearStyleTransfer/.DS_Store ADDED
Binary file (8.2 kB). View file
 
graph_networks/LinearStyleTransfer/LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 2-Clause License
2
+
3
+ Copyright (c) 2018, SunshineAtNoon
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
graph_networks/LinearStyleTransfer/README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Learning Linear Transformations for Fast Image and Video Style Transfer
2
+ **[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)**
3
+
4
+ <img src="doc/images/chicago_paste.png" height="149" hspace="5"><img src="doc/images/photo_content.png" height="150" hspace="5"><img src="doc/images/content.gif" height="150" hspace="5">
5
+ <img src="doc/images/chicago_27.png" height="150" hspace="5"><img src="doc/images/in5_result.png" height="150" hspace="5"><img src="doc/images/test.gif" height="150" hspace="5">
6
+
7
+ ## Prerequisites
8
+ - [Pytorch](http://pytorch.org/)
9
+ - [torchvision](https://github.com/pytorch/vision)
10
+ - [opencv](https://opencv.org/) for video generation
11
+
12
+ **All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2**
13
+
14
+ ## Style Transfer
15
+ - Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer`
16
+ - Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing).
17
+ - Uncompress to root folder :
18
+ ```
19
+ cd LinearStyleTransfer
20
+ unzip models.zip
21
+ rm models.zip
22
+ ```
23
+
24
+ #### Artistic style transfer
25
+ ```
26
+ python TestArtistic.py
27
+ ```
28
+ or conduct style transfer on relu_31 features
29
+ ```
30
+ python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
31
+ ```
32
+
33
+ #### Photo-realistic style transfer
34
+ For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository.
35
+ ```
36
+ cd libs/pytorch_spn
37
+ sh make.sh
38
+ cd ../..
39
+ ```
40
+ Then:
41
+ ```
42
+ python TestPhotoReal.py
43
+ ```
44
+ Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py).
45
+
46
+ #### Video style transfer
47
+ ```
48
+ python TestVideo.py
49
+ ```
50
+
51
+ #### Real-time video demo
52
+ ```
53
+ python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
54
+ ```
55
+
56
+ ## Model Training
57
+ ### Data Preparation
58
+ - MSCOCO
59
+ ```
60
+ wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
61
+ ```
62
+ - WikiArt
63
+ - Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers).
64
+ - Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running:
65
+ ```
66
+ kg download -u <username> -p <password> -c painter-by-numbers -f train.zip
67
+ ```
68
+
69
+ ### Training
70
+ #### Train a style transfer model
71
+ To train a model that transfers relu4_1 features, run:
72
+ ```
73
+ python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
74
+ ```
75
+ or train a model that transfers relu3_1 features:
76
+ ```
77
+ python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
78
+ ```
79
+ Key hyper-parameters:
80
+ - style_layers: which features to compute style loss.
81
+ - style_weight: larger style weight leads to heavier style in transferred images.
82
+
83
+ Intermediate results and weight will be stored in `OUTPUT_DIR`
84
+
85
+ #### Train a SPN model to cancel distortions for photo-realistic style transfer
86
+ Run:
87
+ ```
88
+ python TrainSPN.py --contentPath PATH_TO_MSCOCO
89
+ ```
90
+
91
+ ### Acknowledgement
92
+ - We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer.
93
+
94
+ ### Citation
95
+ ```
96
+ @inproceedings{li2018learning,
97
+ author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan},
98
+ title = {Learning Linear Transformations for Fast Arbitrary Style Transfer},
99
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
100
+ year = {2019}
101
+ }
102
+ ```
graph_networks/LinearStyleTransfer/TestArtistic.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from libs.Loader import Dataset
5
+ from libs.Matrix import MulLayer
6
+ import torchvision.utils as vutils
7
+ import torch.backends.cudnn as cudnn
8
+ from libs.utils import print_options
9
+ from libs.models import encoder3,encoder4, encoder5
10
+ from libs.models import decoder3,decoder4, decoder5
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
14
+ help='pre-trained encoder path')
15
+ parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
16
+ help='pre-trained decoder path')
17
+ parser.add_argument("--matrixPath", default='models/r41.pth',
18
+ help='pre-trained model path')
19
+ parser.add_argument("--stylePath", default="./data/style/",
20
+ help='path to style image')
21
+ parser.add_argument("--contentPath", default="./data/content/",
22
+ help='path to frames')
23
+ parser.add_argument("--outf", default="Artistic/",
24
+ help='path to transferred images')
25
+ parser.add_argument("--batchSize", type=int,default=1,
26
+ help='batch size')
27
+ parser.add_argument('--loadSize', type=int, default=256,
28
+ help='scale image size')
29
+ parser.add_argument('--fineSize', type=int, default=256,
30
+ help='crop image size')
31
+ parser.add_argument("--layer", default="r41",
32
+ help='which features to transfer, either r31 or r41')
33
+
34
+ ################# PREPARATIONS #################
35
+ opt = parser.parse_args()
36
+ opt.cuda = torch.cuda.is_available()
37
+ print_options(opt)
38
+
39
+ os.makedirs(opt.outf,exist_ok=True)
40
+ cudnn.benchmark = True
41
+
42
+ ################# DATA #################
43
+ content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True)
44
+ content_loader = torch.utils.data.DataLoader(dataset=content_dataset,
45
+ batch_size = opt.batchSize,
46
+ shuffle = False)
47
+ #num_workers = 1)
48
+ style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True)
49
+ style_loader = torch.utils.data.DataLoader(dataset=style_dataset,
50
+ batch_size = opt.batchSize,
51
+ shuffle = False)
52
+ #num_workers = 1)
53
+
54
+ ################# MODEL #################
55
+ if(opt.layer == 'r31'):
56
+ vgg = encoder3()
57
+ dec = decoder3()
58
+ elif(opt.layer == 'r41'):
59
+ vgg = encoder4()
60
+ dec = decoder4()
61
+ matrix = MulLayer(opt.layer)
62
+ vgg.load_state_dict(torch.load(opt.vgg_dir))
63
+ dec.load_state_dict(torch.load(opt.decoder_dir))
64
+ matrix.load_state_dict(torch.load(opt.matrixPath))
65
+
66
+ ################# GLOBAL VARIABLE #################
67
+ contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
68
+ styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
69
+
70
+ ################# GPU #################
71
+ if(opt.cuda):
72
+ vgg.cuda()
73
+ dec.cuda()
74
+ matrix.cuda()
75
+ contentV = contentV.cuda()
76
+ styleV = styleV.cuda()
77
+
78
+ for ci,(content,contentName) in enumerate(content_loader):
79
+ contentName = contentName[0]
80
+ contentV.resize_(content.size()).copy_(content)
81
+ for sj,(style,styleName) in enumerate(style_loader):
82
+ styleName = styleName[0]
83
+ styleV.resize_(style.size()).copy_(style)
84
+
85
+ # forward
86
+ with torch.no_grad():
87
+ sF = vgg(styleV)
88
+ cF = vgg(contentV)
89
+
90
+ if(opt.layer == 'r41'):
91
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
92
+ else:
93
+ feature,transmatrix = matrix(cF,sF)
94
+ transfer = dec(feature)
95
+
96
+ transfer = transfer.clamp(0,1)
97
+ vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize)
98
+ print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName))
graph_networks/LinearStyleTransfer/TestPhotoReal.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ from libs.SPN import SPN
9
+ import torchvision.utils as vutils
10
+ from libs.utils import print_options
11
+ from libs.MatrixTest import MulLayer
12
+ import torch.backends.cudnn as cudnn
13
+ from libs.LoaderPhotoReal import Dataset
14
+ from libs.models import encoder3,encoder4
15
+ from libs.models import decoder3,decoder4
16
+ import torchvision.transforms as transforms
17
+ from libs.smooth_filter import smooth_filter
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
21
+ help='pre-trained encoder path')
22
+ parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
23
+ help='pre-trained decoder path')
24
+ parser.add_argument("--matrixPath", default='models/r41.pth',
25
+ help='pre-trained model path')
26
+ parser.add_argument("--stylePath", default="data/photo_real/style/images/",
27
+ help='path to style image')
28
+ parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/",
29
+ help='path to style image masks')
30
+ parser.add_argument("--contentPath", default="data/photo_real/content/images/",
31
+ help='path to content image')
32
+ parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/",
33
+ help='path to content image masks')
34
+ parser.add_argument("--outf", default="PhotoReal/",
35
+ help='path to save output images')
36
+ parser.add_argument("--batchSize", type=int,default=1,
37
+ help='batch size')
38
+ parser.add_argument('--fineSize', type=int, default=512,
39
+ help='image size')
40
+ parser.add_argument("--layer", default="r41",
41
+ help='features of which layer to transform, either r31 or r41')
42
+ parser.add_argument("--spn_dir", default='models/r41_spn.pth',
43
+ help='path to pretrained SPN model')
44
+
45
+ ################# PREPARATIONS #################
46
+ opt = parser.parse_args()
47
+ opt.cuda = torch.cuda.is_available()
48
+ print_options(opt)
49
+
50
+ os.makedirs(opt.outf, exist_ok=True)
51
+
52
+ cudnn.benchmark = True
53
+
54
+ ################# DATA #################
55
+ dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize)
56
+ loader = torch.utils.data.DataLoader(dataset=dataset,
57
+ batch_size=1,
58
+ shuffle=False)
59
+
60
+ ################# MODEL #################
61
+ if(opt.layer == 'r31'):
62
+ vgg = encoder3()
63
+ dec = decoder3()
64
+ elif(opt.layer == 'r41'):
65
+ vgg = encoder4()
66
+ dec = decoder4()
67
+ matrix = MulLayer(opt.layer)
68
+ vgg.load_state_dict(torch.load(opt.vgg_dir))
69
+ dec.load_state_dict(torch.load(opt.decoder_dir))
70
+ matrix.load_state_dict(torch.load(opt.matrixPath))
71
+ spn = SPN()
72
+ spn.load_state_dict(torch.load(opt.spn_dir))
73
+
74
+ ################# GLOBAL VARIABLE #################
75
+ contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
76
+ styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
77
+ whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
78
+
79
+ ################# GPU #################
80
+ if(opt.cuda):
81
+ vgg.cuda()
82
+ dec.cuda()
83
+ spn.cuda()
84
+ matrix.cuda()
85
+ contentV = contentV.cuda()
86
+ styleV = styleV.cuda()
87
+ whitenV = whitenV.cuda()
88
+
89
+ for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader):
90
+ imname = imname[0]
91
+ contentV.resize_(contentImg.size()).copy_(contentImg)
92
+ styleV.resize_(styleImg.size()).copy_(styleImg)
93
+ whitenV.resize_(whitenImg.size()).copy_(whitenImg)
94
+
95
+ # forward
96
+ sF = vgg(styleV)
97
+ cF = vgg(contentV)
98
+
99
+ with torch.no_grad():
100
+ if(opt.layer == 'r41'):
101
+ feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks)
102
+ else:
103
+ feature = matrix(cF,sF,cmasks,smasks)
104
+ transfer = dec(feature)
105
+ filtered = spn(transfer,whitenV)
106
+ vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0])))
107
+
108
+ filtered = filtered.clamp(0,1)
109
+ filtered = filtered.cpu()
110
+ vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0]))
111
+ out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
112
+ content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
113
+ content = content.copy()
114
+ out_img = out_img.copy()
115
+ smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1)
116
+ smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0]))
117
+ print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \
118
+ %(opt.outf,imname,opt.outf,imname.split('.')[0]))
graph_networks/LinearStyleTransfer/TestVideo.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from PIL import Image
5
+ from libs.Loader import Dataset
6
+ from libs.Matrix import MulLayer
7
+ import torch.backends.cudnn as cudnn
8
+ from libs.models import encoder3,encoder4
9
+ from libs.models import decoder3,decoder4
10
+ import torchvision.transforms as transforms
11
+ from libs.utils import makeVideo, print_options
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
15
+ help='pre-trained encoder path')
16
+ parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
17
+ help='pre-trained decoder path')
18
+ parser.add_argument("--matrix_dir", default="models/r31.pth",
19
+ help='path to pre-trained model')
20
+ parser.add_argument("--style", default="data/style/in2.jpg",
21
+ help='path to style image')
22
+ parser.add_argument("--content_dir", default="data/videos/content/mountain_2/",
23
+ help='path to video frames')
24
+ parser.add_argument('--loadSize', type=int, default=512,
25
+ help='scale image size')
26
+ parser.add_argument('--fineSize', type=int, default=512,
27
+ help='crop image size')
28
+ parser.add_argument("--name",default="transferred_video",
29
+ help="name of generated video")
30
+ parser.add_argument("--layer",default="r31",
31
+ help="features of which layer to transform")
32
+ parser.add_argument("--outf",default="videos",
33
+ help="output folder")
34
+
35
+ ################# PREPARATIONS #################
36
+ opt = parser.parse_args()
37
+ opt.cuda = torch.cuda.is_available()
38
+ print_options(opt)
39
+
40
+ os.makedirs(opt.outf,exist_ok=True)
41
+ cudnn.benchmark = True
42
+
43
+ ################# DATA #################
44
+ def loadImg(imgPath):
45
+ img = Image.open(imgPath).convert('RGB')
46
+ transform = transforms.Compose([
47
+ transforms.Scale(opt.fineSize),
48
+ transforms.ToTensor()])
49
+ return transform(img)
50
+ styleV = loadImg(opt.style).unsqueeze(0)
51
+
52
+ content_dataset = Dataset(opt.content_dir,
53
+ loadSize = opt.loadSize,
54
+ fineSize = opt.fineSize,
55
+ test = True,
56
+ video = True)
57
+ content_loader = torch.utils.data.DataLoader(dataset = content_dataset,
58
+ batch_size = 1,
59
+ shuffle = False)
60
+
61
+ ################# MODEL #################
62
+ if(opt.layer == 'r31'):
63
+ vgg = encoder3()
64
+ dec = decoder3()
65
+ elif(opt.layer == 'r41'):
66
+ vgg = encoder4()
67
+ dec = decoder4()
68
+ matrix = MulLayer(layer=opt.layer)
69
+ vgg.load_state_dict(torch.load(opt.vgg_dir))
70
+ dec.load_state_dict(torch.load(opt.decoder_dir))
71
+ matrix.load_state_dict(torch.load(opt.matrix_dir))
72
+
73
+ ################# GLOBAL VARIABLE #################
74
+ contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
75
+
76
+ ################# GPU #################
77
+ if(opt.cuda):
78
+ vgg.cuda()
79
+ dec.cuda()
80
+ matrix.cuda()
81
+
82
+ styleV = styleV.cuda()
83
+ contentV = contentV.cuda()
84
+
85
+ result_frames = []
86
+ contents = []
87
+ style = styleV.squeeze(0).cpu().numpy()
88
+ sF = vgg(styleV)
89
+
90
+ for i,(content,contentName) in enumerate(content_loader):
91
+ print('Transfer frame %d...'%i)
92
+ contentName = contentName[0]
93
+ contentV.resize_(content.size()).copy_(content)
94
+ contents.append(content.squeeze(0).float().numpy())
95
+ # forward
96
+ with torch.no_grad():
97
+ cF = vgg(contentV)
98
+
99
+ if(opt.layer == 'r41'):
100
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
101
+ else:
102
+ feature,transmatrix = matrix(cF,sF)
103
+ transfer = dec(feature)
104
+
105
+ transfer = transfer.clamp(0,1)
106
+ result_frames.append(transfer.squeeze(0).cpu().numpy())
107
+
108
+ makeVideo(contents,style,result_frames,opt.outf)
graph_networks/LinearStyleTransfer/Train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from libs.Loader import Dataset
7
+ from libs.Matrix import MulLayer
8
+ import torchvision.utils as vutils
9
+ import torch.backends.cudnn as cudnn
10
+ from libs.utils import print_options
11
+ from libs.Criterion import LossCriterion
12
+ from libs.models import encoder3,encoder4
13
+ from libs.models import decoder3,decoder4
14
+ from libs.models import encoder5 as loss_network
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
18
+ help='pre-trained encoder path')
19
+ parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth',
20
+ help='used for loss network')
21
+ parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
22
+ help='pre-trained decoder path')
23
+ parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/",
24
+ help='path to wikiArt dataset')
25
+ parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
26
+ help='path to MSCOCO dataset')
27
+ parser.add_argument("--outf", default="trainingOutput/",
28
+ help='folder to output images and model checkpoints')
29
+ parser.add_argument("--content_layers", default="r41",
30
+ help='layers for content')
31
+ parser.add_argument("--style_layers", default="r11,r21,r31,r41",
32
+ help='layers for style')
33
+ parser.add_argument("--batchSize", type=int,default=8,
34
+ help='batch size')
35
+ parser.add_argument("--niter", type=int,default=100000,
36
+ help='iterations to train the model')
37
+ parser.add_argument('--loadSize', type=int, default=300,
38
+ help='scale image size')
39
+ parser.add_argument('--fineSize', type=int, default=256,
40
+ help='crop image size')
41
+ parser.add_argument("--lr", type=float, default=1e-4,
42
+ help='learning rate')
43
+ parser.add_argument("--content_weight", type=float, default=1.0,
44
+ help='content loss weight')
45
+ parser.add_argument("--style_weight", type=float, default=0.02,
46
+ help='style loss weight')
47
+ parser.add_argument("--log_interval", type=int, default=500,
48
+ help='log interval')
49
+ parser.add_argument("--gpu_id", type=int, default=0,
50
+ help='which gpu to use')
51
+ parser.add_argument("--save_interval", type=int, default=5000,
52
+ help='checkpoint save interval')
53
+ parser.add_argument("--layer", default="r41",
54
+ help='which features to transfer, either r31 or r41')
55
+
56
+ ################# PREPARATIONS #################
57
+ opt = parser.parse_args()
58
+ opt.content_layers = opt.content_layers.split(',')
59
+ opt.style_layers = opt.style_layers.split(',')
60
+ opt.cuda = torch.cuda.is_available()
61
+ if(opt.cuda):
62
+ torch.cuda.set_device(opt.gpu_id)
63
+
64
+ os.makedirs(opt.outf,exist_ok=True)
65
+ cudnn.benchmark = True
66
+ print_options(opt)
67
+
68
+ ################# DATA #################
69
+ content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
70
+ content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset,
71
+ batch_size = opt.batchSize,
72
+ shuffle = True,
73
+ num_workers = 1,
74
+ drop_last = True)
75
+ content_loader = iter(content_loader_)
76
+ style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize)
77
+ style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset,
78
+ batch_size = opt.batchSize,
79
+ shuffle = True,
80
+ num_workers = 1,
81
+ drop_last = True)
82
+ style_loader = iter(style_loader_)
83
+
84
+ ################# MODEL #################
85
+ vgg5 = loss_network()
86
+ if(opt.layer == 'r31'):
87
+ matrix = MulLayer('r31')
88
+ vgg = encoder3()
89
+ dec = decoder3()
90
+ elif(opt.layer == 'r41'):
91
+ matrix = MulLayer('r41')
92
+ vgg = encoder4()
93
+ dec = decoder4()
94
+ vgg.load_state_dict(torch.load(opt.vgg_dir))
95
+ dec.load_state_dict(torch.load(opt.decoder_dir))
96
+ vgg5.load_state_dict(torch.load(opt.loss_network_dir))
97
+
98
+ for param in vgg.parameters():
99
+ param.requires_grad = False
100
+ for param in vgg5.parameters():
101
+ param.requires_grad = False
102
+ for param in dec.parameters():
103
+ param.requires_grad = False
104
+
105
+ ################# LOSS & OPTIMIZER #################
106
+ criterion = LossCriterion(opt.style_layers,
107
+ opt.content_layers,
108
+ opt.style_weight,
109
+ opt.content_weight)
110
+ optimizer = optim.Adam(matrix.parameters(), opt.lr)
111
+
112
+ ################# GLOBAL VARIABLE #################
113
+ contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
114
+ styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
115
+
116
+ ################# GPU #################
117
+ if(opt.cuda):
118
+ vgg.cuda()
119
+ dec.cuda()
120
+ vgg5.cuda()
121
+ matrix.cuda()
122
+ contentV = contentV.cuda()
123
+ styleV = styleV.cuda()
124
+
125
+ ################# TRAINING #################
126
+ def adjust_learning_rate(optimizer, iteration):
127
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
128
+ for param_group in optimizer.param_groups:
129
+ param_group['lr'] = opt.lr / (1+iteration*1e-5)
130
+
131
+ for iteration in range(1,opt.niter+1):
132
+ optimizer.zero_grad()
133
+ try:
134
+ content,_ = content_loader.next()
135
+ except IOError:
136
+ content,_ = content_loader.next()
137
+ except StopIteration:
138
+ content_loader = iter(content_loader_)
139
+ content,_ = content_loader.next()
140
+ except:
141
+ continue
142
+
143
+ try:
144
+ style,_ = style_loader.next()
145
+ except IOError:
146
+ style,_ = style_loader.next()
147
+ except StopIteration:
148
+ style_loader = iter(style_loader_)
149
+ style,_ = style_loader.next()
150
+ except:
151
+ continue
152
+
153
+ contentV.resize_(content.size()).copy_(content)
154
+ styleV.resize_(style.size()).copy_(style)
155
+
156
+ # forward
157
+ sF = vgg(styleV)
158
+ cF = vgg(contentV)
159
+
160
+ if(opt.layer == 'r41'):
161
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
162
+ else:
163
+ feature,transmatrix = matrix(cF,sF)
164
+ transfer = dec(feature)
165
+
166
+ sF_loss = vgg5(styleV)
167
+ cF_loss = vgg5(contentV)
168
+ tF = vgg5(transfer)
169
+ loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss)
170
+
171
+ # backward & optimization
172
+ loss.backward()
173
+ optimizer.step()
174
+ print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'%
175
+ (opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr']))
176
+
177
+ adjust_learning_rate(optimizer,iteration)
178
+
179
+ if((iteration) % opt.log_interval == 0):
180
+ transfer = transfer.clamp(0,1)
181
+ concat = torch.cat((content,style,transfer.cpu()),dim=0)
182
+ vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize)
183
+
184
+ if(iteration > 0 and (iteration) % opt.save_interval == 0):
185
+ torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer))
graph_networks/LinearStyleTransfer/TrainSPN.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import argparse
4
+
5
+ from libs.SPN import SPN
6
+ from libs.Loader import Dataset
7
+ from libs.models import encoder4
8
+ from libs.models import decoder4
9
+ from libs.utils import print_options
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+ import torch.nn.functional as F
15
+ import torchvision.utils as vutils
16
+ import torch.backends.cudnn as cudnn
17
+ import torchvision.transforms as transforms
18
+
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
21
+ help='pre-trained encoder path')
22
+ parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
23
+ help='pre-trained decoder path')
24
+ parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
25
+ help='path to MSCOCO dataset')
26
+ parser.add_argument("--outf", default="trainingSPNOutput/",
27
+ help='folder to output images and model checkpoints')
28
+ parser.add_argument("--layer", default="r41",
29
+ help='layers for content')
30
+ parser.add_argument("--batchSize", type=int,default=8,
31
+ help='batch size')
32
+ parser.add_argument("--niter", type=int,default=100000,
33
+ help='iterations to train the model')
34
+ parser.add_argument('--loadSize', type=int, default=512,
35
+ help='scale image size')
36
+ parser.add_argument('--fineSize', type=int, default=256,
37
+ help='crop image size')
38
+ parser.add_argument("--lr", type=float, default=1e-3,
39
+ help='learning rate')
40
+ parser.add_argument("--log_interval", type=int, default=500,
41
+ help='log interval')
42
+ parser.add_argument("--save_interval", type=int, default=5000,
43
+ help='checkpoint save interval')
44
+ parser.add_argument("--spn_num", type=int, default=1,
45
+ help='number of spn filters')
46
+
47
+ ################# PREPARATIONS #################
48
+ opt = parser.parse_args()
49
+ opt.cuda = torch.cuda.is_available()
50
+ print_options(opt)
51
+
52
+
53
+ os.makedirs(opt.outf, exist_ok = True)
54
+
55
+ cudnn.benchmark = True
56
+
57
+ ################# DATA #################
58
+ content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
59
+ content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset,
60
+ batch_size = opt.batchSize,
61
+ shuffle = True,
62
+ num_workers = 4,
63
+ drop_last = True)
64
+ content_loader = iter(content_loader_)
65
+
66
+ ################# MODEL #################
67
+ spn = SPN(spn=opt.spn_num)
68
+ if(opt.layer == 'r31'):
69
+ vgg = encoder3()
70
+ dec = decoder3()
71
+ elif(opt.layer == 'r41'):
72
+ vgg = encoder4()
73
+ dec = decoder4()
74
+ vgg.load_state_dict(torch.load(opt.vgg_dir))
75
+ dec.load_state_dict(torch.load(opt.decoder_dir))
76
+
77
+ for param in vgg.parameters():
78
+ param.requires_grad = False
79
+ for param in dec.parameters():
80
+ param.requires_grad = False
81
+
82
+ ################# LOSS & OPTIMIZER #################
83
+ criterion = nn.MSELoss(size_average=False)
84
+ #optimizer_spn = optim.SGD(spn.parameters(), opt.lr)
85
+ optimizer_spn = optim.Adam(spn.parameters(), opt.lr)
86
+
87
+ ################# GLOBAL VARIABLE #################
88
+ contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
89
+
90
+ ################# GPU #################
91
+ if(opt.cuda):
92
+ vgg.cuda()
93
+ dec.cuda()
94
+ spn.cuda()
95
+ contentV = contentV.cuda()
96
+
97
+ ################# TRAINING #################
98
+ def adjust_learning_rate(optimizer, iteration):
99
+ for param_group in optimizer.param_groups:
100
+ param_group['lr'] = opt.lr / (1+iteration*1e-5)
101
+
102
+ spn.train()
103
+ for iteration in range(1,opt.niter+1):
104
+ optimizer_spn.zero_grad()
105
+ try:
106
+ content,_ = content_loader.next()
107
+ except IOError:
108
+ content,_ = content_loader.next()
109
+ except StopIteration:
110
+ content_loader = iter(content_loader_)
111
+ content,_ = content_loader.next()
112
+ except:
113
+ continue
114
+
115
+ contentV.resize_(content.size()).copy_(content)
116
+
117
+ # forward
118
+ cF = vgg(contentV)
119
+ transfer = dec(cF['r41'])
120
+
121
+
122
+ propagated = spn(transfer,contentV)
123
+ loss = criterion(propagated,contentV)
124
+
125
+ # backward & optimization
126
+ loss.backward()
127
+ #nn.utils.clip_grad_norm(spn.parameters(), 1000)
128
+ optimizer_spn.step()
129
+ print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f'
130
+ %(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr']))
131
+
132
+ adjust_learning_rate(optimizer_spn,iteration)
133
+
134
+ if((iteration) % opt.log_interval == 0):
135
+ transfer = transfer.clamp(0,1)
136
+ propagated = propagated.clamp(0,1)
137
+ vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration))
138
+ vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration))
139
+
140
+ if(iteration > 0 and (iteration) % opt.save_interval == 0):
141
+ torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer))
graph_networks/LinearStyleTransfer/__init__.py ADDED
File without changes
graph_networks/LinearStyleTransfer/libs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
graph_networks/LinearStyleTransfer/libs/Criterion.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class styleLoss(nn.Module):
5
+ def forward(self,input,target):
6
+ ib,ic,ih,iw = input.size()
7
+ iF = input.view(ib,ic,-1)
8
+ iMean = torch.mean(iF,dim=2)
9
+ iCov = GramMatrix()(input)
10
+
11
+ tb,tc,th,tw = target.size()
12
+ tF = target.view(tb,tc,-1)
13
+ tMean = torch.mean(tF,dim=2)
14
+ tCov = GramMatrix()(target)
15
+
16
+ loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
17
+ return loss/tb
18
+
19
+ class GramMatrix(nn.Module):
20
+ def forward(self,input):
21
+ b, c, h, w = input.size()
22
+ f = input.view(b,c,h*w) # bxcx(hxw)
23
+ # torch.bmm(batch1, batch2, out=None) #
24
+ # batch1: bxmxp, batch2: bxpxn -> bxmxn #
25
+ G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
26
+ return G.div_(c*h*w)
27
+
28
+ class LossCriterion(nn.Module):
29
+ def __init__(self,style_layers,content_layers,style_weight,content_weight):
30
+ super(LossCriterion,self).__init__()
31
+
32
+ self.style_layers = style_layers
33
+ self.content_layers = content_layers
34
+ self.style_weight = style_weight
35
+ self.content_weight = content_weight
36
+
37
+ self.styleLosses = [styleLoss()] * len(style_layers)
38
+ self.contentLosses = [nn.MSELoss()] * len(content_layers)
39
+
40
+ def forward(self,tF,sF,cF):
41
+ # content loss
42
+ totalContentLoss = 0
43
+ for i,layer in enumerate(self.content_layers):
44
+ cf_i = cF[layer]
45
+ cf_i = cf_i.detach()
46
+ tf_i = tF[layer]
47
+ loss_i = self.contentLosses[i]
48
+ totalContentLoss += loss_i(tf_i,cf_i)
49
+ totalContentLoss = totalContentLoss * self.content_weight
50
+
51
+ # style loss
52
+ totalStyleLoss = 0
53
+ for i,layer in enumerate(self.style_layers):
54
+ sf_i = sF[layer]
55
+ sf_i = sf_i.detach()
56
+ tf_i = tF[layer]
57
+ loss_i = self.styleLosses[i]
58
+ totalStyleLoss += loss_i(tf_i,sf_i)
59
+ totalStyleLoss = totalStyleLoss * self.style_weight
60
+ loss = totalStyleLoss + totalContentLoss
61
+
62
+ return loss,totalStyleLoss,totalContentLoss
graph_networks/LinearStyleTransfer/libs/Loader.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch.utils.data as data
4
+ import torchvision.transforms as transforms
5
+
6
+ def is_image_file(filename):
7
+ return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
8
+
9
+ def default_loader(path):
10
+ return Image.open(path).convert('RGB')
11
+
12
+ class Dataset(data.Dataset):
13
+ def __init__(self,dataPath,loadSize,fineSize,test=False,video=False):
14
+ super(Dataset,self).__init__()
15
+ self.dataPath = dataPath
16
+ self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)]
17
+ self.image_list = sorted(self.image_list)
18
+ if(video):
19
+ self.image_list = sorted(self.image_list)
20
+ if not test:
21
+ self.transform = transforms.Compose([
22
+ transforms.Resize(fineSize),
23
+ transforms.RandomCrop(fineSize),
24
+ transforms.RandomHorizontalFlip(),
25
+ transforms.ToTensor()])
26
+ else:
27
+ self.transform = transforms.Compose([
28
+ transforms.Resize(fineSize),
29
+ transforms.ToTensor()])
30
+
31
+ self.test = test
32
+
33
+ def __getitem__(self,index):
34
+ dataPath = os.path.join(self.dataPath,self.image_list[index])
35
+
36
+ Img = default_loader(dataPath)
37
+ ImgA = self.transform(Img)
38
+
39
+ imgName = self.image_list[index]
40
+ imgName = imgName.split('.')[0]
41
+ return ImgA,imgName
42
+
43
+ def __len__(self):
44
+ return len(self.image_list)
graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torchvision.transforms as transforms
3
+ import torchvision.utils as vutils
4
+ import torch.utils.data as data
5
+ from os import listdir
6
+ from os.path import join
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+ import torch.nn as nn
11
+ from torch.autograd import Variable
12
+ import numpy as np
13
+ from libs.utils import whiten
14
+
15
+ def is_image_file(filename):
16
+ return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
17
+
18
+ def default_loader(path,fineSize):
19
+ img = Image.open(path).convert('RGB')
20
+ w,h = img.size
21
+ if(w < h):
22
+ neww = fineSize
23
+ newh = h * neww / w
24
+ newh = int(newh / 8) * 8
25
+ else:
26
+ newh = fineSize
27
+ neww = w * newh / h
28
+ neww = int(neww / 8) * 8
29
+ img = img.resize((neww,newh))
30
+ return img
31
+
32
+ def MaskHelper(seg,color):
33
+ # green
34
+ mask = torch.Tensor()
35
+ if(color == 'green'):
36
+ mask = torch.lt(seg[0],0.1)
37
+ mask = torch.mul(mask,torch.gt(seg[1],1-0.1))
38
+ mask = torch.mul(mask,torch.lt(seg[2],0.1))
39
+ elif(color == 'black'):
40
+ mask = torch.lt(seg[0], 0.1)
41
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
42
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
43
+ elif(color == 'white'):
44
+ mask = torch.gt(seg[0], 1-0.1)
45
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
46
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
47
+ elif(color == 'red'):
48
+ mask = torch.gt(seg[0], 1-0.1)
49
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
50
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
51
+ elif(color == 'blue'):
52
+ mask = torch.lt(seg[0], 0.1)
53
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
54
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
55
+ elif(color == 'yellow'):
56
+ mask = torch.gt(seg[0], 1-0.1)
57
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
58
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
59
+ elif(color == 'grey'):
60
+ mask = torch.lt(seg[0], 0.1)
61
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
62
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
63
+ elif(color == 'lightblue'):
64
+ mask = torch.lt(seg[0], 0.1)
65
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
66
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
67
+ elif(color == 'purple'):
68
+ mask = torch.gt(seg[0], 1-0.1)
69
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
70
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
71
+ else:
72
+ print('MaskHelper(): color not recognized, color = ' + color)
73
+ return mask.float()
74
+
75
+ def ExtractMask(Seg):
76
+ # Given segmentation for content and style, we get a list of segmentation for each color
77
+ '''
78
+ Test Code:
79
+ content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg)
80
+ for i,mask in enumerate(content_masks):
81
+ vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True)
82
+ for i,mask in enumerate(style_masks):
83
+ vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True)
84
+ '''
85
+ color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple']
86
+ masks = []
87
+ for color in color_codes:
88
+ mask = MaskHelper(Seg,color)
89
+ masks.append(mask)
90
+ return masks
91
+
92
+ def calculate_size(h,w,fineSize):
93
+ if(h > w):
94
+ newh = fineSize
95
+ neww = int(w * 1.0 * newh / h)
96
+ else:
97
+ neww = fineSize
98
+ newh = int(h * 1.0 * neww / w)
99
+ newh = (newh // 8) * 8
100
+ neww = (neww // 8) * 8
101
+ return neww, newh
102
+
103
+ class Dataset(data.Dataset):
104
+ def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize):
105
+ super(Dataset,self).__init__()
106
+ self.contentPath = contentPath
107
+ self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
108
+ self.stylePath = stylePath
109
+ self.contentSegPath = contentSegPath
110
+ self.styleSegPath = styleSegPath
111
+ self.fineSize = fineSize
112
+
113
+ def __getitem__(self,index):
114
+ contentImgPath = os.path.join(self.contentPath,self.image_list[index])
115
+ styleImgPath = os.path.join(self.stylePath,self.image_list[index])
116
+ contentImg = default_loader(contentImgPath,self.fineSize)
117
+ styleImg = default_loader(styleImgPath,self.fineSize)
118
+
119
+ try:
120
+ contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index])
121
+ contentSegImg = default_loader(contentSegImgPath,self.fineSize)
122
+ except :
123
+ print('no mask provided, fake a whole black one')
124
+ contentSegImg = Image.new('RGB', (contentImg.size))
125
+
126
+ try:
127
+ styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index])
128
+ styleSegImg = default_loader(styleSegImgPath,self.fineSize)
129
+ except :
130
+ print('no mask provided, fake a whole black one')
131
+ styleSegImg = Image.new('RGB', (styleImg.size))
132
+
133
+
134
+ hs, ws = styleImg.size
135
+ newhs, newws = calculate_size(hs,ws,self.fineSize)
136
+
137
+ transform = transforms.Compose([
138
+ transforms.Resize((newhs, newws)),
139
+ transforms.ToTensor()])
140
+ # Turning segmentation images into masks
141
+ styleSegImg = transform(styleSegImg)
142
+ styleImgArbi = transform(styleImg)
143
+
144
+ hc, wc = contentImg.size
145
+ newhc, newwc = calculate_size(hc,wc,self.fineSize)
146
+
147
+ transform = transforms.Compose([
148
+ transforms.Resize((newhc, newwc)),
149
+ transforms.ToTensor()])
150
+ contentSegImg = transform(contentSegImg)
151
+ contentImgArbi = transform(contentImg)
152
+
153
+ content_masks = ExtractMask(contentSegImg)
154
+ style_masks = ExtractMask(styleSegImg)
155
+
156
+ ImgW = whiten(contentImgArbi.view(3,-1).double())
157
+ ImgW = ImgW.view(contentImgArbi.size()).float()
158
+
159
+ return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index]
160
+
161
+ def __len__(self):
162
+ return len(self.image_list)
graph_networks/LinearStyleTransfer/libs/Matrix.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CNN(nn.Module):
5
+ def __init__(self,layer,matrixSize=32):
6
+ super(CNN,self).__init__()
7
+ if(layer == 'r31'):
8
+ # 256x64x64
9
+ self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
10
+ nn.ReLU(inplace=True),
11
+ nn.Conv2d(128,64,3,1,1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(64,matrixSize,3,1,1))
14
+ elif(layer == 'r41'):
15
+ # 512x32x32
16
+ self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(256,128,3,1,1),
19
+ nn.ReLU(inplace=True),
20
+ nn.Conv2d(128,matrixSize,3,1,1))
21
+
22
+ # 32x8x8
23
+ self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
24
+ #self.fc = nn.Linear(32*64,256*256)
25
+
26
+ def forward(self,x):
27
+ out = self.convs(x)
28
+ # 32x8x8
29
+ b,c,h,w = out.size()
30
+ out = out.view(b,c,-1)
31
+ # 32x64
32
+ out = torch.bmm(out,out.transpose(1,2)).div(h*w)
33
+ # 32x32
34
+ out = out.view(out.size(0),-1)
35
+ return self.fc(out)
36
+
37
+ class MulLayer(nn.Module):
38
+ def __init__(self,layer,matrixSize=32):
39
+ super(MulLayer,self).__init__()
40
+ self.snet = CNN(layer,matrixSize)
41
+ self.cnet = CNN(layer,matrixSize)
42
+ self.matrixSize = matrixSize
43
+
44
+ if(layer == 'r41'):
45
+ self.compress = nn.Conv2d(512,matrixSize,1,1,0)
46
+ self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
47
+ elif(layer == 'r31'):
48
+ self.compress = nn.Conv2d(256,matrixSize,1,1,0)
49
+ self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
50
+ self.transmatrix = None
51
+
52
+ def forward(self,cF,sF,trans=True):
53
+ cFBK = cF.clone()
54
+ cb,cc,ch,cw = cF.size()
55
+ cFF = cF.view(cb,cc,-1)
56
+ cMean = torch.mean(cFF,dim=2,keepdim=True)
57
+ cMean = cMean.unsqueeze(3)
58
+ cMean = cMean.expand_as(cF)
59
+ cF = cF - cMean
60
+
61
+ sb,sc,sh,sw = sF.size()
62
+ sFF = sF.view(sb,sc,-1)
63
+ sMean = torch.mean(sFF,dim=2,keepdim=True)
64
+ sMean = sMean.unsqueeze(3)
65
+ sMeanC = sMean.expand_as(cF)
66
+ sMeanS = sMean.expand_as(sF)
67
+ sF = sF - sMeanS
68
+
69
+
70
+ compress_content = self.compress(cF)
71
+ b,c,h,w = compress_content.size()
72
+ compress_content = compress_content.view(b,c,-1)
73
+
74
+ if(trans):
75
+ cMatrix = self.cnet(cF)
76
+ sMatrix = self.snet(sF)
77
+
78
+ sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
79
+ cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
80
+ transmatrix = torch.bmm(sMatrix,cMatrix)
81
+ print(cMatrix)
82
+ transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
83
+ out = self.unzip(transfeature.view(b,c,h,w))
84
+ out = out + sMeanC
85
+ return out, transmatrix
86
+ else:
87
+ out = self.unzip(compress_content.view(b,c,h,w))
88
+ out = out + cMean
89
+ return out
graph_networks/LinearStyleTransfer/libs/MatrixTest.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import cv2
6
+ from torch.autograd import Variable
7
+ import torchvision.utils as vutils
8
+
9
+
10
+ class CNN(nn.Module):
11
+ def __init__(self,layer,matrixSize=32):
12
+ super(CNN,self).__init__()
13
+ # 256x64x64
14
+ if(layer == 'r31'):
15
+ self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(128,64,3,1,1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(64,matrixSize,3,1,1))
20
+ elif(layer == 'r41'):
21
+ # 512x32x32
22
+ self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(256,128,3,1,1),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(128,matrixSize,3,1,1))
27
+ self.fc = nn.Linear(32*32,32*32)
28
+
29
+ def forward(self,x,masks,style=False):
30
+ color_code_number = 9
31
+ xb,xc,xh,xw = x.size()
32
+ x = x.view(xc,-1)
33
+ feature_sub_mean = x.clone()
34
+ for i in range(color_code_number):
35
+ mask = masks[i].clone().squeeze(0)
36
+ mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST)
37
+ mask = torch.FloatTensor(mask)
38
+ mask = mask.long()
39
+ if(torch.sum(mask) >= 10):
40
+ mask = mask.view(-1)
41
+
42
+ # dilation here
43
+ """
44
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
45
+ mask = mask.cpu().numpy()
46
+ mask = cv2.dilate(mask.astype(np.float32), kernel)
47
+ mask = torch.from_numpy(mask)
48
+ mask = mask.squeeze()
49
+ """
50
+
51
+ fgmask = (mask>0).nonzero().squeeze(1)
52
+ fgmask = fgmask.cuda()
53
+ selectFeature = torch.index_select(x,1,fgmask) # 32x96
54
+ # subtract mean
55
+ f_mean = torch.mean(selectFeature,1)
56
+ f_mean = f_mean.unsqueeze(1).expand_as(selectFeature)
57
+ selectFeature = selectFeature - f_mean
58
+ feature_sub_mean.index_copy_(1,fgmask,selectFeature)
59
+
60
+ feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw))
61
+ # 32x16x16
62
+ b,c,h,w = feature.size()
63
+ transMatrices = {}
64
+ feature = feature.view(c,-1)
65
+
66
+ for i in range(color_code_number):
67
+ mask = masks[i].clone().squeeze(0)
68
+ mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST)
69
+ mask = torch.FloatTensor(mask)
70
+ mask = mask.long()
71
+ if(torch.sum(mask) >= 10):
72
+ mask = mask.view(-1)
73
+ fgmask = Variable((mask==1).nonzero().squeeze(1))
74
+ fgmask = fgmask.cuda()
75
+ selectFeature = torch.index_select(feature,1,fgmask) # 32x96
76
+ tc,tN = selectFeature.size()
77
+
78
+ covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN)
79
+ transmatrix = self.fc(covMatrix.view(-1))
80
+ transMatrices[i] = transmatrix
81
+ return transMatrices,feature_sub_mean
82
+
83
+ class MulLayer(nn.Module):
84
+ def __init__(self,layer,matrixSize=32):
85
+ super(MulLayer,self).__init__()
86
+ self.snet = CNN(layer)
87
+ self.cnet = CNN(layer)
88
+ self.matrixSize = matrixSize
89
+
90
+ if(layer == 'r41'):
91
+ self.compress = nn.Conv2d(512,matrixSize,1,1,0)
92
+ self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
93
+ elif(layer == 'r31'):
94
+ self.compress = nn.Conv2d(256,matrixSize,1,1,0)
95
+ self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
96
+
97
+ def forward(self,cF,sF,cmasks,smasks):
98
+
99
+ sb,sc,sh,sw = sF.size()
100
+
101
+ sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True)
102
+ cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False)
103
+
104
+ compress_content = self.compress(cF_sub_mean.view(cF.size()))
105
+ cb,cc,ch,cw = compress_content.size()
106
+ compress_content = compress_content.view(cc,-1)
107
+ transfeature = compress_content.clone()
108
+ color_code_number = 9
109
+ finalSMean = Variable(torch.zeros(cF.size()).cuda(0))
110
+ finalSMean = finalSMean.view(sc,-1)
111
+ for i in range(color_code_number):
112
+ cmask = cmasks[i].clone().squeeze(0)
113
+ smask = smasks[i].clone().squeeze(0)
114
+
115
+ cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST)
116
+ cmask = torch.FloatTensor(cmask)
117
+ cmask = cmask.long()
118
+ smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST)
119
+ smask = torch.FloatTensor(smask)
120
+ smask = smask.long()
121
+ if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10
122
+ and (i in sMatrices) and (i in cMatrices)):
123
+ cmask = cmask.view(-1)
124
+ fgcmask = Variable((cmask==1).nonzero().squeeze(1))
125
+ fgcmask = fgcmask.cuda()
126
+
127
+ smask = smask.view(-1)
128
+ fgsmask = Variable((smask==1).nonzero().squeeze(1))
129
+ fgsmask = fgsmask.cuda()
130
+
131
+ sFF = sF.view(sc,-1)
132
+ sFF_select = torch.index_select(sFF,1,fgsmask)
133
+ sMean = torch.mean(sFF_select,dim=1,keepdim=True)
134
+ sMean = sMean.view(1,sc,1,1)
135
+ sMean = sMean.expand_as(cF)
136
+
137
+ sMatrix = sMatrices[i]
138
+ cMatrix = cMatrices[i]
139
+
140
+ sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
141
+ cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
142
+
143
+ transmatrix = torch.mm(sMatrix,cMatrix) # (C*C)
144
+
145
+ compress_content_select = torch.index_select(compress_content,1,fgcmask)
146
+
147
+ transfeatureFG = torch.mm(transmatrix,compress_content_select)
148
+ transfeature.index_copy_(1,fgcmask,transfeatureFG)
149
+
150
+ sMean = sMean.contiguous()
151
+ sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask)
152
+ finalSMean.index_copy_(1,fgcmask,sMean_select)
153
+ out = self.unzip(transfeature.view(cb,cc,ch,cw))
154
+ return out + finalSMean.view(out.size())
graph_networks/LinearStyleTransfer/libs/SPN.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import vgg16
4
+ from torch.autograd import Variable
5
+ from collections import OrderedDict
6
+ import torch.nn.functional as F
7
+ import sys
8
+ sys.path.append('../')
9
+ from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
10
+
11
+ class spn_block(nn.Module):
12
+ def __init__(self, horizontal, reverse):
13
+ super(spn_block, self).__init__()
14
+ self.propagator = GateRecurrent2dnoind(horizontal,reverse)
15
+
16
+ def forward(self,x,G1,G2,G3):
17
+ sum_abs = G1.abs() + G2.abs() + G3.abs()
18
+ sum_abs.data[sum_abs.data == 0] = 1e-6
19
+ mask_need_norm = sum_abs.ge(1)
20
+ mask_need_norm = mask_need_norm.float()
21
+ G1_norm = torch.div(G1, sum_abs)
22
+ G2_norm = torch.div(G2, sum_abs)
23
+ G3_norm = torch.div(G3, sum_abs)
24
+
25
+ G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
26
+ G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
27
+ G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
28
+
29
+ return self.propagator(x,G1,G2,G3)
30
+
31
+ class VGG(nn.Module):
32
+ def __init__(self,nf):
33
+ super(VGG,self).__init__()
34
+ self.conv1 = nn.Conv2d(3,nf,3,padding = 1)
35
+ # 256 x 256
36
+ self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
37
+ self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1)
38
+ # 128 x 128
39
+ self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
40
+ self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1)
41
+ # 64 x 64
42
+ self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
43
+ # 32 x 32
44
+ self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1)
45
+
46
+ def forward(self,x):
47
+ output = {}
48
+ output['conv1'] = self.conv1(x)
49
+ x = F.relu(output['conv1'])
50
+ x = self.pool1(x)
51
+ output['conv2'] = self.conv2(x)
52
+ # 128 x 128
53
+ x = F.relu(output['conv2'])
54
+ x = self.pool2(x)
55
+ output['conv3'] = self.conv3(x)
56
+ # 64 x 64
57
+ x = F.relu(output['conv3'])
58
+ output['pool3'] = self.pool3(x)
59
+ # 32 x 32
60
+ output['conv4'] = self.conv4(output['pool3'])
61
+ return output
62
+
63
+ class Decoder(nn.Module):
64
+ def __init__(self,nf=32,spn=1):
65
+ super(Decoder,self).__init__()
66
+ # 32 x 32
67
+ self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5
68
+ self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear')
69
+ self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8
70
+ nn.ELU(inplace=True))
71
+ # 64 x 64
72
+ self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear')
73
+ self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8
74
+ nn.ELU(inplace=True))
75
+ # 128 x 128
76
+ self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear')
77
+ self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8
78
+ nn.ELU(inplace=True))
79
+ if(spn == 1):
80
+ self.layer7 = nn.Conv2d(nf,nf*12,3,1,1)
81
+ else:
82
+ self.layer7 = nn.Conv2d(nf,nf*24,3,1,1)
83
+ self.spn = spn
84
+ # 256 x 256
85
+
86
+ def forward(self,encode_feature):
87
+ output = {}
88
+ output['0'] = self.layer0(encode_feature['conv4'])
89
+ output['1'] = self.layer1(output['0'])
90
+
91
+ output['2'] = self.layer2(output['1'])
92
+ output['2res'] = output['2'] + encode_feature['conv3']
93
+ # 64 x 64
94
+
95
+ output['3'] = self.layer3(output['2res'])
96
+ output['4'] = self.layer4(output['3'])
97
+ output['4res'] = output['4'] + encode_feature['conv2']
98
+ # 128 x 128
99
+
100
+ output['5'] = self.layer5(output['4res'])
101
+ output['6'] = self.layer6(output['5'])
102
+ output['6res'] = output['6'] + encode_feature['conv1']
103
+
104
+ output['7'] = self.layer7(output['6res'])
105
+
106
+ return output['7']
107
+
108
+
109
+ class SPN(nn.Module):
110
+ def __init__(self,nf=32,spn=1):
111
+ super(SPN,self).__init__()
112
+ # conv for mask
113
+ self.mask_conv = nn.Conv2d(3,nf,3,1,1)
114
+
115
+ # guidance network
116
+ self.encoder = VGG(nf)
117
+ self.decoder = Decoder(nf,spn)
118
+
119
+ # spn blocks
120
+ self.left_right = spn_block(True,False)
121
+ self.right_left = spn_block(True,True)
122
+ self.top_down = spn_block(False, False)
123
+ self.down_top = spn_block(False,True)
124
+
125
+ # post upsample
126
+ self.post = nn.Conv2d(nf,3,3,1,1)
127
+ self.nf = nf
128
+
129
+ def forward(self,x,rgb):
130
+ # feature for mask
131
+ X = self.mask_conv(x)
132
+
133
+ # guidance
134
+ features = self.encoder(rgb)
135
+ guide = self.decoder(features)
136
+
137
+ G = torch.split(guide,self.nf,1)
138
+ out1 = self.left_right(X,G[0],G[1],G[2])
139
+ out2 = self.right_left(X,G[3],G[4],G[5])
140
+ out3 = self.top_down(X,G[6],G[7],G[8])
141
+ out4 = self.down_top(X,G[9],G[10],G[11])
142
+
143
+ out = torch.max(out1,out2)
144
+ out = torch.max(out,out3)
145
+ out = torch.max(out,out4)
146
+
147
+ return self.post(out)
148
+
149
+ if __name__ == '__main__':
150
+ spn = SPN()
151
+ spn = spn.cuda()
152
+ for i in range(100):
153
+ x = Variable(torch.Tensor(1,3,256,256)).cuda()
154
+ rgb = Variable(torch.Tensor(1,3,256,256)).cuda()
155
+ output = spn(x,rgb)
156
+ print(output.size())
graph_networks/LinearStyleTransfer/libs/__init__.py ADDED
File without changes
graph_networks/LinearStyleTransfer/libs/models.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class encoder3(nn.Module):
5
+ def __init__(self):
6
+ super(encoder3,self).__init__()
7
+ # vgg
8
+ # 224 x 224
9
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
10
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
11
+ # 226 x 226
12
+
13
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
14
+ self.relu2 = nn.ReLU(inplace=True)
15
+ # 224 x 224
16
+
17
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
18
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
19
+ self.relu3 = nn.ReLU(inplace=True)
20
+ # 224 x 224
21
+
22
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
23
+ # 112 x 112
24
+
25
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
26
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
27
+ self.relu4 = nn.ReLU(inplace=True)
28
+ # 112 x 112
29
+
30
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
31
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
32
+ self.relu5 = nn.ReLU(inplace=True)
33
+ # 112 x 112
34
+
35
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
36
+ # 56 x 56
37
+
38
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
39
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
40
+ self.relu6 = nn.ReLU(inplace=True)
41
+ # 56 x 56
42
+ def forward(self,x):
43
+ out = self.conv1(x)
44
+ out = self.reflecPad1(out)
45
+ out = self.conv2(out)
46
+ out = self.relu2(out)
47
+ out = self.reflecPad3(out)
48
+ out = self.conv3(out)
49
+ pool1 = self.relu3(out)
50
+ out,pool_idx = self.maxPool(pool1)
51
+ out = self.reflecPad4(out)
52
+ out = self.conv4(out)
53
+ out = self.relu4(out)
54
+ out = self.reflecPad5(out)
55
+ out = self.conv5(out)
56
+ pool2 = self.relu5(out)
57
+ out,pool_idx2 = self.maxPool2(pool2)
58
+ out = self.reflecPad6(out)
59
+ out = self.conv6(out)
60
+ out = self.relu6(out)
61
+ return out
62
+
63
+ class decoder3(nn.Module):
64
+ def __init__(self):
65
+ super(decoder3,self).__init__()
66
+ # decoder
67
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
68
+ self.conv7 = nn.Conv2d(256,128,3,1,0)
69
+ self.relu7 = nn.ReLU(inplace=True)
70
+ # 56 x 56
71
+
72
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
73
+ # 112 x 112
74
+
75
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
76
+ self.conv8 = nn.Conv2d(128,128,3,1,0)
77
+ self.relu8 = nn.ReLU(inplace=True)
78
+ # 112 x 112
79
+
80
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
81
+ self.conv9 = nn.Conv2d(128,64,3,1,0)
82
+ self.relu9 = nn.ReLU(inplace=True)
83
+
84
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
85
+ # 224 x 224
86
+
87
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
88
+ self.conv10 = nn.Conv2d(64,64,3,1,0)
89
+ self.relu10 = nn.ReLU(inplace=True)
90
+
91
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
92
+ self.conv11 = nn.Conv2d(64,3,3,1,0)
93
+
94
+ def forward(self,x):
95
+ output = {}
96
+ out = self.reflecPad7(x)
97
+ out = self.conv7(out)
98
+ out = self.relu7(out)
99
+ out = self.unpool(out)
100
+ out = self.reflecPad8(out)
101
+ out = self.conv8(out)
102
+ out = self.relu8(out)
103
+ out = self.reflecPad9(out)
104
+ out = self.conv9(out)
105
+ out_relu9 = self.relu9(out)
106
+ out = self.unpool2(out_relu9)
107
+ out = self.reflecPad10(out)
108
+ out = self.conv10(out)
109
+ out = self.relu10(out)
110
+ out = self.reflecPad11(out)
111
+ out = self.conv11(out)
112
+ return out
113
+
114
+ class encoder4(nn.Module):
115
+ def __init__(self):
116
+ super(encoder4,self).__init__()
117
+ # vgg
118
+ # 224 x 224
119
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
120
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
121
+ # 226 x 226
122
+
123
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
124
+ self.relu2 = nn.ReLU(inplace=True)
125
+ # 224 x 224
126
+
127
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
128
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
129
+ self.relu3 = nn.ReLU(inplace=True)
130
+ # 224 x 224
131
+
132
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
133
+ # 112 x 112
134
+
135
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
136
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
137
+ self.relu4 = nn.ReLU(inplace=True)
138
+ # 112 x 112
139
+
140
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
141
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
142
+ self.relu5 = nn.ReLU(inplace=True)
143
+ # 112 x 112
144
+
145
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
146
+ # 56 x 56
147
+
148
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
149
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
150
+ self.relu6 = nn.ReLU(inplace=True)
151
+ # 56 x 56
152
+
153
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
154
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
155
+ self.relu7 = nn.ReLU(inplace=True)
156
+ # 56 x 56
157
+
158
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
159
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
160
+ self.relu8 = nn.ReLU(inplace=True)
161
+ # 56 x 56
162
+
163
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
164
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
165
+ self.relu9 = nn.ReLU(inplace=True)
166
+ # 56 x 56
167
+
168
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
169
+ # 28 x 28
170
+
171
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
172
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
173
+ self.relu10 = nn.ReLU(inplace=True)
174
+ # 28 x 28
175
+ def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
176
+ output = {}
177
+ out = self.conv1(x)
178
+ out = self.reflecPad1(out)
179
+ out = self.conv2(out)
180
+ output['r11'] = self.relu2(out)
181
+ out = self.reflecPad7(output['r11'])
182
+
183
+ out = self.conv3(out)
184
+ output['r12'] = self.relu3(out)
185
+
186
+ output['p1'] = self.maxPool(output['r12'])
187
+ out = self.reflecPad4(output['p1'])
188
+ out = self.conv4(out)
189
+ output['r21'] = self.relu4(out)
190
+ out = self.reflecPad7(output['r21'])
191
+
192
+ out = self.conv5(out)
193
+ output['r22'] = self.relu5(out)
194
+
195
+ output['p2'] = self.maxPool2(output['r22'])
196
+ out = self.reflecPad6(output['p2'])
197
+ out = self.conv6(out)
198
+ output['r31'] = self.relu6(out)
199
+ if(matrix31 is not None):
200
+ feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
201
+ out = self.reflecPad7(feature3)
202
+ else:
203
+ out = self.reflecPad7(output['r31'])
204
+ out = self.conv7(out)
205
+ output['r32'] = self.relu7(out)
206
+
207
+ out = self.reflecPad8(output['r32'])
208
+ out = self.conv8(out)
209
+ output['r33'] = self.relu8(out)
210
+
211
+ out = self.reflecPad9(output['r33'])
212
+ out = self.conv9(out)
213
+ output['r34'] = self.relu9(out)
214
+
215
+ output['p3'] = self.maxPool3(output['r34'])
216
+ out = self.reflecPad10(output['p3'])
217
+ out = self.conv10(out)
218
+ output['r41'] = self.relu10(out)
219
+
220
+ return output
221
+
222
+ class decoder4(nn.Module):
223
+ def __init__(self):
224
+ super(decoder4,self).__init__()
225
+ # decoder
226
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
227
+ self.conv11 = nn.Conv2d(512,256,3,1,0)
228
+ self.relu11 = nn.ReLU(inplace=True)
229
+ # 28 x 28
230
+
231
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
232
+ # 56 x 56
233
+
234
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
235
+ self.conv12 = nn.Conv2d(256,256,3,1,0)
236
+ self.relu12 = nn.ReLU(inplace=True)
237
+ # 56 x 56
238
+
239
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
240
+ self.conv13 = nn.Conv2d(256,256,3,1,0)
241
+ self.relu13 = nn.ReLU(inplace=True)
242
+ # 56 x 56
243
+
244
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
245
+ self.conv14 = nn.Conv2d(256,256,3,1,0)
246
+ self.relu14 = nn.ReLU(inplace=True)
247
+ # 56 x 56
248
+
249
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
250
+ self.conv15 = nn.Conv2d(256,128,3,1,0)
251
+ self.relu15 = nn.ReLU(inplace=True)
252
+ # 56 x 56
253
+
254
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
255
+ # 112 x 112
256
+
257
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
258
+ self.conv16 = nn.Conv2d(128,128,3,1,0)
259
+ self.relu16 = nn.ReLU(inplace=True)
260
+ # 112 x 112
261
+
262
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
263
+ self.conv17 = nn.Conv2d(128,64,3,1,0)
264
+ self.relu17 = nn.ReLU(inplace=True)
265
+ # 112 x 112
266
+
267
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
268
+ # 224 x 224
269
+
270
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
271
+ self.conv18 = nn.Conv2d(64,64,3,1,0)
272
+ self.relu18 = nn.ReLU(inplace=True)
273
+ # 224 x 224
274
+
275
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
276
+ self.conv19 = nn.Conv2d(64,3,3,1,0)
277
+
278
+ def forward(self,x):
279
+ # decoder
280
+ out = self.reflecPad11(x)
281
+ out = self.conv11(out)
282
+ out = self.relu11(out)
283
+ out = self.unpool(out)
284
+ out = self.reflecPad12(out)
285
+ out = self.conv12(out)
286
+
287
+ out = self.relu12(out)
288
+ out = self.reflecPad13(out)
289
+ out = self.conv13(out)
290
+ out = self.relu13(out)
291
+ out = self.reflecPad14(out)
292
+ out = self.conv14(out)
293
+ out = self.relu14(out)
294
+ out = self.reflecPad15(out)
295
+ out = self.conv15(out)
296
+ out = self.relu15(out)
297
+ out = self.unpool2(out)
298
+ out = self.reflecPad16(out)
299
+ out = self.conv16(out)
300
+ out = self.relu16(out)
301
+ out = self.reflecPad17(out)
302
+ out = self.conv17(out)
303
+ out = self.relu17(out)
304
+ out = self.unpool3(out)
305
+ out = self.reflecPad18(out)
306
+ out = self.conv18(out)
307
+ out = self.relu18(out)
308
+ out = self.reflecPad19(out)
309
+ out = self.conv19(out)
310
+ return out
311
+
312
+ class decoder4(nn.Module):
313
+ def __init__(self):
314
+ super(decoder4,self).__init__()
315
+ # decoder
316
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
317
+ self.conv11 = nn.Conv2d(512,256,3,1,0)
318
+ self.relu11 = nn.ReLU(inplace=True)
319
+ # 28 x 28
320
+
321
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
322
+ # 56 x 56
323
+
324
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
325
+ self.conv12 = nn.Conv2d(256,256,3,1,0)
326
+ self.relu12 = nn.ReLU(inplace=True)
327
+ # 56 x 56
328
+
329
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
330
+ self.conv13 = nn.Conv2d(256,256,3,1,0)
331
+ self.relu13 = nn.ReLU(inplace=True)
332
+ # 56 x 56
333
+
334
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
335
+ self.conv14 = nn.Conv2d(256,256,3,1,0)
336
+ self.relu14 = nn.ReLU(inplace=True)
337
+ # 56 x 56
338
+
339
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
340
+ self.conv15 = nn.Conv2d(256,128,3,1,0)
341
+ self.relu15 = nn.ReLU(inplace=True)
342
+ # 56 x 56
343
+
344
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
345
+ # 112 x 112
346
+
347
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
348
+ self.conv16 = nn.Conv2d(128,128,3,1,0)
349
+ self.relu16 = nn.ReLU(inplace=True)
350
+ # 112 x 112
351
+
352
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
353
+ self.conv17 = nn.Conv2d(128,64,3,1,0)
354
+ self.relu17 = nn.ReLU(inplace=True)
355
+ # 112 x 112
356
+
357
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
358
+ # 224 x 224
359
+
360
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
361
+ self.conv18 = nn.Conv2d(64,64,3,1,0)
362
+ self.relu18 = nn.ReLU(inplace=True)
363
+ # 224 x 224
364
+
365
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
366
+ self.conv19 = nn.Conv2d(64,3,3,1,0)
367
+
368
+ def forward(self,x):
369
+ # decoder
370
+ out = self.reflecPad11(x)
371
+ out = self.conv11(out)
372
+ out = self.relu11(out)
373
+ out = self.unpool(out)
374
+ out = self.reflecPad12(out)
375
+ out = self.conv12(out)
376
+
377
+ out = self.relu12(out)
378
+ out = self.reflecPad13(out)
379
+ out = self.conv13(out)
380
+ out = self.relu13(out)
381
+ out = self.reflecPad14(out)
382
+ out = self.conv14(out)
383
+ out = self.relu14(out)
384
+ out = self.reflecPad15(out)
385
+ out = self.conv15(out)
386
+ out = self.relu15(out)
387
+ out = self.unpool2(out)
388
+ out = self.reflecPad16(out)
389
+ out = self.conv16(out)
390
+ out = self.relu16(out)
391
+ out = self.reflecPad17(out)
392
+ out = self.conv17(out)
393
+ out = self.relu17(out)
394
+ out = self.unpool3(out)
395
+ out = self.reflecPad18(out)
396
+ out = self.conv18(out)
397
+ out = self.relu18(out)
398
+ out = self.reflecPad19(out)
399
+ out = self.conv19(out)
400
+ return out
401
+
402
+ class encoder5(nn.Module):
403
+ def __init__(self):
404
+ super(encoder5,self).__init__()
405
+ # vgg
406
+ # 224 x 224
407
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
408
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
409
+ # 226 x 226
410
+
411
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
412
+ self.relu2 = nn.ReLU(inplace=True)
413
+ # 224 x 224
414
+
415
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
416
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
417
+ self.relu3 = nn.ReLU(inplace=True)
418
+ # 224 x 224
419
+
420
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
421
+ # 112 x 112
422
+
423
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
424
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
425
+ self.relu4 = nn.ReLU(inplace=True)
426
+ # 112 x 112
427
+
428
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
429
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
430
+ self.relu5 = nn.ReLU(inplace=True)
431
+ # 112 x 112
432
+
433
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
434
+ # 56 x 56
435
+
436
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
437
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
438
+ self.relu6 = nn.ReLU(inplace=True)
439
+ # 56 x 56
440
+
441
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
442
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
443
+ self.relu7 = nn.ReLU(inplace=True)
444
+ # 56 x 56
445
+
446
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
447
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
448
+ self.relu8 = nn.ReLU(inplace=True)
449
+ # 56 x 56
450
+
451
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
452
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
453
+ self.relu9 = nn.ReLU(inplace=True)
454
+ # 56 x 56
455
+
456
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
457
+ # 28 x 28
458
+
459
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
460
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
461
+ self.relu10 = nn.ReLU(inplace=True)
462
+
463
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
464
+ self.conv11 = nn.Conv2d(512,512,3,1,0)
465
+ self.relu11 = nn.ReLU(inplace=True)
466
+
467
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
468
+ self.conv12 = nn.Conv2d(512,512,3,1,0)
469
+ self.relu12 = nn.ReLU(inplace=True)
470
+
471
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
472
+ self.conv13 = nn.Conv2d(512,512,3,1,0)
473
+ self.relu13 = nn.ReLU(inplace=True)
474
+
475
+ self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
476
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
477
+ self.conv14 = nn.Conv2d(512,512,3,1,0)
478
+ self.relu14 = nn.ReLU(inplace=True)
479
+
480
+ def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
481
+ output = {}
482
+ out = self.conv1(x)
483
+ out = self.reflecPad1(out)
484
+ out = self.conv2(out)
485
+ output['r11'] = self.relu2(out)
486
+ out = self.reflecPad7(output['r11'])
487
+
488
+ #out = self.reflecPad3(output['r11'])
489
+ out = self.conv3(out)
490
+ output['r12'] = self.relu3(out)
491
+
492
+ output['p1'] = self.maxPool(output['r12'])
493
+ out = self.reflecPad4(output['p1'])
494
+ out = self.conv4(out)
495
+ output['r21'] = self.relu4(out)
496
+ out = self.reflecPad7(output['r21'])
497
+
498
+ #out = self.reflecPad5(output['r21'])
499
+ out = self.conv5(out)
500
+ output['r22'] = self.relu5(out)
501
+
502
+ output['p2'] = self.maxPool2(output['r22'])
503
+ out = self.reflecPad6(output['p2'])
504
+ out = self.conv6(out)
505
+ output['r31'] = self.relu6(out)
506
+ if(styleV256 is not None):
507
+ feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
508
+ out = self.reflecPad7(feature)
509
+ else:
510
+ out = self.reflecPad7(output['r31'])
511
+ out = self.conv7(out)
512
+ output['r32'] = self.relu7(out)
513
+
514
+ out = self.reflecPad8(output['r32'])
515
+ out = self.conv8(out)
516
+ output['r33'] = self.relu8(out)
517
+
518
+ out = self.reflecPad9(output['r33'])
519
+ out = self.conv9(out)
520
+ output['r34'] = self.relu9(out)
521
+
522
+ output['p3'] = self.maxPool3(output['r34'])
523
+ out = self.reflecPad10(output['p3'])
524
+ out = self.conv10(out)
525
+ output['r41'] = self.relu10(out)
526
+
527
+ out = self.reflecPad11(output['r41'])
528
+ out = self.conv11(out)
529
+ output['r42'] = self.relu11(out)
530
+
531
+ out = self.reflecPad12(output['r42'])
532
+ out = self.conv12(out)
533
+ output['r43'] = self.relu12(out)
534
+
535
+ out = self.reflecPad13(output['r43'])
536
+ out = self.conv13(out)
537
+ output['r44'] = self.relu13(out)
538
+
539
+ output['p4'] = self.maxPool4(output['r44'])
540
+
541
+ out = self.reflecPad14(output['p4'])
542
+ out = self.conv14(out)
543
+ output['r51'] = self.relu14(out)
544
+ return output
545
+
546
+ class decoder5(nn.Module):
547
+ def __init__(self):
548
+ super(decoder5,self).__init__()
549
+
550
+ # decoder
551
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
552
+ self.conv15 = nn.Conv2d(512,512,3,1,0)
553
+ self.relu15 = nn.ReLU(inplace=True)
554
+
555
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
556
+ # 28 x 28
557
+
558
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
559
+ self.conv16 = nn.Conv2d(512,512,3,1,0)
560
+ self.relu16 = nn.ReLU(inplace=True)
561
+ # 28 x 28
562
+
563
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
564
+ self.conv17 = nn.Conv2d(512,512,3,1,0)
565
+ self.relu17 = nn.ReLU(inplace=True)
566
+ # 28 x 28
567
+
568
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
569
+ self.conv18 = nn.Conv2d(512,512,3,1,0)
570
+ self.relu18 = nn.ReLU(inplace=True)
571
+ # 28 x 28
572
+
573
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
574
+ self.conv19 = nn.Conv2d(512,256,3,1,0)
575
+ self.relu19 = nn.ReLU(inplace=True)
576
+ # 28 x 28
577
+
578
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
579
+ # 56 x 56
580
+
581
+ self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
582
+ self.conv20 = nn.Conv2d(256,256,3,1,0)
583
+ self.relu20 = nn.ReLU(inplace=True)
584
+ # 56 x 56
585
+
586
+ self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
587
+ self.conv21 = nn.Conv2d(256,256,3,1,0)
588
+ self.relu21 = nn.ReLU(inplace=True)
589
+
590
+ self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
591
+ self.conv22 = nn.Conv2d(256,256,3,1,0)
592
+ self.relu22 = nn.ReLU(inplace=True)
593
+
594
+ self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
595
+ self.conv23 = nn.Conv2d(256,128,3,1,0)
596
+ self.relu23 = nn.ReLU(inplace=True)
597
+
598
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
599
+ # 112 X 112
600
+
601
+ self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
602
+ self.conv24 = nn.Conv2d(128,128,3,1,0)
603
+ self.relu24 = nn.ReLU(inplace=True)
604
+
605
+ self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
606
+ self.conv25 = nn.Conv2d(128,64,3,1,0)
607
+ self.relu25 = nn.ReLU(inplace=True)
608
+
609
+ self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)
610
+
611
+ self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
612
+ self.conv26 = nn.Conv2d(64,64,3,1,0)
613
+ self.relu26 = nn.ReLU(inplace=True)
614
+
615
+ self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
616
+ self.conv27 = nn.Conv2d(64,3,3,1,0)
617
+
618
+ def forward(self,x):
619
+ # decoder
620
+ out = self.reflecPad15(x)
621
+ out = self.conv15(out)
622
+ out = self.relu15(out)
623
+ out = self.unpool(out)
624
+ out = self.reflecPad16(out)
625
+ out = self.conv16(out)
626
+ out = self.relu16(out)
627
+ out = self.reflecPad17(out)
628
+ out = self.conv17(out)
629
+ out = self.relu17(out)
630
+ out = self.reflecPad18(out)
631
+ out = self.conv18(out)
632
+ out = self.relu18(out)
633
+ out = self.reflecPad19(out)
634
+ out = self.conv19(out)
635
+ out = self.relu19(out)
636
+ out = self.unpool2(out)
637
+ out = self.reflecPad20(out)
638
+ out = self.conv20(out)
639
+ out = self.relu20(out)
640
+ out = self.reflecPad21(out)
641
+ out = self.conv21(out)
642
+ out = self.relu21(out)
643
+ out = self.reflecPad22(out)
644
+ out = self.conv22(out)
645
+ out = self.relu22(out)
646
+ out = self.reflecPad23(out)
647
+ out = self.conv23(out)
648
+ out = self.relu23(out)
649
+ out = self.unpool3(out)
650
+ out = self.reflecPad24(out)
651
+ out = self.conv24(out)
652
+ out = self.relu24(out)
653
+ out = self.reflecPad25(out)
654
+ out = self.conv25(out)
655
+ out = self.relu25(out)
656
+ out = self.unpool4(out)
657
+ out = self.reflecPad26(out)
658
+ out = self.conv26(out)
659
+ out = self.relu26(out)
660
+ out = self.reflecPad27(out)
661
+ out = self.conv27(out)
662
+ return out
graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_spn
2
+ To build, install [pytorch](https://github.com/pytorch) and run:
3
+
4
+ $ sh make.sh
5
+
6
+ See left_right_demo.py for usage:
7
+
8
+ $ mv left_right_demo.py ../
9
+
10
+ $ python left_right_demo.py
11
+
12
+ The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git).
graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py ADDED
File without changes
graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py ADDED
File without changes
graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from torch.utils.ffi import _wrap_function
3
+ from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi
4
+
5
+ __all__ = []
6
+ def _import_symbols(locals):
7
+ for symbol in dir(_lib):
8
+ fn = getattr(_lib, symbol)
9
+ if callable(fn):
10
+ locals[symbol] = _wrap_function(fn, _ffi)
11
+ else:
12
+ locals[symbol] = fn
13
+ __all__.append(symbol)
14
+
15
+ _import_symbols(locals())
graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.ffi import create_extension
4
+
5
+ this_file = os.path.dirname(__file__)
6
+
7
+ sources = []
8
+ headers = []
9
+ defines = []
10
+ with_cuda = False
11
+
12
+ if torch.cuda.is_available():
13
+ print('Including CUDA code.')
14
+ sources += ['src/gaterecurrent2dnoind_cuda.c']
15
+ headers += ['src/gaterecurrent2dnoind_cuda.h']
16
+ defines += [('WITH_CUDA', None)]
17
+ with_cuda = True
18
+
19
+ this_file = os.path.dirname(os.path.realpath(__file__))
20
+ extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o']
21
+ extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
22
+
23
+ ffi = create_extension(
24
+ '_ext.gaterecurrent2dnoind',
25
+ headers=headers,
26
+ sources=sources,
27
+ define_macros=defines,
28
+ relative_to=__file__,
29
+ with_cuda=with_cuda,
30
+ extra_objects=extra_objects
31
+ )
32
+
33
+ if __name__ == '__main__':
34
+ ffi.build()
graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py ADDED
File without changes
graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from .._ext import gaterecurrent2dnoind as gaterecurrent2d
4
+
5
+ class GateRecurrent2dnoindFunction(Function):
6
+ def __init__(self, horizontal_, reverse_):
7
+ self.horizontal = horizontal_
8
+ self.reverse = reverse_
9
+
10
+ def forward(self, X, G1, G2, G3):
11
+ num, channels, height, width = X.size()
12
+ output = torch.zeros(num, channels, height, width)
13
+
14
+ if not X.is_cuda:
15
+ print("cpu version is not ready at this time")
16
+ return 0
17
+ else:
18
+ output = output.cuda()
19
+ gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output)
20
+
21
+ self.X = X
22
+ self.G1 = G1
23
+ self.G2 = G2
24
+ self.G3 = G3
25
+ self.output = output
26
+ self.hiddensize = X.size()
27
+ return output
28
+
29
+ def backward(self, grad_output):
30
+ assert(self.hiddensize is not None and grad_output.is_cuda)
31
+ num, channels, height, width = self.hiddensize
32
+
33
+ grad_X = torch.zeros(num, channels, height, width).cuda()
34
+ grad_G1 = torch.zeros(num, channels, height, width).cuda()
35
+ grad_G2 = torch.zeros(num, channels, height, width).cuda()
36
+ grad_G3 = torch.zeros(num, channels, height, width).cuda()
37
+
38
+ gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3)
39
+
40
+ del self.hiddensize
41
+ del self.G1
42
+ del self.G2
43
+ del self.G3
44
+ del self.output
45
+ del self.X
46
+
47
+ return grad_X, grad_G1, grad_G2, grad_G3
graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An example of left->right propagation
3
+
4
+ Other direction settings:
5
+ left->right: Propagator = GateRecurrent2dnoind(True,False)
6
+ right->left: Propagator = GateRecurrent2dnoind(True,True)
7
+ top->bottom: Propagator = GateRecurrent2dnoind(False,False)
8
+ bottom->top: Propagator = GateRecurrent2dnoind(False,True)
9
+
10
+ X: any signal/feature map to be filtered
11
+ G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)
12
+
13
+ Note:
14
+ 1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
15
+ 2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
16
+ """
17
+ import torch
18
+ from torch.autograd import Variable
19
+ from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
20
+
21
+ Propagator = GateRecurrent2dnoind(True,False)
22
+
23
+ X = Variable(torch.randn(1,3,10,10))
24
+ G1 = Variable(torch.randn(1,3,10,10))
25
+ G2 = Variable(torch.randn(1,3,10,10))
26
+ G3 = Variable(torch.randn(1,3,10,10))
27
+
28
+ sum_abs = G1.abs() + G2.abs() + G3.abs()
29
+ mask_need_norm = sum_abs.ge(1)
30
+ mask_need_norm = mask_need_norm.float()
31
+ G1_norm = torch.div(G1, sum_abs)
32
+ G2_norm = torch.div(G2, sum_abs)
33
+ G3_norm = torch.div(G3, sum_abs)
34
+
35
+ G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
36
+ G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
37
+ G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
38
+
39
+ X = X.cuda()
40
+ G1 = G1.cuda()
41
+ G2 = G2.cuda()
42
+ G3 = G3.cuda()
43
+
44
+ output = Propagator.forward(X,G1,G2,G3)
45
+ print(X)
46
+ print(output)
graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ CUDA_PATH=/usr/local/cuda/
4
+
5
+ cd src/cuda/
6
+ echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..."
7
+ nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
8
+ cd ../../
9
+ python build.py
graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction
3
+
4
+ class GateRecurrent2dnoind(nn.Module):
5
+ """docstring for ."""
6
+ def __init__(self, horizontal_, reverse_):
7
+ super(GateRecurrent2dnoind, self).__init__()
8
+ self.horizontal = horizontal_
9
+ self.reverse = reverse_
10
+
11
+ def forward(self, X, G1, G2, G3):
12
+ return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3)
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store ADDED
Binary file (8.2 kB). View file
 
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifdef __cplusplus
2
+ extern "C" {
3
+ #endif
4
+
5
+ #include <stdio.h>
6
+ #include <math.h>
7
+ #include <float.h>
8
+ #include "gaterecurrent2dnoind_kernel.h"
9
+
10
+ #define CUDA_1D_KERNEL_LOOP(i, n) \
11
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
12
+ i += blockDim.x * gridDim.x)
13
+
14
+ __device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse)
15
+ {
16
+ if(horizontal && ! reverse) // left -> right
17
+ {
18
+ if(w1>w2)
19
+ {
20
+ out[0]=h1;
21
+ out[1]=w1;
22
+ }
23
+ else
24
+ {
25
+ out[0]=h2;
26
+ out[1]=w2;
27
+ }
28
+ }
29
+ if(horizontal && reverse) // right -> left
30
+ {
31
+ if(w1<w2)
32
+ {
33
+ out[0]=h1;
34
+ out[1]=w1;
35
+ }
36
+ else
37
+ {
38
+ out[0]=h2;
39
+ out[1]=w2;
40
+ }
41
+ }
42
+ if(!horizontal && !reverse) // top -> bottom
43
+ {
44
+ if(h1>h2)
45
+ {
46
+ out[0]=h1;
47
+ out[1]=w1;
48
+ }
49
+ else
50
+ {
51
+ out[0]=h2;
52
+ out[1]=w2;
53
+ }
54
+ }
55
+ if(!horizontal && reverse) // bottom -> top
56
+ {
57
+ if(h1<h2)
58
+ {
59
+ out[0]=h1;
60
+ out[1]=w1;
61
+ }
62
+ else
63
+ {
64
+ out[0]=h2;
65
+ out[1]=w2;
66
+ }
67
+ }
68
+
69
+ }
70
+
71
+ __device__ float get_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w)
72
+ {
73
+ if(h<0 || h >=height)
74
+ return 0;
75
+ if(w<0 || w >= width)
76
+ return 0;
77
+
78
+ return data[n*channels*height*width + c * height*width + h * width + w];
79
+ }
80
+
81
+ __device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v)
82
+ {
83
+ if(h<0 || h >=height)
84
+ return ;
85
+ if(w<0 || w >= width)
86
+ return ;
87
+
88
+ data[n*channels*height*width + c * height*width + h * width + w]=v;
89
+ }
90
+
91
+ __device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse)
92
+ {
93
+ if(h1<0 || h1 >=height)
94
+ return 0;
95
+ if(w1<0 || w1 >= width)
96
+ return 0;
97
+ if(h2<0 || h2 >=height)
98
+ return 0;
99
+ if(w2<0 || w2 >= width)
100
+ return 0;
101
+ int idx[2];
102
+
103
+ get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
104
+
105
+ int h = idx[0];
106
+ int w = idx[1];
107
+
108
+ return data[n*channels*height*width + c * height*width + h * width + w];
109
+ }
110
+
111
+ __device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v)
112
+ {
113
+ if(h1<0 || h1 >=height)
114
+ return ;
115
+ if(w1<0 || w1 >= width)
116
+ return ;
117
+ if(h2<0 || h2 >=height)
118
+ return ;
119
+ if(w2<0 || w2 >= width)
120
+ return ;
121
+ int idx[2];
122
+
123
+ get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
124
+
125
+ int h = idx[0];
126
+ int w = idx[1];
127
+
128
+ data[n*channels*height*width + c * height*width + h * width + w]=v;
129
+ }
130
+
131
+ // we do not use set_gate_add_sf(...) in the caffe implimentation
132
+ // avoid using atomicAdd
133
+
134
+ __global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) {
135
+ CUDA_1D_KERNEL_LOOP(index, count) {
136
+
137
+ int hc_count = height * channels;
138
+
139
+ int n,c,h,w;
140
+ int temp=index;
141
+ w = T;
142
+ n = temp / hc_count;
143
+ temp = temp % hc_count;
144
+ c = temp / height;
145
+ temp = temp % height;
146
+ h = temp;
147
+
148
+
149
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
150
+
151
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
152
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
153
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
154
+
155
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
156
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
157
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
158
+
159
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
160
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
161
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
162
+
163
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
164
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
165
+
166
+ float h_data = x_hype + h_hype;
167
+
168
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
169
+
170
+ }
171
+ }
172
+
173
+ __global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
174
+ CUDA_1D_KERNEL_LOOP(index, count) {
175
+
176
+ int hc_count = height * channels;
177
+ int n,c,h,w;
178
+ int temp=index;
179
+ w = T;
180
+ n = temp / hc_count;
181
+ temp = temp % hc_count;
182
+ c = temp / height;
183
+ temp = temp % height;
184
+ h = temp;
185
+
186
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
187
+
188
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
189
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
190
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
191
+
192
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
193
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
194
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
195
+
196
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
197
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
198
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
199
+
200
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
201
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
202
+
203
+ float h_data = x_hype + h_hype;
204
+
205
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
206
+
207
+ }
208
+ }
209
+
210
+ __global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
211
+ CUDA_1D_KERNEL_LOOP(index, count) {
212
+
213
+ int wc_count = width * channels;
214
+
215
+ int n,c,h,w;
216
+ int temp=index;
217
+ h = T;
218
+ n = temp / wc_count;
219
+ temp = temp % wc_count;
220
+ c = temp / width;
221
+ temp = temp % width;
222
+ w = temp;
223
+
224
+
225
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
226
+
227
+
228
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
229
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
230
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
231
+
232
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
233
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
234
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
235
+
236
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
237
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
238
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
239
+
240
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
241
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
242
+
243
+ float h_data = x_hype + h_hype;
244
+
245
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
246
+
247
+ }
248
+ }
249
+
250
+ __global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
251
+ CUDA_1D_KERNEL_LOOP(index, count) {
252
+
253
+ int wc_count = width * channels;
254
+
255
+ int n,c,h,w;
256
+ int temp=index;
257
+ h = T;
258
+ n = temp / wc_count;
259
+ temp = temp % wc_count;
260
+ c = temp / width;
261
+ temp = temp % width;
262
+ w = temp;
263
+
264
+
265
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
266
+
267
+
268
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
269
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
270
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
271
+
272
+
273
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
274
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
275
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
276
+
277
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
278
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
279
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
280
+
281
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
282
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
283
+
284
+ float h_data = x_hype + h_hype;
285
+
286
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
287
+
288
+ }
289
+ }
290
+
291
+
292
+ __global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
293
+ CUDA_1D_KERNEL_LOOP(index, count) {
294
+
295
+ int hc_count = height * channels;
296
+
297
+ int n,c,h,w;
298
+ int temp=index;
299
+
300
+ w = T;
301
+ n = temp / hc_count;
302
+ temp = temp % hc_count;
303
+ c = temp / height;
304
+ temp = temp % height;
305
+ h = temp;
306
+
307
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
308
+
309
+ //h(t)_diff = top(t)_diff
310
+ float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
311
+
312
+ //h(t)_diff += h(t+1)_diff * g(t+1) if t<T
313
+ float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w+1);
314
+ float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
315
+
316
+ float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w+1);
317
+ float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
318
+
319
+ float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w+1);
320
+ float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
321
+
322
+ h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
323
+
324
+
325
+ //Hdiff[n*channels*height*width + c*height*width + h*width + w]=0;
326
+ set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
327
+
328
+
329
+ //x(t)_diff=(1-sum(g_date))*h(t)_diff
330
+ float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
331
+ float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
332
+ float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
333
+
334
+ float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
335
+ set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
336
+
337
+
338
+ // g_diff = h_diff * (h_data(t-1) - x_data)
339
+ float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
340
+ float g1_diff = h_diff * (h1_minus1_data - x_data);
341
+ set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse,g1_diff);
342
+
343
+ float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
344
+ float g2_diff = h_diff * (h2_minus1_data - x_data);
345
+ set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse,g2_diff);
346
+
347
+ float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
348
+ float g3_diff = h_diff * (h3_minus1_data - x_data);
349
+ set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse,g3_diff);
350
+
351
+ }
352
+ }
353
+
354
+
355
+ __global__ void backward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
356
+ CUDA_1D_KERNEL_LOOP(index, count) {
357
+
358
+ int hc_count = height * channels;
359
+
360
+ int n,c,h,w;
361
+ int temp=index;
362
+
363
+
364
+ w = T;
365
+ n = temp / hc_count;
366
+ temp = temp % hc_count;
367
+ c = temp / height;
368
+ temp = temp % height;
369
+ h = temp;
370
+
371
+
372
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
373
+
374
+ //h(t)_diff = top(t)_diff
375
+ float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
376
+
377
+ ///h(t)_diff += h(t+1)_diff * g(t+1) if t<T
378
+ float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w-1);
379
+ float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
380
+
381
+ float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w-1);
382
+ float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
383
+
384
+ float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w-1);
385
+ float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
386
+
387
+ h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
388
+
389
+
390
+ set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
391
+
392
+ float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
393
+ float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
394
+ float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
395
+ float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
396
+ set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
397
+
398
+ // g_diff = h_diff * (h_data(t-1) - x_data)
399
+ float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
400
+ float g1_diff = h_diff * (h1_minus1_data - x_data);
401
+ set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse,g1_diff);
402
+
403
+
404
+ float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
405
+ float g2_diff = h_diff * (h2_minus1_data - x_data);
406
+ set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse,g2_diff);
407
+
408
+ float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
409
+ float g3_diff = h_diff * (h3_minus1_data - x_data);
410
+ set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse,g3_diff);
411
+
412
+
413
+ }
414
+ }
415
+
416
+ __global__ void backward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
417
+ CUDA_1D_KERNEL_LOOP(index, count) {
418
+
419
+
420
+ int wc_count = width * channels;
421
+
422
+ int n,c,h,w;
423
+ int temp=index;
424
+ h = T;
425
+ n = temp / wc_count;
426
+ temp = temp % wc_count;
427
+ c = temp / width;
428
+ temp = temp % width;
429
+ w = temp;
430
+
431
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
432
+
433
+ float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
434
+
435
+ //h(t)_diff += h(t+1)_diff * g(t+1) if t<T
436
+ float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w-1);
437
+ float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
438
+
439
+ float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w);
440
+ float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
441
+
442
+ float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w+1);
443
+ float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
444
+
445
+ h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
446
+
447
+
448
+ set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
449
+
450
+
451
+ //x(t)_diff=(1-g(t))*h(t)_diff
452
+ float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
453
+ float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
454
+ float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
455
+ float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
456
+ set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
457
+
458
+
459
+
460
+ // g_diff = h_diff * (h_data(t-1) - x_data)
461
+ float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
462
+ float g1_diff = h_diff * (h1_minus1_data - x_data);
463
+ set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse,g1_diff);
464
+
465
+ float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
466
+ float g2_diff = h_diff * (h2_minus1_data - x_data);
467
+ set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse,g2_diff);
468
+
469
+ float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
470
+ float g3_diff = h_diff * (h3_minus1_data - x_data);
471
+ set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse,g3_diff);
472
+
473
+ }
474
+ }
475
+
476
+ __global__ void backward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
477
+ CUDA_1D_KERNEL_LOOP(index, count) {
478
+
479
+ int wc_count = width * channels;
480
+
481
+ int n,c,h,w;
482
+ int temp=index;
483
+ h = T;
484
+ n = temp / wc_count;
485
+ temp = temp % wc_count;
486
+ c = temp / width;
487
+ temp = temp % width;
488
+ w = temp;
489
+
490
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
491
+
492
+ //h(t)_diff = top(t)_diff
493
+ float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
494
+
495
+ //h(t)_diff += h(t+1)_diff * g(t+1) if t<T
496
+ float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w-1);
497
+ float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
498
+
499
+ float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w);
500
+ float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
501
+
502
+ float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w+1);
503
+ float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
504
+
505
+ h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
506
+
507
+
508
+ set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
509
+
510
+
511
+ //x(t)_diff=(1-g(t))*h(t)_diff
512
+ float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
513
+ float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
514
+ float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
515
+ float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
516
+ set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
517
+
518
+
519
+ // g_diff = h_diff * (h_data(t-1) - x_data)
520
+ float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
521
+ float g1_diff = h_diff * (h1_minus1_data - x_data);
522
+ set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse,g1_diff);
523
+
524
+ //float g2_diff = h_diff * g2_idx * x_data * -1;
525
+ float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
526
+ float g2_diff = h_diff * (h2_minus1_data - x_data);
527
+ set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse,g2_diff);
528
+
529
+ //float g3_diff = h_diff * g3_idx * x_data * -1;
530
+ float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
531
+ float g3_diff = h_diff * (h3_minus1_data - x_data);
532
+ set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse,g3_diff);
533
+
534
+
535
+ }
536
+ }
537
+
538
+
539
+ int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
540
+ {
541
+ int count = height_ * channels_ * num_;
542
+ int kThreadsPerBlock = 1024;
543
+ cudaError_t err;
544
+
545
+ for(int t=0; t<width_; t++) {
546
+ forward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
547
+
548
+ err = cudaGetLastError();
549
+ if(cudaSuccess != err)
550
+ {
551
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
552
+ exit( -1 );
553
+ }
554
+ }
555
+ return 1;
556
+ }
557
+
558
+ int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
559
+ {
560
+ int count = height_ * channels_ * num_;
561
+ int kThreadsPerBlock = 1024;
562
+ cudaError_t err;
563
+
564
+ for(int t = width_ - 1; t >= 0; t--) {
565
+ forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
566
+
567
+ err = cudaGetLastError();
568
+ if(cudaSuccess != err)
569
+ {
570
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
571
+ exit( -1 );
572
+ }
573
+ }
574
+ return 1;
575
+ }
576
+
577
+ int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
578
+ {
579
+ int count = width_ * channels_ * num_;
580
+ int kThreadsPerBlock = 1024;
581
+ cudaError_t err;
582
+
583
+ for(int t=0; t< height_; t++) {
584
+ forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
585
+
586
+ err = cudaGetLastError();
587
+ if(cudaSuccess != err)
588
+ {
589
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
590
+ exit( -1 );
591
+ }
592
+ }
593
+ return 1;
594
+ }
595
+
596
+ int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
597
+ {
598
+ int count = width_ * channels_ * num_;
599
+ int kThreadsPerBlock = 1024;
600
+ cudaError_t err;
601
+
602
+ for(int t = height_-1; t >= 0; t--) {
603
+ forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
604
+
605
+ err = cudaGetLastError();
606
+ if(cudaSuccess != err)
607
+ {
608
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
609
+ exit( -1 );
610
+ }
611
+ }
612
+ return 1;
613
+ }
614
+
615
+ int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
616
+ {
617
+ int count = height_ * channels_ * num_;
618
+ int kThreadsPerBlock = 1024;
619
+ cudaError_t err;
620
+
621
+ for(int t = width_ -1; t>=0; t--)
622
+ {
623
+ backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
624
+
625
+ err = cudaGetLastError();
626
+ if(cudaSuccess != err)
627
+ {
628
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
629
+ exit( -1 );
630
+ }
631
+ }
632
+ return 1;
633
+ }
634
+
635
+ int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
636
+ {
637
+ int count = height_ * channels_ * num_;
638
+ int kThreadsPerBlock = 1024;
639
+ cudaError_t err;
640
+
641
+ for(int t = 0; t<width_; t++)
642
+ {
643
+ backward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
644
+
645
+ err = cudaGetLastError();
646
+ if(cudaSuccess != err)
647
+ {
648
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
649
+ exit( -1 );
650
+ }
651
+ }
652
+ return 1;
653
+ }
654
+
655
+ int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
656
+ {
657
+ int count = width_ * channels_ * num_;
658
+ int kThreadsPerBlock = 1024;
659
+ cudaError_t err;
660
+
661
+ for(int t = height_-1; t>=0; t--)
662
+ {
663
+ backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
664
+
665
+ err = cudaGetLastError();
666
+ if(cudaSuccess != err)
667
+ {
668
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
669
+ exit( -1 );
670
+ }
671
+ }
672
+ return 1;
673
+ }
674
+
675
+ int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
676
+ {
677
+ int count = width_ * channels_ * num_;
678
+ int kThreadsPerBlock = 1024;
679
+ cudaError_t err;
680
+
681
+ for(int t = 0; t<height_; t++)
682
+ {
683
+ backward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
684
+
685
+ err = cudaGetLastError();
686
+ if(cudaSuccess != err)
687
+ {
688
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
689
+ exit( -1 );
690
+ }
691
+ }
692
+ return 1;
693
+ }
694
+
695
+ #ifdef __cplusplus
696
+ }
697
+ #endif
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o ADDED
Binary file (98.8 kB). View file
 
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _GATERECURRENT2DNOIND_KERNEL
2
+ #define _GATERECURRENT2DNOIND_KERNEL
3
+
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
9
+
10
+ int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
11
+
12
+ int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
13
+
14
+ int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
15
+
16
+ int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
17
+
18
+ int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
19
+
20
+ int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
21
+
22
+ int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
23
+
24
+ #ifdef __cplusplus
25
+ }
26
+ #endif
27
+
28
+ #endif
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // gaterecurrent2dnoind_cuda.c
2
+ #include <THC/THC.h>
3
+ #include <math.h>
4
+ #include "gaterecurrent2dnoind_cuda.h"
5
+ #include "cuda/gaterecurrent2dnoind_kernel.h"
6
+
7
+ // typedef bool boolean;
8
+
9
+ // this symbol will be resolved automatically from PyTorch libs
10
+ extern THCState *state;
11
+
12
+ int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output)
13
+ {
14
+ // Grab the input tensor to flat
15
+ float * X_data = THCudaTensor_data(state, X);
16
+ float * G1_data = THCudaTensor_data(state, G1);
17
+ float * G2_data = THCudaTensor_data(state, G2);
18
+ float * G3_data = THCudaTensor_data(state, G3);
19
+ float * H_data = THCudaTensor_data(state, output);
20
+
21
+ // dimensions
22
+ int num_ = THCudaTensor_size(state, X, 0);
23
+ int channels_ = THCudaTensor_size(state, X, 1);
24
+ int height_ = THCudaTensor_size(state, X, 2);
25
+ int width_ = THCudaTensor_size(state, X, 3);
26
+
27
+ cudaStream_t stream = THCState_getCurrentStream(state);
28
+
29
+ if(horizontal_ && !reverse_) // left to right
30
+ {
31
+ //const int count = height_ * channels_ * num_;
32
+ Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
33
+ }
34
+ else if(horizontal_ && reverse_) // right to left
35
+ {
36
+ Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
37
+ }
38
+ else if(!horizontal_ && !reverse_) // top to bottom
39
+ {
40
+ Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
41
+ }
42
+ else
43
+ {
44
+ Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
45
+ }
46
+
47
+ return 1;
48
+ }
49
+
50
+ int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad)
51
+ {
52
+ //Grab the input tensor to flat
53
+ float * X_data = THCudaTensor_data(state, X);
54
+ float * G1_data = THCudaTensor_data(state, G1);
55
+ float * G2_data = THCudaTensor_data(state, G2);
56
+ float * G3_data = THCudaTensor_data(state, G3);
57
+ float * H_data = THCudaTensor_data(state, top);
58
+
59
+ float * H_diff = THCudaTensor_data(state, top_grad);
60
+
61
+ float * X_diff = THCudaTensor_data(state, X_grad);
62
+ float * G1_diff = THCudaTensor_data(state, G1_grad);
63
+ float * G2_diff = THCudaTensor_data(state, G2_grad);
64
+ float * G3_diff = THCudaTensor_data(state, G3_grad);
65
+
66
+ // dimensions
67
+ int num_ = THCudaTensor_size(state, X, 0);
68
+ int channels_ = THCudaTensor_size(state, X, 1);
69
+ int height_ = THCudaTensor_size(state, X, 2);
70
+ int width_ = THCudaTensor_size(state, X, 3);
71
+
72
+ cudaStream_t stream = THCState_getCurrentStream(state);
73
+
74
+ if(horizontal_ && ! reverse_) //left to right
75
+ {
76
+ Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
77
+ }
78
+ else if(horizontal_ && reverse_) //right to left
79
+ {
80
+ Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
81
+ }
82
+ else if(!horizontal_ && !reverse_) //top to bottom
83
+ {
84
+ Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
85
+ }
86
+ else {
87
+ Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
88
+ }
89
+
90
+ return 1;
91
+ }
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ // #include <stdbool.h>
3
+ // gaterecurrent2dnoind_cuda.h
4
+ int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output);
5
+
6
+ int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff);
graph_networks/LinearStyleTransfer/libs/smooth_filter.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py
3
+ """
4
+ src = '''
5
+ #include "/usr/local/cuda/include/math_functions.h"
6
+ #define TB 256
7
+ #define EPS 1e-7
8
+
9
+ __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
10
+ double m[16], inv[16];
11
+ for (int i = 0; i < 4; i++) {
12
+ for (int j = 0; j < 4; j++) {
13
+ m[i * 4 + j] = m_in[i][j];
14
+ }
15
+ }
16
+
17
+ inv[0] = m[5] * m[10] * m[15] -
18
+ m[5] * m[11] * m[14] -
19
+ m[9] * m[6] * m[15] +
20
+ m[9] * m[7] * m[14] +
21
+ m[13] * m[6] * m[11] -
22
+ m[13] * m[7] * m[10];
23
+
24
+ inv[4] = -m[4] * m[10] * m[15] +
25
+ m[4] * m[11] * m[14] +
26
+ m[8] * m[6] * m[15] -
27
+ m[8] * m[7] * m[14] -
28
+ m[12] * m[6] * m[11] +
29
+ m[12] * m[7] * m[10];
30
+
31
+ inv[8] = m[4] * m[9] * m[15] -
32
+ m[4] * m[11] * m[13] -
33
+ m[8] * m[5] * m[15] +
34
+ m[8] * m[7] * m[13] +
35
+ m[12] * m[5] * m[11] -
36
+ m[12] * m[7] * m[9];
37
+
38
+ inv[12] = -m[4] * m[9] * m[14] +
39
+ m[4] * m[10] * m[13] +
40
+ m[8] * m[5] * m[14] -
41
+ m[8] * m[6] * m[13] -
42
+ m[12] * m[5] * m[10] +
43
+ m[12] * m[6] * m[9];
44
+
45
+ inv[1] = -m[1] * m[10] * m[15] +
46
+ m[1] * m[11] * m[14] +
47
+ m[9] * m[2] * m[15] -
48
+ m[9] * m[3] * m[14] -
49
+ m[13] * m[2] * m[11] +
50
+ m[13] * m[3] * m[10];
51
+
52
+ inv[5] = m[0] * m[10] * m[15] -
53
+ m[0] * m[11] * m[14] -
54
+ m[8] * m[2] * m[15] +
55
+ m[8] * m[3] * m[14] +
56
+ m[12] * m[2] * m[11] -
57
+ m[12] * m[3] * m[10];
58
+
59
+ inv[9] = -m[0] * m[9] * m[15] +
60
+ m[0] * m[11] * m[13] +
61
+ m[8] * m[1] * m[15] -
62
+ m[8] * m[3] * m[13] -
63
+ m[12] * m[1] * m[11] +
64
+ m[12] * m[3] * m[9];
65
+
66
+ inv[13] = m[0] * m[9] * m[14] -
67
+ m[0] * m[10] * m[13] -
68
+ m[8] * m[1] * m[14] +
69
+ m[8] * m[2] * m[13] +
70
+ m[12] * m[1] * m[10] -
71
+ m[12] * m[2] * m[9];
72
+
73
+ inv[2] = m[1] * m[6] * m[15] -
74
+ m[1] * m[7] * m[14] -
75
+ m[5] * m[2] * m[15] +
76
+ m[5] * m[3] * m[14] +
77
+ m[13] * m[2] * m[7] -
78
+ m[13] * m[3] * m[6];
79
+
80
+ inv[6] = -m[0] * m[6] * m[15] +
81
+ m[0] * m[7] * m[14] +
82
+ m[4] * m[2] * m[15] -
83
+ m[4] * m[3] * m[14] -
84
+ m[12] * m[2] * m[7] +
85
+ m[12] * m[3] * m[6];
86
+
87
+ inv[10] = m[0] * m[5] * m[15] -
88
+ m[0] * m[7] * m[13] -
89
+ m[4] * m[1] * m[15] +
90
+ m[4] * m[3] * m[13] +
91
+ m[12] * m[1] * m[7] -
92
+ m[12] * m[3] * m[5];
93
+
94
+ inv[14] = -m[0] * m[5] * m[14] +
95
+ m[0] * m[6] * m[13] +
96
+ m[4] * m[1] * m[14] -
97
+ m[4] * m[2] * m[13] -
98
+ m[12] * m[1] * m[6] +
99
+ m[12] * m[2] * m[5];
100
+
101
+ inv[3] = -m[1] * m[6] * m[11] +
102
+ m[1] * m[7] * m[10] +
103
+ m[5] * m[2] * m[11] -
104
+ m[5] * m[3] * m[10] -
105
+ m[9] * m[2] * m[7] +
106
+ m[9] * m[3] * m[6];
107
+
108
+ inv[7] = m[0] * m[6] * m[11] -
109
+ m[0] * m[7] * m[10] -
110
+ m[4] * m[2] * m[11] +
111
+ m[4] * m[3] * m[10] +
112
+ m[8] * m[2] * m[7] -
113
+ m[8] * m[3] * m[6];
114
+
115
+ inv[11] = -m[0] * m[5] * m[11] +
116
+ m[0] * m[7] * m[9] +
117
+ m[4] * m[1] * m[11] -
118
+ m[4] * m[3] * m[9] -
119
+ m[8] * m[1] * m[7] +
120
+ m[8] * m[3] * m[5];
121
+
122
+ inv[15] = m[0] * m[5] * m[10] -
123
+ m[0] * m[6] * m[9] -
124
+ m[4] * m[1] * m[10] +
125
+ m[4] * m[2] * m[9] +
126
+ m[8] * m[1] * m[6] -
127
+ m[8] * m[2] * m[5];
128
+
129
+ double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
130
+
131
+ if (abs(det) < 1e-9) {
132
+ return false;
133
+ }
134
+
135
+
136
+ det = 1.0 / det;
137
+
138
+ for (int i = 0; i < 4; i++) {
139
+ for (int j = 0; j < 4; j++) {
140
+ inv_out[i][j] = inv[i * 4 + j] * det;
141
+ }
142
+ }
143
+
144
+ return true;
145
+ }
146
+
147
+ extern "C"
148
+ __global__ void best_local_affine_kernel(
149
+ float *output, float *input, float *affine_model,
150
+ int h, int w, float epsilon, int kernel_radius
151
+ )
152
+ {
153
+ int size = h * w;
154
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
155
+
156
+ if (id < size) {
157
+ int x = id % w, y = id / w;
158
+
159
+ double Mt_M[4][4] = {}; // 4x4
160
+ double invMt_M[4][4] = {};
161
+ double Mt_S[3][4] = {}; // RGB -> 1x4
162
+ double A[3][4] = {};
163
+ for (int i = 0; i < 4; i++)
164
+ for (int j = 0; j < 4; j++) {
165
+ Mt_M[i][j] = 0, invMt_M[i][j] = 0;
166
+ if (i != 3) {
167
+ Mt_S[i][j] = 0, A[i][j] = 0;
168
+ if (i == j)
169
+ Mt_M[i][j] = 1e-3;
170
+ }
171
+ }
172
+
173
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
174
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
175
+
176
+ int xx = x + dx, yy = y + dy;
177
+ int id2 = yy * w + xx;
178
+
179
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
180
+
181
+ Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
182
+ Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
183
+ Mt_M[0][2] += input[id2 + 2*size] * input[id2];
184
+ Mt_M[0][3] += input[id2 + 2*size];
185
+
186
+ Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
187
+ Mt_M[1][1] += input[id2 + size] * input[id2 + size];
188
+ Mt_M[1][2] += input[id2 + size] * input[id2];
189
+ Mt_M[1][3] += input[id2 + size];
190
+
191
+ Mt_M[2][0] += input[id2] * input[id2 + 2*size];
192
+ Mt_M[2][1] += input[id2] * input[id2 + size];
193
+ Mt_M[2][2] += input[id2] * input[id2];
194
+ Mt_M[2][3] += input[id2];
195
+
196
+ Mt_M[3][0] += input[id2 + 2*size];
197
+ Mt_M[3][1] += input[id2 + size];
198
+ Mt_M[3][2] += input[id2];
199
+ Mt_M[3][3] += 1;
200
+
201
+ Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
202
+ Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
203
+ Mt_S[0][2] += input[id2] * output[id2 + 2*size];
204
+ Mt_S[0][3] += output[id2 + 2*size];
205
+
206
+ Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
207
+ Mt_S[1][1] += input[id2 + size] * output[id2 + size];
208
+ Mt_S[1][2] += input[id2] * output[id2 + size];
209
+ Mt_S[1][3] += output[id2 + size];
210
+
211
+ Mt_S[2][0] += input[id2 + 2*size] * output[id2];
212
+ Mt_S[2][1] += input[id2 + size] * output[id2];
213
+ Mt_S[2][2] += input[id2] * output[id2];
214
+ Mt_S[2][3] += output[id2];
215
+ }
216
+ }
217
+ }
218
+
219
+ bool success = InverseMat4x4(Mt_M, invMt_M);
220
+
221
+ for (int i = 0; i < 3; i++) {
222
+ for (int j = 0; j < 4; j++) {
223
+ for (int k = 0; k < 4; k++) {
224
+ A[i][j] += invMt_M[j][k] * Mt_S[i][k];
225
+ }
226
+ }
227
+ }
228
+
229
+ for (int i = 0; i < 3; i++) {
230
+ for (int j = 0; j < 4; j++) {
231
+ int affine_id = i * 4 + j;
232
+ affine_model[12 * id + affine_id] = A[i][j];
233
+ }
234
+ }
235
+ }
236
+ return ;
237
+ }
238
+
239
+ extern "C"
240
+ __global__ void bilateral_smooth_kernel(
241
+ float *affine_model, float *filtered_affine_model, float *guide,
242
+ int h, int w, int kernel_radius, float sigma1, float sigma2
243
+ )
244
+ {
245
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
246
+ int size = h * w;
247
+ if (id < size) {
248
+ int x = id % w;
249
+ int y = id / w;
250
+
251
+ double sum_affine[12] = {};
252
+ double sum_weight = 0;
253
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
254
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
255
+ int yy = y + dy, xx = x + dx;
256
+ int id2 = yy * w + xx;
257
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
258
+ float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
259
+ float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
260
+ float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
261
+ float color_diff_sqr =
262
+ (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
263
+
264
+ float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
265
+ float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
266
+ float weight = v1 * v2;
267
+
268
+ for (int i = 0; i < 3; i++) {
269
+ for (int j = 0; j < 4; j++) {
270
+ int affine_id = i * 4 + j;
271
+ sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
272
+ }
273
+ }
274
+ sum_weight += weight;
275
+ }
276
+ }
277
+ }
278
+
279
+ for (int i = 0; i < 3; i++) {
280
+ for (int j = 0; j < 4; j++) {
281
+ int affine_id = i * 4 + j;
282
+ filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
283
+ }
284
+ }
285
+ }
286
+ return ;
287
+ }
288
+
289
+
290
+ extern "C"
291
+ __global__ void reconstruction_best_kernel(
292
+ float *input, float *filtered_affine_model, float *filtered_best_output,
293
+ int h, int w
294
+ )
295
+ {
296
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
297
+ int size = h * w;
298
+ if (id < size) {
299
+ double out1 =
300
+ input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
301
+ input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
302
+ input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
303
+ filtered_affine_model[id*12 + 3]; //A[0][3];
304
+ double out2 =
305
+ input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
306
+ input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
307
+ input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
308
+ filtered_affine_model[id*12 + 7]; //A[1][3];
309
+ double out3 =
310
+ input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
311
+ input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
312
+ input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
313
+ filtered_affine_model[id*12 + 11]; // A[2][3];
314
+
315
+ filtered_best_output[id] = out1;
316
+ filtered_best_output[id + size] = out2;
317
+ filtered_best_output[id + 2*size] = out3;
318
+ }
319
+ return ;
320
+ }
321
+ '''
322
+
323
+ import cv2
324
+ import torch
325
+ import numpy as np
326
+ from PIL import Image
327
+ from cupy.cuda import function
328
+ from pynvrtc.compiler import Program
329
+ from collections import namedtuple
330
+
331
+
332
+ def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
333
+ # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
334
+ # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
335
+ program = Program(src, 'best_local_affine_kernel.cu')
336
+ ptx = program.compile(['-I/usr/local/cuda/include'])
337
+ m = function.Module()
338
+ m.load(bytes(ptx.encode()))
339
+
340
+ _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
341
+ _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
342
+ _best_local_affine_kernel = m.get_function('best_local_affine_kernel')
343
+ Stream = namedtuple('Stream', ['ptr'])
344
+ s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
345
+
346
+ filter_radius = f_r
347
+ sigma1 = filter_radius / 3
348
+ sigma2 = f_e
349
+ radius = (patch - 1) / 2
350
+
351
+ filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
352
+ affine_model = torch.zeros((h * w, 12)).cuda()
353
+ filtered_affine_model =torch.zeros((h * w, 12)).cuda()
354
+
355
+ input_ = torch.from_numpy(input_cpu).cuda()
356
+ output_ = torch.from_numpy(output_cpu).cuda()
357
+ _best_local_affine_kernel(
358
+ grid=(int((h * w) / 256 + 1), 1),
359
+ block=(256, 1, 1),
360
+ args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
361
+ np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
362
+ )
363
+
364
+ _bilateral_smooth_kernel(
365
+ grid=(int((h * w) / 256 + 1), 1),
366
+ block=(256, 1, 1),
367
+ args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
368
+ )
369
+
370
+ _reconstruction_best_kernel(
371
+ grid=(int((h * w) / 256 + 1), 1),
372
+ block=(256, 1, 1),
373
+ args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
374
+ np.int32(h), np.int32(w)], stream=s
375
+ )
376
+ numpy_filtered_best_output = filtered_best_output.cpu().numpy()
377
+ return numpy_filtered_best_output
378
+
379
+
380
+ def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
381
+ '''
382
+ :param initImg: intermediate output. Either image path or PIL Image
383
+ :param contentImg: content image output. Either path or PIL Image
384
+ :return: stylized output image. PIL Image
385
+ '''
386
+ if type(initImg) == str:
387
+ initImg = Image.open(initImg).convert("RGB")
388
+ best_image_bgr = np.array(initImg, dtype=np.float32)
389
+ bW, bH, bC = best_image_bgr.shape
390
+ best_image_bgr = best_image_bgr[:, :, ::-1]
391
+ best_image_bgr = best_image_bgr.transpose((2, 0, 1))
392
+
393
+ if type(contentImg) == str:
394
+ contentImg = Image.open(contentImg).convert("RGB")
395
+ content_input = contentImg.resize((bH,bW))
396
+ else:
397
+ content_input = cv2.resize(contentImg,(bH,bW))
398
+ content_input = np.array(content_input, dtype=np.float32)
399
+ content_input = content_input[:, :, ::-1]
400
+ content_input = content_input.transpose((2, 0, 1))
401
+ input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
402
+ _, H, W = np.shape(input_)
403
+ output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
404
+ best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
405
+ best_ = best_.transpose(1, 2, 0)
406
+ result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
407
+ return result
graph_networks/LinearStyleTransfer/libs/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import os
3
+ import cv2
4
+ import time
5
+ import torch
6
+ import scipy.misc
7
+ import numpy as np
8
+ import scipy.sparse
9
+ from PIL import Image
10
+ import scipy.sparse.linalg
11
+ from cv2.ximgproc import jointBilateralFilter
12
+ from numpy.lib.stride_tricks import as_strided
13
+
14
+ def whiten(cF):
15
+ cFSize = cF.size()
16
+ c_mean = torch.mean(cF,1) # c x (h x w)
17
+ c_mean = c_mean.unsqueeze(1).expand_as(cF)
18
+ cF = cF - c_mean
19
+
20
+ contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
21
+ c_u,c_e,c_v = torch.svd(contentConv,some=False)
22
+
23
+ k_c = cFSize[0]
24
+ for i in range(cFSize[0]):
25
+ if c_e[i] < 0.00001:
26
+ k_c = i
27
+ break
28
+
29
+ c_d = (c_e[0:k_c]).pow(-0.5)
30
+ step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
31
+ step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
32
+ whiten_cF = torch.mm(step2,cF)
33
+ return whiten_cF
34
+
35
+ def numpy2cv2(cont,style,prop,width,height):
36
+ cont = cont.transpose((1,2,0))
37
+ cont = cont[...,::-1]
38
+ cont = cont * 255
39
+ cont = cv2.resize(cont,(width,height))
40
+ #cv2.resize(iimg,(width,height))
41
+ style = style.transpose((1,2,0))
42
+ style = style[...,::-1]
43
+ style = style * 255
44
+ style = cv2.resize(style,(width,height))
45
+
46
+ prop = prop.transpose((1,2,0))
47
+ prop = prop[...,::-1]
48
+ prop = prop * 255
49
+ prop = cv2.resize(prop,(width,height))
50
+
51
+ #return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1)
52
+ return prop,cont
53
+
54
+ def makeVideo(content,style,props,outf):
55
+ print('Stack transferred frames back to video...')
56
+ layers,height,width = content[0].shape
57
+ fourcc = cv2.VideoWriter_fourcc(*'MJPG')
58
+ video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height))
59
+ ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height))
60
+ for j in range(len(content)):
61
+ prop,cont = numpy2cv2(content[j],style,props[j],width,height)
62
+ cv2.imwrite('prop.png',prop)
63
+ cv2.imwrite('content.png',cont)
64
+ # TODO: this is ugly, fix this
65
+ imgj = cv2.imread('prop.png')
66
+ imgc = cv2.imread('content.png')
67
+
68
+ video.write(imgj)
69
+ ori_video.write(imgc)
70
+ # RGB or BRG, yuks
71
+ video.release()
72
+ ori_video.release()
73
+ os.remove('prop.png')
74
+ os.remove('content.png')
75
+ print('Transferred video saved at %s.'%outf)
76
+
77
+ def print_options(opt):
78
+ message = ''
79
+ message += '----------------- Options ---------------\n'
80
+ for k, v in sorted(vars(opt).items()):
81
+ comment = ''
82
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
83
+ message += '----------------- End -------------------'
84
+ print(message)
85
+
86
+ # save to the disk
87
+ expr_dir = os.path.join(opt.outf)
88
+ os.makedirs(expr_dir,exist_ok=True)
89
+ file_name = os.path.join(expr_dir, 'opt.txt')
90
+ with open(file_name, 'wt') as opt_file:
91
+ opt_file.write(message)
92
+ opt_file.write('\n')
graph_networks/LinearStyleTransfer/models/dec_r31.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ccc3bbc97a15e1d002d0b13523543c518dda2d0346c8f4d39c1d381d8490f68
3
+ size 2221888
graph_networks/LinearStyleTransfer/models/dec_r41.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6858f96d3d0882fa3b40652a0315928219086e4bcb0e3efbe43bd04ea631911
3
+ size 14023509
graph_networks/LinearStyleTransfer/models/r31.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bb75be684b331a105a8fad266343556067e4eec888249f7abb693a66b5ad7e3
3
+ size 11564438