Spaces:
Running
Running
Initial Upload (attempt 2)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +7 -0
- LICENSE +21 -0
- README.md +36 -14
- app.py +302 -0
- clusters.py +234 -0
- example-broche-rose-gold.splat +3 -0
- example.jpg +0 -0
- graph_helpers.py +400 -0
- graph_io.py +306 -0
- graph_networks/.DS_Store +0 -0
- graph_networks/LinearStyleTransfer/.DS_Store +0 -0
- graph_networks/LinearStyleTransfer/LICENSE +25 -0
- graph_networks/LinearStyleTransfer/README.md +102 -0
- graph_networks/LinearStyleTransfer/TestArtistic.py +98 -0
- graph_networks/LinearStyleTransfer/TestPhotoReal.py +118 -0
- graph_networks/LinearStyleTransfer/TestVideo.py +108 -0
- graph_networks/LinearStyleTransfer/Train.py +185 -0
- graph_networks/LinearStyleTransfer/TrainSPN.py +141 -0
- graph_networks/LinearStyleTransfer/__init__.py +0 -0
- graph_networks/LinearStyleTransfer/libs/.DS_Store +0 -0
- graph_networks/LinearStyleTransfer/libs/Criterion.py +62 -0
- graph_networks/LinearStyleTransfer/libs/Loader.py +44 -0
- graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py +162 -0
- graph_networks/LinearStyleTransfer/libs/Matrix.py +89 -0
- graph_networks/LinearStyleTransfer/libs/MatrixTest.py +154 -0
- graph_networks/LinearStyleTransfer/libs/SPN.py +156 -0
- graph_networks/LinearStyleTransfer/libs/__init__.py +0 -0
- graph_networks/LinearStyleTransfer/libs/models.py +662 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md +12 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py +0 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py +0 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py +15 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py +34 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py +0 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py +47 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py +46 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh +9 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py +1 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py +12 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store +0 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu +697 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o +0 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h +28 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c +91 -0
- graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h +6 -0
- graph_networks/LinearStyleTransfer/libs/smooth_filter.py +407 -0
- graph_networks/LinearStyleTransfer/libs/utils.py +92 -0
- graph_networks/LinearStyleTransfer/models/dec_r31.pth +3 -0
- graph_networks/LinearStyleTransfer/models/dec_r41.pth +3 -0
- graph_networks/LinearStyleTransfer/models/r31.pth +3 -0
.gitattributes
CHANGED
|
@@ -40,3 +40,10 @@ FastSplatStyler_huggingface/style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -
|
|
| 40 |
FastSplatStyler_huggingface/style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
FastSplatStyler_huggingface/style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
FastSplatStyler_huggingface/style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
FastSplatStyler_huggingface/style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
FastSplatStyler_huggingface/style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
FastSplatStyler_huggingface/style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
example-broche-rose-gold.splat filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
output.splat filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
style_ims/style-10.jpg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 ECU Computer Vision Lab
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,14 +1,36 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastSplatStyler
|
| 2 |
+
Official Implementation of "Optimization-Free Style Transfer of 3D Gaussian Splats"
|
| 3 |
+
|
| 4 |
+
[arXiv Paper](https://arxiv.org/abs/2508.05813)
|
| 5 |
+
|
| 6 |
+

|
| 7 |
+
|
| 8 |
+
## Example Outputs
|
| 9 |
+
|
| 10 |
+
Example Outputs can be visualized using the [Antimatter WebGL viewer](https://antimatter15.com/splat/) at the following links.
|
| 11 |
+
|
| 12 |
+
- Broche: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_style3.splat)
|
| 13 |
+
- Crystal Lamp: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-style2.splat)
|
| 14 |
+
- Family Statue: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family-style-6.splat)
|
| 15 |
+
- M60 Tanks: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60-style-31.splat)
|
| 16 |
+
- Table: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table_style5.splat)
|
| 17 |
+
- Train: [Original](https://antimatter15.com/splat/) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/train_style1.splat)
|
| 18 |
+
- Truck: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck-style-21.splat)
|
| 19 |
+
|
| 20 |
+
Example inputs and outputs can be downloaded from [Google Drive](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link) or [Hugging Face](https://huggingface.co/datasets/incrl/fast-splat-styler/tree/main)
|
| 21 |
+
|
| 22 |
+
## Demo
|
| 23 |
+
|
| 24 |
+
**Coming soon**
|
| 25 |
+
|
| 26 |
+
## Install
|
| 27 |
+
|
| 28 |
+
This work relies heavily on the [Pytorch](https://pytorch.org/) and [Pytorch Geometric](https://www.pyg.org/) libraries.
|
| 29 |
+
|
| 30 |
+
This code was tested with Python 3.12, Pytorch 2.9.1 (CUDA Toolkit 12.8), and Pytorch Geometric 2.8 on Windows. It was also tested on Mac with the same setup (without Cuda). The provided requirements.txt file comes from the Mac configuration.
|
| 31 |
+
|
| 32 |
+
This repository relies on a graph networks library that was presented in a [previous work](https://github.com/davidmhart/interpolated-selectionconv/tree/main). The library can be downloaded, with included model weights, at this [google drive link](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link). Place the "graph_networks" folder in the main directory.
|
| 33 |
+
|
| 34 |
+
To stylize an example splat, run `python styletransfer_splat.py example-broche-rose-gold.splat --stylePath style_ims/style0.jpg --samplingRate 1.5`. You can change the input splat and style image to your specific use case.
|
| 35 |
+
|
| 36 |
+
Supports `.splat` and `.ply` files.
|
app.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from time import time
|
| 8 |
+
|
| 9 |
+
# ── Core style-transfer logic (adapted from styletransfer_splat.py) ──────────
|
| 10 |
+
import pointCloudToMesh as ply2M
|
| 11 |
+
import utils
|
| 12 |
+
import graph_io as gio
|
| 13 |
+
from clusters import *
|
| 14 |
+
import splat_mesh_helpers as splt
|
| 15 |
+
import clusters as cl
|
| 16 |
+
from torch_geometric.data import Data
|
| 17 |
+
from scipy.interpolate import NearestNDInterpolator
|
| 18 |
+
|
| 19 |
+
from graph_networks.LinearStyleTransfer_vgg import encoder, decoder
|
| 20 |
+
from graph_networks.LinearStyleTransfer_matrix import TransformLayer
|
| 21 |
+
from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
|
| 22 |
+
from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Example assets (place your own files in ./examples/) ─────────────────────
|
| 26 |
+
EXAMPLE_SPLATS = [
|
| 27 |
+
["example-broche-rose-gold.splat", "style_ims/style2.jpg"],
|
| 28 |
+
["example-broche-rose-gold.splat", "style_ims/style6.jpg"],
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Style-transfer function called by Gradio ─────────────────────────────────
|
| 33 |
+
def run_style_transfer(
|
| 34 |
+
splat_file,
|
| 35 |
+
style_image,
|
| 36 |
+
threshold: float,
|
| 37 |
+
sampling_rate: float,
|
| 38 |
+
device_choice: str,
|
| 39 |
+
progress=gr.Progress(track_tqdm=True),
|
| 40 |
+
):
|
| 41 |
+
if splat_file is None:
|
| 42 |
+
raise gr.Error("Please upload a 3D Gaussian Splat file (.ply or .splat).")
|
| 43 |
+
if style_image is None:
|
| 44 |
+
raise gr.Error("Please upload a style image.")
|
| 45 |
+
|
| 46 |
+
device = device_choice if device_choice == "cpu" else f"cuda:{device_choice}"
|
| 47 |
+
|
| 48 |
+
# ── Parameters ────────────────────────────────────────────────────────────
|
| 49 |
+
n = 25
|
| 50 |
+
ratio = 0.25
|
| 51 |
+
depth = 3
|
| 52 |
+
style_shape = (512, 512)
|
| 53 |
+
|
| 54 |
+
logs = []
|
| 55 |
+
|
| 56 |
+
def log(msg):
|
| 57 |
+
logs.append(msg)
|
| 58 |
+
print(msg)
|
| 59 |
+
return "\n".join(logs)
|
| 60 |
+
|
| 61 |
+
# ── 1. Load splat ─────────────────────────────────────────────────────────
|
| 62 |
+
progress(0.05, desc="Loading splat…")
|
| 63 |
+
splat_path = splat_file.name if hasattr(splat_file, "name") else splat_file
|
| 64 |
+
log(f"Loading splat: {splat_path}")
|
| 65 |
+
|
| 66 |
+
pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = \
|
| 67 |
+
splt.splat_unpacker_with_threshold(n, splat_path, threshold)
|
| 68 |
+
|
| 69 |
+
# ── 2. Gaussian super-sampling ────────────────────────────────────────────
|
| 70 |
+
progress(0.15, desc="Super-sampling…")
|
| 71 |
+
t0 = time()
|
| 72 |
+
if sampling_rate > 1:
|
| 73 |
+
GaussianSamples = int(pos3D_Original.shape[0] * sampling_rate)
|
| 74 |
+
pos3D, colors = splt.splat_GaussianSuperSampler(
|
| 75 |
+
pos3D_Original.clone(), colors_Original.clone(),
|
| 76 |
+
opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(),
|
| 77 |
+
GaussianSamples,
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
pos3D, colors = pos3D_Original, colors_Original
|
| 81 |
+
log(f"Nodes in graph: {pos3D.shape[0]} ({time()-t0:.1f}s)")
|
| 82 |
+
|
| 83 |
+
# ── 3. Graph construction ─────────────────────────────────────────────────
|
| 84 |
+
progress(0.30, desc="Building surface graph…")
|
| 85 |
+
t0 = time()
|
| 86 |
+
style_ref = utils.loadImage(style_image, shape=style_shape)
|
| 87 |
+
|
| 88 |
+
normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
|
| 89 |
+
normals = torch.from_numpy(normalsNP)
|
| 90 |
+
|
| 91 |
+
up_vector = torch.tensor([[1, 1, 1]], dtype=torch.float)
|
| 92 |
+
up_vector = up_vector / torch.linalg.norm(up_vector, dim=1)
|
| 93 |
+
|
| 94 |
+
pos3D = pos3D.to(device)
|
| 95 |
+
colors = colors.to(device)
|
| 96 |
+
normals = normals.to(device)
|
| 97 |
+
up_vector = up_vector.to(device)
|
| 98 |
+
|
| 99 |
+
edge_index, directions = gh.surface2Edges(pos3D, normals, up_vector, k_neighbors=16)
|
| 100 |
+
edge_index, selections, interps = gh.edges2Selections(edge_index, directions, interpolated=True)
|
| 101 |
+
|
| 102 |
+
clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(
|
| 103 |
+
pos3D, normals, edge_index, selections, interps,
|
| 104 |
+
ratio=ratio, up_vector=up_vector, depth=depth, device=device,
|
| 105 |
+
)
|
| 106 |
+
log(f"Graph built ({time()-t0:.1f}s)")
|
| 107 |
+
|
| 108 |
+
# ── 4. Load networks ──────────────────────────────────────────────────────
|
| 109 |
+
progress(0.50, desc="Loading networks…")
|
| 110 |
+
t0 = time()
|
| 111 |
+
|
| 112 |
+
enc_ref = encoder4()
|
| 113 |
+
dec_ref = decoder4()
|
| 114 |
+
matrix_ref = MulLayer("r41")
|
| 115 |
+
|
| 116 |
+
enc_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/vgg_r41.pth", map_location=device))
|
| 117 |
+
dec_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/dec_r41.pth", map_location=device))
|
| 118 |
+
matrix_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/r41.pth", map_location=device))
|
| 119 |
+
|
| 120 |
+
enc = encoder(padding_mode="replicate")
|
| 121 |
+
dec = decoder(padding_mode="replicate")
|
| 122 |
+
matrix = TransformLayer()
|
| 123 |
+
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
enc.copy_weights(enc_ref)
|
| 126 |
+
dec.copy_weights(dec_ref)
|
| 127 |
+
matrix.copy_weights(matrix_ref)
|
| 128 |
+
|
| 129 |
+
content = Data(
|
| 130 |
+
x=colors, clusters=clusters,
|
| 131 |
+
edge_indexes=edge_indexes,
|
| 132 |
+
selections_list=selections_list,
|
| 133 |
+
interps_list=interps_list,
|
| 134 |
+
).to(device)
|
| 135 |
+
|
| 136 |
+
style, _ = gio.image2Graph(style_ref, depth=3, device=device)
|
| 137 |
+
|
| 138 |
+
enc = enc.to(device)
|
| 139 |
+
dec = dec.to(device)
|
| 140 |
+
matrix = matrix.to(device)
|
| 141 |
+
log(f"Networks loaded ({time()-t0:.1f}s)")
|
| 142 |
+
|
| 143 |
+
# ── 5. Style transfer ─────────────────────────────────────────────────────
|
| 144 |
+
progress(0.70, desc="Running style transfer…")
|
| 145 |
+
t0 = time()
|
| 146 |
+
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
cF = enc(content)
|
| 149 |
+
sF = enc(style)
|
| 150 |
+
feature, _ = matrix(
|
| 151 |
+
cF["r41"], sF["r41"],
|
| 152 |
+
content.edge_indexes[3], content.selections_list[3],
|
| 153 |
+
style.edge_indexes[3], style.selections_list[3],
|
| 154 |
+
content.interps_list[3] if hasattr(content, "interps_list") else None,
|
| 155 |
+
)
|
| 156 |
+
result = dec(feature, content).clamp(0, 1)
|
| 157 |
+
|
| 158 |
+
colors[:, 0:3] = result
|
| 159 |
+
log(f"Stylization done ({time()-t0:.1f}s)")
|
| 160 |
+
|
| 161 |
+
# ── 6. Interpolate back to original resolution ────────────────────────────
|
| 162 |
+
progress(0.88, desc="Interpolating back to original splat…")
|
| 163 |
+
t0 = time()
|
| 164 |
+
|
| 165 |
+
interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
|
| 166 |
+
results_OriginalNP = interp2(pos3D_Original)
|
| 167 |
+
results_Original = torch.from_numpy(results_OriginalNP).to(torch.float32)
|
| 168 |
+
colors_and_opacity_Original = torch.cat(
|
| 169 |
+
(results_Original, opacity_Original.unsqueeze(1)), dim=1
|
| 170 |
+
)
|
| 171 |
+
log(f"Interpolation done ({time()-t0:.1f}s)")
|
| 172 |
+
|
| 173 |
+
# ── 7. Save output ────────────────────────────────────────────────────────
|
| 174 |
+
progress(0.95, desc="Saving output splat…")
|
| 175 |
+
suffix = ".splat" if fileType == "splat" else ".ply"
|
| 176 |
+
out_dir = tempfile.mkdtemp()
|
| 177 |
+
out_path = os.path.join(out_dir, f"stylized{suffix}")
|
| 178 |
+
|
| 179 |
+
splt.splat_save(
|
| 180 |
+
pos3D_Original.numpy(),
|
| 181 |
+
scales_Original.numpy(),
|
| 182 |
+
rots_Original.numpy(),
|
| 183 |
+
colors_and_opacity_Original.numpy(),
|
| 184 |
+
out_path,
|
| 185 |
+
fileType,
|
| 186 |
+
)
|
| 187 |
+
log(f"Saved to: {out_path}")
|
| 188 |
+
progress(1.0, desc="Done!")
|
| 189 |
+
|
| 190 |
+
return out_path, "\n".join(logs)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ── Gradio UI ─────────────────────────────────────────────────────────────────
|
| 194 |
+
def build_ui():
|
| 195 |
+
available_devices = (
|
| 196 |
+
[str(i) for i in range(torch.cuda.device_count())] + ["cpu"]
|
| 197 |
+
if torch.cuda.is_available()
|
| 198 |
+
else ["cpu"]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
with gr.Blocks(
|
| 202 |
+
title="3DGS Style Transfer",
|
| 203 |
+
theme=gr.themes.Soft(primary_hue="violet"),
|
| 204 |
+
css="""
|
| 205 |
+
#title { text-align: center; }
|
| 206 |
+
#subtitle { text-align: center; color: #666; margin-bottom: 1rem; }
|
| 207 |
+
.panel { border-radius: 12px; }
|
| 208 |
+
#run-btn { font-size: 1.1rem; }
|
| 209 |
+
""",
|
| 210 |
+
) as demo:
|
| 211 |
+
|
| 212 |
+
gr.Markdown("# 🎨 3D Gaussian Splat Style Transfer", elem_id="title")
|
| 213 |
+
gr.Markdown(
|
| 214 |
+
"Upload a 3DGS scene and a style image — the app will repaint the splat "
|
| 215 |
+
"with the artistic style of the image and give you a stylized splat to download. "
|
| 216 |
+
"After downloading, you can view your splat with an [online viewer](https://antimatter15.com/splat/).",
|
| 217 |
+
elem_id="subtitle",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
with gr.Row():
|
| 221 |
+
# ── Left column: inputs ───────────────────────────────────────────
|
| 222 |
+
with gr.Column(scale=1, elem_classes="panel"):
|
| 223 |
+
gr.Markdown("### 📂 Inputs")
|
| 224 |
+
|
| 225 |
+
splat_input = gr.File(
|
| 226 |
+
label="3D Gaussian Splat (.ply or .splat)",
|
| 227 |
+
file_types=[".ply", ".splat"],
|
| 228 |
+
type="filepath",
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
style_input = gr.Image(
|
| 232 |
+
label="Style Image",
|
| 233 |
+
type="filepath",
|
| 234 |
+
height=240,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 238 |
+
threshold_slider = gr.Slider(
|
| 239 |
+
minimum=90.0, maximum=100.0, value=99.8, step=0.1,
|
| 240 |
+
label="Opacity threshold (percentile)",
|
| 241 |
+
info="Points below this opacity percentile are removed.",
|
| 242 |
+
)
|
| 243 |
+
sampling_slider = gr.Slider(
|
| 244 |
+
minimum=0.5, maximum=3.0, value=1.5, step=0.1,
|
| 245 |
+
label="Gaussian super-sampling rate",
|
| 246 |
+
info="Values > 1 add extra samples; 1.0 = no super-sampling.",
|
| 247 |
+
)
|
| 248 |
+
device_radio = gr.Radio(
|
| 249 |
+
choices=available_devices,
|
| 250 |
+
value=available_devices[0],
|
| 251 |
+
label="Device",
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
run_btn = gr.Button("🚀 Run Style Transfer", variant="primary", elem_id="run-btn")
|
| 255 |
+
|
| 256 |
+
# ── Right column: outputs ─────────────────────────────────────────
|
| 257 |
+
with gr.Column(scale=1, elem_classes="panel"):
|
| 258 |
+
gr.Markdown("### 📥 Output")
|
| 259 |
+
|
| 260 |
+
output_file = gr.File(
|
| 261 |
+
label="Download Stylized Splat",
|
| 262 |
+
interactive=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
log_box = gr.Textbox(
|
| 266 |
+
label="Progress log",
|
| 267 |
+
lines=12,
|
| 268 |
+
max_lines=20,
|
| 269 |
+
interactive=False,
|
| 270 |
+
placeholder="Logs will appear here once processing starts…",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# ── Examples ─────────────────────────────────────────────────────────
|
| 274 |
+
example_splat_paths = [row[0] for row in EXAMPLE_SPLATS]
|
| 275 |
+
example_style_paths = [row[1] for row in EXAMPLE_SPLATS]
|
| 276 |
+
|
| 277 |
+
valid_examples = [
|
| 278 |
+
row for row in EXAMPLE_SPLATS
|
| 279 |
+
if os.path.exists(row[0]) and os.path.exists(row[1])
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
if valid_examples:
|
| 283 |
+
gr.Markdown("### 🖼️ Examples")
|
| 284 |
+
gr.Examples(
|
| 285 |
+
examples=valid_examples,
|
| 286 |
+
inputs=[splat_input, style_input],
|
| 287 |
+
label="Click an example to load it",
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# ── Event wiring ──────────────────────────────────────────────────────
|
| 291 |
+
run_btn.click(
|
| 292 |
+
fn=run_style_transfer,
|
| 293 |
+
inputs=[splat_input, style_input, threshold_slider, sampling_slider, device_radio],
|
| 294 |
+
outputs=[output_file, log_box],
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return demo
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
demo = build_ui()
|
| 302 |
+
demo.launch(share=False)
|
clusters.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_scatter import scatter
|
| 3 |
+
from torch_geometric.nn.pool.consecutive import consecutive_cluster
|
| 4 |
+
from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops
|
| 5 |
+
from torch_geometric.nn import fps, knn
|
| 6 |
+
from torch_sparse import coalesce
|
| 7 |
+
import graph_helpers as gh
|
| 8 |
+
import sphere_helpers as sh
|
| 9 |
+
import mesh_helpers as mh
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
from math import pi,sqrt
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from warnings import warn
|
| 16 |
+
|
| 17 |
+
def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2):
|
| 18 |
+
clusters = []
|
| 19 |
+
edge_indexes = [torch.clone(edge_index).to(device)]
|
| 20 |
+
selections_list = [torch.clone(selections).to(device)]
|
| 21 |
+
|
| 22 |
+
for _ in range(depth):
|
| 23 |
+
Nx = Nx//stride
|
| 24 |
+
Ny = Ny//stride
|
| 25 |
+
cx,cy = getGrid(pos2D,Nx,Ny)
|
| 26 |
+
cluster, pos2D = gridCluster(pos2D,cx,cy,Nx)
|
| 27 |
+
edge_index, selections = selectionAverage(cluster, edge_index, selections)
|
| 28 |
+
|
| 29 |
+
clusters.append(torch.clone(cluster).to(device))
|
| 30 |
+
edge_indexes.append(torch.clone(edge_index).to(device))
|
| 31 |
+
selections_list.append(torch.clone(selections).to(device))
|
| 32 |
+
|
| 33 |
+
return clusters, edge_indexes, selections_list
|
| 34 |
+
|
| 35 |
+
def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'):
|
| 36 |
+
clusters = []
|
| 37 |
+
edge_indexes = [torch.clone(edge_index).to(device)]
|
| 38 |
+
selections_list = [torch.clone(selections).to(device)]
|
| 39 |
+
interps_list = [torch.clone(interps).to(device)]
|
| 40 |
+
|
| 41 |
+
for _ in range(depth):
|
| 42 |
+
|
| 43 |
+
rows = rows//stride
|
| 44 |
+
cols = cols//stride
|
| 45 |
+
|
| 46 |
+
if bary_d is not None:
|
| 47 |
+
bary_d = bary_d*stride
|
| 48 |
+
|
| 49 |
+
if cluster_method == "equirec":
|
| 50 |
+
centroids, _ = sh.sampleSphere_Equirec(rows,cols)
|
| 51 |
+
|
| 52 |
+
elif cluster_method == "layering":
|
| 53 |
+
centroids, _ = sh.sampleSphere_Layering(rows)
|
| 54 |
+
|
| 55 |
+
elif cluster_method == "spiral":
|
| 56 |
+
centroids, _ = sh.sampleSphere_Spiral(rows,cols)
|
| 57 |
+
|
| 58 |
+
elif cluster_method == "icosphere":
|
| 59 |
+
centroids, _ = sh.sampleSphere_Icosphere(rows)
|
| 60 |
+
|
| 61 |
+
elif cluster_method == "random":
|
| 62 |
+
centroids, _ = sh.sampleSphere_Random(rows,cols)
|
| 63 |
+
|
| 64 |
+
elif cluster_method == "random_nodes":
|
| 65 |
+
index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
|
| 66 |
+
centroids = pos3D[index]
|
| 67 |
+
|
| 68 |
+
elif cluster_method == "fps":
|
| 69 |
+
# Farthest Point Search used in PointNet++
|
| 70 |
+
index = fps(pos3D, ratio=ratio)
|
| 71 |
+
centroids = pos3D[index]
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError("Sphere cluster_method unknown")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Find closest centriod to each current point
|
| 77 |
+
cluster = knn(centroids,pos3D,1)[1]
|
| 78 |
+
cluster, _ = consecutive_cluster(cluster)
|
| 79 |
+
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
|
| 80 |
+
|
| 81 |
+
# Regenerate surface graph
|
| 82 |
+
normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors
|
| 83 |
+
edge_index,directions = gh.surface2Edges(pos3D,normals)
|
| 84 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
|
| 85 |
+
|
| 86 |
+
clusters.append(torch.clone(cluster).to(device))
|
| 87 |
+
edge_indexes.append(torch.clone(edge_index).to(device))
|
| 88 |
+
selections_list.append(torch.clone(selections).to(device))
|
| 89 |
+
interps_list.append(torch.clone(interps).to(device))
|
| 90 |
+
|
| 91 |
+
return clusters, edge_indexes, selections_list, interps_list
|
| 92 |
+
|
| 93 |
+
def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'):
|
| 94 |
+
clusters = []
|
| 95 |
+
edge_indexes = [torch.clone(edge_index).to(device)]
|
| 96 |
+
selections_list = [torch.clone(selections).to(device)]
|
| 97 |
+
interps_list = [torch.clone(interps).to(device)]
|
| 98 |
+
|
| 99 |
+
for _ in range(depth):
|
| 100 |
+
|
| 101 |
+
#Desired number of clusters in the next level
|
| 102 |
+
N = int(len(pos3D) * ratio)
|
| 103 |
+
|
| 104 |
+
if cluster_method == "random":
|
| 105 |
+
index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
|
| 106 |
+
centroids = pos3D[index]
|
| 107 |
+
|
| 108 |
+
elif cluster_method == "fps":
|
| 109 |
+
# Farthest Point Search used in PointNet++
|
| 110 |
+
index = fps(pos3D, ratio=ratio)
|
| 111 |
+
centroids = pos3D[index]
|
| 112 |
+
|
| 113 |
+
# Find closest centriod to each current point
|
| 114 |
+
cluster = knn(centroids,pos3D,1)[1]
|
| 115 |
+
cluster, _ = consecutive_cluster(cluster)
|
| 116 |
+
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
|
| 117 |
+
normals = scatter(normals, cluster, dim=0, reduce='mean')
|
| 118 |
+
|
| 119 |
+
# Regenerate surface graph
|
| 120 |
+
normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
|
| 121 |
+
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
|
| 122 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
|
| 123 |
+
|
| 124 |
+
clusters.append(torch.clone(cluster).to(device))
|
| 125 |
+
edge_indexes.append(torch.clone(edge_index).to(device))
|
| 126 |
+
selections_list.append(torch.clone(selections).to(device))
|
| 127 |
+
interps_list.append(torch.clone(interps).to(device))
|
| 128 |
+
|
| 129 |
+
return clusters, edge_indexes, selections_list, interps_list
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'):
|
| 133 |
+
clusters = []
|
| 134 |
+
edge_indexes = [torch.clone(edge_index).to(device)]
|
| 135 |
+
selections_list = [torch.clone(selections).to(device)]
|
| 136 |
+
interps_list = [torch.clone(interps).to(device)]
|
| 137 |
+
|
| 138 |
+
for _ in range(depth):
|
| 139 |
+
|
| 140 |
+
#Desired number of clusters in the next level
|
| 141 |
+
N = int(len(pos3D) * ratio)
|
| 142 |
+
|
| 143 |
+
# Generate new point cloud from downsampled version of texture map
|
| 144 |
+
centroids, normals = mh.sampleSurface(mesh,N,return_x=False)
|
| 145 |
+
|
| 146 |
+
# Find closest centriod to each current point
|
| 147 |
+
cluster = knn(centroids,pos3D,1)[1]
|
| 148 |
+
cluster, _ = consecutive_cluster(cluster)
|
| 149 |
+
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
|
| 150 |
+
|
| 151 |
+
# Regenerate surface graph
|
| 152 |
+
#normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
|
| 153 |
+
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector)
|
| 154 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
|
| 155 |
+
|
| 156 |
+
clusters.append(torch.clone(cluster).to(device))
|
| 157 |
+
edge_indexes.append(torch.clone(edge_index).to(device))
|
| 158 |
+
selections_list.append(torch.clone(selections).to(device))
|
| 159 |
+
interps_list.append(torch.clone(interps).to(device))
|
| 160 |
+
|
| 161 |
+
return clusters, edge_indexes, selections_list, interps_list
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def getGrid(pos,Nx,Ny,xrange=None,yrange=None):
|
| 166 |
+
xmin = torch.min(pos[:,0]) if xrange is None else xrange[0]
|
| 167 |
+
ymin = torch.min(pos[:,1]) if yrange is None else yrange[0]
|
| 168 |
+
xmax = torch.max(pos[:,0]) if xrange is None else xrange[1]
|
| 169 |
+
ymax = torch.max(pos[:,1]) if yrange is None else yrange[1]
|
| 170 |
+
|
| 171 |
+
cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1)
|
| 172 |
+
cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1)
|
| 173 |
+
return cx, cy
|
| 174 |
+
|
| 175 |
+
def gridCluster(pos,cx,cy,xmax):
|
| 176 |
+
cluster = cx + cy*xmax
|
| 177 |
+
cluster = cluster.type(torch.long) # Cast appropriately
|
| 178 |
+
cluster, _ = consecutive_cluster(cluster)
|
| 179 |
+
pos = scatter(pos, cluster, dim=0, reduce='mean')
|
| 180 |
+
|
| 181 |
+
return cluster, pos
|
| 182 |
+
|
| 183 |
+
def selectionAverage(cluster, edge_index, selections):
|
| 184 |
+
num_nodes = cluster.size(0)
|
| 185 |
+
edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1)
|
| 186 |
+
edge_index, selections = remove_self_loops(edge_index, selections)
|
| 187 |
+
if edge_index.numel() > 0:
|
| 188 |
+
|
| 189 |
+
# To avoid means over discontinuities, do mean for two selections at at a time
|
| 190 |
+
final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
|
| 191 |
+
selections_check = torch.round(selections_check).type(torch.long)
|
| 192 |
+
|
| 193 |
+
final_selections = torch.zeros_like(selections_check).to(selections.device)
|
| 194 |
+
|
| 195 |
+
final_selections[torch.where(selections_check==4)] = 4
|
| 196 |
+
final_selections[torch.where(selections_check==5)] = 5
|
| 197 |
+
|
| 198 |
+
#Rotate selection kernel
|
| 199 |
+
selections += 2
|
| 200 |
+
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
|
| 201 |
+
|
| 202 |
+
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
|
| 203 |
+
selections_check = torch.round(selections_check).type(torch.long)
|
| 204 |
+
final_selections[torch.where(selections_check==4)] = 2
|
| 205 |
+
final_selections[torch.where(selections_check==5)] = 3
|
| 206 |
+
|
| 207 |
+
#Rotate selection kernel
|
| 208 |
+
selections += 2
|
| 209 |
+
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
|
| 210 |
+
|
| 211 |
+
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
|
| 212 |
+
selections_check = torch.round(selections_check).type(torch.long)
|
| 213 |
+
final_selections[torch.where(selections_check==4)] = 8
|
| 214 |
+
final_selections[torch.where(selections_check==5)] = 1
|
| 215 |
+
|
| 216 |
+
#Rotate selection kernel
|
| 217 |
+
selections += 2
|
| 218 |
+
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
|
| 219 |
+
|
| 220 |
+
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
|
| 221 |
+
selections_check = torch.round(selections_check).type(torch.long)
|
| 222 |
+
final_selections[torch.where(selections_check==4)] = 6
|
| 223 |
+
final_selections[torch.where(selections_check==5)] = 7
|
| 224 |
+
|
| 225 |
+
#print(torch.min(final_selections), torch.max(final_selections))
|
| 226 |
+
#print(torch.mean(final_selections.type(torch.float)))
|
| 227 |
+
|
| 228 |
+
edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long))
|
| 229 |
+
|
| 230 |
+
else:
|
| 231 |
+
edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long))
|
| 232 |
+
print("Warning: Edge Pool found no edges")
|
| 233 |
+
|
| 234 |
+
return edge_index, selections
|
example-broche-rose-gold.splat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62cccf0adfb4e5713985e8ee874fa30a6a37ae222a23ead7c6e4639a5802ab62
|
| 3 |
+
size 4157728
|
example.jpg
ADDED
|
graph_helpers.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch_geometric.nn import radius_graph, knn_graph
|
| 3 |
+
import torch_geometric as tg
|
| 4 |
+
from torch_geometric.utils import subgraph
|
| 5 |
+
import utils
|
| 6 |
+
from math import sqrt
|
| 7 |
+
|
| 8 |
+
def getImPos(rows,cols,start_row=0,start_col=0):
|
| 9 |
+
row_space = torch.arange(start_row,rows+start_row)
|
| 10 |
+
col_space = torch.arange(start_col,cols+start_col)
|
| 11 |
+
col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy')
|
| 12 |
+
im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2))
|
| 13 |
+
return im_pos
|
| 14 |
+
|
| 15 |
+
def convertImPos(im_pos,flip_y=True):
|
| 16 |
+
|
| 17 |
+
# Cast to float for clustering based methods
|
| 18 |
+
pos2D = im_pos.float()
|
| 19 |
+
|
| 20 |
+
# Switch rows,cols to x,y
|
| 21 |
+
pos2D[:,[1,0]] = pos2D[:,[0,1]]
|
| 22 |
+
|
| 23 |
+
if flip_y:
|
| 24 |
+
|
| 25 |
+
# Flip to y-axis to match mathematical definition and edges2Selections settings
|
| 26 |
+
pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1]
|
| 27 |
+
|
| 28 |
+
return pos2D
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def grid2Edges(locs):
|
| 32 |
+
# Assume locs are already spaced at a distance of 1 structure
|
| 33 |
+
edge_index = radius_graph(locs,1.44,loop=True)
|
| 34 |
+
return edge_index
|
| 35 |
+
|
| 36 |
+
def radius2Edges(locs,r=1.0):
|
| 37 |
+
edge_index = radius_graph(locs,r,loop=True)
|
| 38 |
+
return edge_index
|
| 39 |
+
|
| 40 |
+
def knn2Edges(locs,knn=9):
|
| 41 |
+
edge_index = knn_graph(locs,knn,loop=True)
|
| 42 |
+
return edge_index
|
| 43 |
+
|
| 44 |
+
def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9):
|
| 45 |
+
|
| 46 |
+
if up_vector is None:
|
| 47 |
+
up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device)
|
| 48 |
+
|
| 49 |
+
# K Nearest Neighbors graph
|
| 50 |
+
edge_index = knn_graph(pos3D,k_neighbors,loop=True)
|
| 51 |
+
|
| 52 |
+
# Cull neighbors based on normals (dot them together)
|
| 53 |
+
culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1)
|
| 54 |
+
edge_index = edge_index[:,torch.where(culling>0)[0]]
|
| 55 |
+
|
| 56 |
+
# For each node, rotate based on Grahm-Schmidt Orthognalization
|
| 57 |
+
norms = normals[edge_index[0]]
|
| 58 |
+
|
| 59 |
+
z_dir = norms
|
| 60 |
+
z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector
|
| 61 |
+
#x_dir = torch.cross(up_vector,norms,dim=1)
|
| 62 |
+
x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch
|
| 63 |
+
x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True)
|
| 64 |
+
#y_dir = torch.cross(norms,x_dir,dim=1)
|
| 65 |
+
y_dir = utils.cross(norms,x_dir)
|
| 66 |
+
y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True)
|
| 67 |
+
|
| 68 |
+
directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]])
|
| 69 |
+
|
| 70 |
+
# Perform rotation by multiplying out rotation matrix
|
| 71 |
+
temp = torch.clone(directions) # Buffer
|
| 72 |
+
directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2]
|
| 73 |
+
directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2]
|
| 74 |
+
#directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2]
|
| 75 |
+
|
| 76 |
+
# Drop z coordinate
|
| 77 |
+
directions = directions[:,:2]
|
| 78 |
+
|
| 79 |
+
return edge_index, directions
|
| 80 |
+
|
| 81 |
+
def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False):
|
| 82 |
+
|
| 83 |
+
# Current Ordering
|
| 84 |
+
# 4 3 2
|
| 85 |
+
# 5 0 1
|
| 86 |
+
# 6 7 8
|
| 87 |
+
if y_down:
|
| 88 |
+
vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0)
|
| 89 |
+
else:
|
| 90 |
+
vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
|
| 91 |
+
|
| 92 |
+
if interpolated:
|
| 93 |
+
|
| 94 |
+
if bary_d is None:
|
| 95 |
+
edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList)
|
| 96 |
+
else:
|
| 97 |
+
edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList)
|
| 98 |
+
interps = normalizeEdges(edge_index,selections,interps)
|
| 99 |
+
return edge_index,selections,interps
|
| 100 |
+
|
| 101 |
+
else:
|
| 102 |
+
selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1
|
| 103 |
+
selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection
|
| 104 |
+
return selections
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True):
|
| 108 |
+
|
| 109 |
+
sources = sources.flatten()
|
| 110 |
+
targets = targets.flatten()
|
| 111 |
+
|
| 112 |
+
prev_sources += sources.tolist()
|
| 113 |
+
prev_targets += targets.tolist()
|
| 114 |
+
prev_selections += len(sources)*[selection]
|
| 115 |
+
|
| 116 |
+
if reverse:
|
| 117 |
+
prev_sources += targets
|
| 118 |
+
prev_targets += sources
|
| 119 |
+
prev_selections += len(sources)*[utils.reverse_selection(selection)]
|
| 120 |
+
|
| 121 |
+
return prev_sources,prev_targets,prev_selections
|
| 122 |
+
|
| 123 |
+
def maskNodes(mask,x):
|
| 124 |
+
node_mask = torch.where(mask)
|
| 125 |
+
x = x[node_mask]
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
def maskPoints(mask,x,y):
|
| 129 |
+
|
| 130 |
+
mask = torch.squeeze(mask)
|
| 131 |
+
|
| 132 |
+
x0 = torch.floor(x).long()
|
| 133 |
+
x1 = x0 + 1
|
| 134 |
+
y0 = torch.floor(y).long()
|
| 135 |
+
y1 = y0 + 1
|
| 136 |
+
|
| 137 |
+
x0 = torch.clip(x0, 0, mask.shape[1]-1);
|
| 138 |
+
x1 = torch.clip(x1, 0, mask.shape[1]-1);
|
| 139 |
+
y0 = torch.clip(y0, 0, mask.shape[0]-1);
|
| 140 |
+
y1 = torch.clip(y1, 0, mask.shape[0]-1);
|
| 141 |
+
|
| 142 |
+
Ma = mask[ y0, x0 ]
|
| 143 |
+
Mb = mask[ y1, x0 ]
|
| 144 |
+
Mc = mask[ y0, x1 ]
|
| 145 |
+
Md = mask[ y1, x1 ]
|
| 146 |
+
|
| 147 |
+
node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0]
|
| 148 |
+
|
| 149 |
+
return node_mask
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def maskGraph(mask,edge_index,selections,interps=None):
|
| 153 |
+
|
| 154 |
+
edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True)
|
| 155 |
+
selections = selections[edge_mask]
|
| 156 |
+
|
| 157 |
+
if interps:
|
| 158 |
+
interps = interps[edge_mask]
|
| 159 |
+
return edge_index, selections, interps
|
| 160 |
+
else:
|
| 161 |
+
return edge_index, selections
|
| 162 |
+
|
| 163 |
+
def interpolateSelections(edge_index,directions,vectorList=None):
|
| 164 |
+
|
| 165 |
+
if vectorList is None:
|
| 166 |
+
# Current Ordering
|
| 167 |
+
# 4 3 2
|
| 168 |
+
# 5 0 1
|
| 169 |
+
# 6 7 8
|
| 170 |
+
vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
|
| 171 |
+
|
| 172 |
+
# Normalize directions for simplicity of calculations
|
| 173 |
+
dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
|
| 174 |
+
directions = directions/dir_norm
|
| 175 |
+
#locs = torch.where(dir_norm > 1)[0]
|
| 176 |
+
#directions[locs] = directions[locs]/dir_norm[locs]
|
| 177 |
+
|
| 178 |
+
values = torch.matmul(directions,vectorList)
|
| 179 |
+
best = torch.unsqueeze(torch.argmax(values,dim=1),1)
|
| 180 |
+
|
| 181 |
+
best_val = torch.take_along_dim(values,best,dim=1)
|
| 182 |
+
|
| 183 |
+
# Look at both neighbors to see who is closer
|
| 184 |
+
lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
|
| 185 |
+
upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
|
| 186 |
+
|
| 187 |
+
comp_vals = torch.cat((lower_val,upper_val),dim=1)
|
| 188 |
+
|
| 189 |
+
second_best_vals = torch.amax(comp_vals,dim=1)
|
| 190 |
+
second_best = torch.argmax(comp_vals,dim=1)
|
| 191 |
+
|
| 192 |
+
# Find the interpolation value (in terms of angles)
|
| 193 |
+
best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function
|
| 194 |
+
angle_best = torch.arccos(best_val)
|
| 195 |
+
angle_second_best = torch.arccos(second_best_vals)
|
| 196 |
+
|
| 197 |
+
angle_vals = angle_best/(angle_second_best + angle_best)
|
| 198 |
+
|
| 199 |
+
# Use negative values for clockwise selections
|
| 200 |
+
clockwise = torch.where(second_best == 0)[0]
|
| 201 |
+
angle_vals[clockwise] = -angle_vals[clockwise]
|
| 202 |
+
|
| 203 |
+
# Handle computation problems at the poles
|
| 204 |
+
angle_vals = torch.nan_to_num(angle_vals)
|
| 205 |
+
|
| 206 |
+
# Make Selections
|
| 207 |
+
selections = best[:,0] + 1
|
| 208 |
+
|
| 209 |
+
# Same cell selection
|
| 210 |
+
same_locs = torch.where(edge_index[0] == edge_index[1])
|
| 211 |
+
selections[same_locs] = 0
|
| 212 |
+
angle_vals[same_locs] = 0
|
| 213 |
+
|
| 214 |
+
# Make starting interp_values
|
| 215 |
+
interps = torch.ones_like(angle_vals)
|
| 216 |
+
interps -= torch.abs(angle_vals)
|
| 217 |
+
|
| 218 |
+
# Add new edges
|
| 219 |
+
pos_interp_locs = torch.where(angle_vals > 1e-2)[0]
|
| 220 |
+
pos_interps = angle_vals[pos_interp_locs]
|
| 221 |
+
pos_edges = edge_index[:,pos_interp_locs]
|
| 222 |
+
pos_selections = selections[pos_interp_locs] + 1
|
| 223 |
+
pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around
|
| 224 |
+
|
| 225 |
+
neg_interp_locs = torch.where(angle_vals < -1e-2)[0]
|
| 226 |
+
neg_interps = torch.abs(angle_vals[neg_interp_locs])
|
| 227 |
+
neg_edges = edge_index[:,neg_interp_locs]
|
| 228 |
+
neg_selections = selections[neg_interp_locs] - 1
|
| 229 |
+
neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
|
| 230 |
+
|
| 231 |
+
edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1)
|
| 232 |
+
selections = torch.cat((selections,pos_selections,neg_selections),dim=0)
|
| 233 |
+
interps = torch.cat((interps,pos_interps,neg_interps),dim=0)
|
| 234 |
+
|
| 235 |
+
return edge_index,selections,interps
|
| 236 |
+
|
| 237 |
+
def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None):
|
| 238 |
+
|
| 239 |
+
if vectorList is None:
|
| 240 |
+
# Current Ordering
|
| 241 |
+
# 4 3 2
|
| 242 |
+
# 5 0 1
|
| 243 |
+
# 6 7 8
|
| 244 |
+
vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device)
|
| 245 |
+
|
| 246 |
+
# Preprune central selections and reappend them at the end
|
| 247 |
+
same_locs = torch.where(edge_index[0] == edge_index[1])
|
| 248 |
+
same_edges = edge_index[:,same_locs[0]]
|
| 249 |
+
|
| 250 |
+
different_locs = torch.where(edge_index[0] != edge_index[1])
|
| 251 |
+
edge_index = edge_index[:,different_locs[0]]
|
| 252 |
+
directions = directions[different_locs[0]]
|
| 253 |
+
|
| 254 |
+
# Normalize directions for simplicity of calculations
|
| 255 |
+
dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
|
| 256 |
+
unit_directions = directions/dir_norm
|
| 257 |
+
#locs = torch.where(dir_norm > 1)[0]
|
| 258 |
+
#directions[locs] = directions[locs]/dir_norm[locs]
|
| 259 |
+
|
| 260 |
+
values = torch.matmul(unit_directions,vectorList)
|
| 261 |
+
best = torch.unsqueeze(torch.argmax(values,dim=1),1)
|
| 262 |
+
#best_val = torch.take_along_dim(values,best,dim=1)
|
| 263 |
+
|
| 264 |
+
# Look at both neighbors to see who is closer
|
| 265 |
+
lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
|
| 266 |
+
upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
|
| 267 |
+
|
| 268 |
+
comp_vals = torch.cat((lower_val,upper_val),dim=1)
|
| 269 |
+
|
| 270 |
+
second_best = torch.argmax(comp_vals,dim=1)
|
| 271 |
+
#second_best_vals = torch.amax(comp_vals,dim=1)
|
| 272 |
+
|
| 273 |
+
# Convert into uv cooridnates for barycentric interpolation calculation
|
| 274 |
+
# /|
|
| 275 |
+
# / |v
|
| 276 |
+
# /__|
|
| 277 |
+
# u
|
| 278 |
+
|
| 279 |
+
scaled_directions = torch.abs(directions/d)
|
| 280 |
+
u = torch.amax(scaled_directions,dim=1)
|
| 281 |
+
v = torch.amin(scaled_directions,dim=1)
|
| 282 |
+
|
| 283 |
+
# Force coordinates to be within the triangle
|
| 284 |
+
boundary_check = torch.where(u > d)
|
| 285 |
+
v[boundary_check] /= u[boundary_check]
|
| 286 |
+
u[boundary_check] = 1.0
|
| 287 |
+
|
| 288 |
+
# Precalculated barycentric values from linear matrix solve
|
| 289 |
+
I0 = 1 - u
|
| 290 |
+
I1 = u - v
|
| 291 |
+
I2 = v
|
| 292 |
+
|
| 293 |
+
# Make first selections and proper interps
|
| 294 |
+
selections = best[:,0] + 1
|
| 295 |
+
interps = I1
|
| 296 |
+
even_sels = torch.where(selections % 2 == 0)
|
| 297 |
+
interps[even_sels] = I2[even_sels] # Corners get different weights
|
| 298 |
+
|
| 299 |
+
# Make new edges for the central selections
|
| 300 |
+
central_edges = torch.clone(edge_index).to(edge_index.device)
|
| 301 |
+
central_selections = torch.zeros_like(selections)
|
| 302 |
+
central_interps = I0
|
| 303 |
+
|
| 304 |
+
# Make new edges for the last selection
|
| 305 |
+
pos_locs = torch.where(second_best==1)[0]
|
| 306 |
+
pos_edges = edge_index[:,pos_locs]
|
| 307 |
+
pos_selections = selections[pos_locs] + 1
|
| 308 |
+
pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around
|
| 309 |
+
pos_interps = I1[pos_locs]
|
| 310 |
+
even_sels = torch.where(pos_selections % 2 == 0)
|
| 311 |
+
pos_interps[even_sels] = I2[pos_locs][even_sels]
|
| 312 |
+
|
| 313 |
+
neg_locs = torch.where(second_best==0)[0]
|
| 314 |
+
neg_edges = edge_index[:,neg_locs]
|
| 315 |
+
neg_selections = selections[neg_locs] - 1
|
| 316 |
+
neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
|
| 317 |
+
neg_interps = I1[neg_locs]
|
| 318 |
+
even_sels = torch.where(neg_selections % 2 == 0)
|
| 319 |
+
neg_interps[even_sels] = I2[neg_locs][even_sels]
|
| 320 |
+
|
| 321 |
+
# Account for the previously pruned same node edges
|
| 322 |
+
same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long)
|
| 323 |
+
same_interps = torch.ones(same_edges.shape[1],dtype=torch.float)
|
| 324 |
+
|
| 325 |
+
# Combine
|
| 326 |
+
edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1)
|
| 327 |
+
selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0)
|
| 328 |
+
interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0)
|
| 329 |
+
|
| 330 |
+
#edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1)
|
| 331 |
+
#selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0)
|
| 332 |
+
#interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0)
|
| 333 |
+
|
| 334 |
+
# Account for edges to the same node
|
| 335 |
+
#same_locs = torch.where(edge_index[0] == edge_index[1])
|
| 336 |
+
#selections[same_locs] = 0
|
| 337 |
+
#interps[same_locs] = 1
|
| 338 |
+
|
| 339 |
+
return edge_index,selections,interps
|
| 340 |
+
|
| 341 |
+
def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False):
|
| 342 |
+
'''Given an edge_index and selections, normalize the edges for each node so that
|
| 343 |
+
aggregation of edges with interps = 1. If interps is given, use a weighted average.
|
| 344 |
+
if kernel_norm = True, account for missing selections by increasing weight on other selections.'''
|
| 345 |
+
|
| 346 |
+
N = torch.max(edge_index) + 1
|
| 347 |
+
S = torch.max(selections) + 1
|
| 348 |
+
|
| 349 |
+
total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device)
|
| 350 |
+
|
| 351 |
+
if interps is None:
|
| 352 |
+
interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device)
|
| 353 |
+
|
| 354 |
+
# Aggregate all edges to determine normalizations per selection
|
| 355 |
+
nodes = edge_index[0]
|
| 356 |
+
#total_weight[nodes,selections] += interps
|
| 357 |
+
total_weight.index_put_((nodes,selections),interps,accumulate=True)
|
| 358 |
+
|
| 359 |
+
# Reassign interps accordingly
|
| 360 |
+
if kernel_norm:
|
| 361 |
+
row_totals = torch.sum(total_weight,dim=1)
|
| 362 |
+
interps = interps * S/row_totals[nodes]
|
| 363 |
+
else:
|
| 364 |
+
norms = total_weight[nodes,selections]
|
| 365 |
+
norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error
|
| 366 |
+
interps = interps/norms
|
| 367 |
+
|
| 368 |
+
return interps
|
| 369 |
+
|
| 370 |
+
def simplifyGraph(edge_index,selections,edge_lengths):
|
| 371 |
+
# Take the shortest edge for the set of the same selections on a given node
|
| 372 |
+
num_edges = edge_index.shape[1]
|
| 373 |
+
|
| 374 |
+
# Keep track of which nodes have been visited
|
| 375 |
+
keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device)
|
| 376 |
+
|
| 377 |
+
previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
|
| 378 |
+
previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
|
| 379 |
+
|
| 380 |
+
for i in range(num_edges):
|
| 381 |
+
start_node = edge_index[0,i]
|
| 382 |
+
#end_node = edge_index[1,i]
|
| 383 |
+
selection = selections[i]
|
| 384 |
+
distance = edge_lengths[i]
|
| 385 |
+
|
| 386 |
+
if distance < previous_best_distance[start_node,selection]:
|
| 387 |
+
previous_best_distance[start_node,selection] = distance
|
| 388 |
+
keep_edges[i] = True
|
| 389 |
+
|
| 390 |
+
prev = previous_best_edge[start_node,selection]
|
| 391 |
+
if prev != -1:
|
| 392 |
+
keep_edges[prev] = False
|
| 393 |
+
|
| 394 |
+
previous_best_edge[start_node,selection] = i
|
| 395 |
+
|
| 396 |
+
edge_index = edge_index[:,torch.where(keep_edges)[0]]
|
| 397 |
+
selections = selections[torch.where(keep_edges)]
|
| 398 |
+
|
| 399 |
+
return edge_index, selections
|
| 400 |
+
|
graph_io.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from torch_geometric.nn import knn
|
| 4 |
+
from torch_geometric.data import Data
|
| 5 |
+
from torch_geometric.nn import radius_graph, knn_graph
|
| 6 |
+
|
| 7 |
+
import graph_helpers as gh
|
| 8 |
+
import sphere_helpers as sh
|
| 9 |
+
import mesh_helpers as mh
|
| 10 |
+
import clusters as cl
|
| 11 |
+
import utils
|
| 12 |
+
|
| 13 |
+
from torch_scatter import scatter
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from math import pi, sqrt
|
| 17 |
+
|
| 18 |
+
from warnings import warn
|
| 19 |
+
|
| 20 |
+
def image2Graph(data, gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
|
| 21 |
+
|
| 22 |
+
_,ch,rows,cols = data.shape
|
| 23 |
+
|
| 24 |
+
x = torch.reshape(data,(ch,rows*cols)).permute((1,0)).to(device)
|
| 25 |
+
|
| 26 |
+
if mask is not None:
|
| 27 |
+
# Mask out nodes
|
| 28 |
+
node_mask = torch.where(mask.flatten())
|
| 29 |
+
x = x[node_mask]
|
| 30 |
+
|
| 31 |
+
if gt is not None:
|
| 32 |
+
y = gt.flatten().to(device)
|
| 33 |
+
if mask is not None:
|
| 34 |
+
y = y[node_mask]
|
| 35 |
+
|
| 36 |
+
if x_only:
|
| 37 |
+
if gt is not None:
|
| 38 |
+
return x,y
|
| 39 |
+
else:
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
im_pos = gh.getImPos(rows,cols)
|
| 43 |
+
|
| 44 |
+
if mask is not None:
|
| 45 |
+
im_pos = im_pos[node_mask]
|
| 46 |
+
|
| 47 |
+
# Make "point cloud" for clustering
|
| 48 |
+
pos2D = gh.convertImPos(im_pos,flip_y=False)
|
| 49 |
+
|
| 50 |
+
# Generate initial graph
|
| 51 |
+
edge_index = gh.grid2Edges(pos2D)
|
| 52 |
+
directions = pos2D[edge_index[1]] - pos2D[edge_index[0]]
|
| 53 |
+
selections = gh.edges2Selections(edge_index,directions,interpolated=False,y_down=True)
|
| 54 |
+
|
| 55 |
+
# Generate info for downsampled versions of the graph
|
| 56 |
+
clusters, edge_indexes, selections_list = cl.makeImageClusters(pos2D,cols,rows,edge_index,selections,depth=depth,device=device)
|
| 57 |
+
|
| 58 |
+
# Make final graph and metadata needed for mapping the result after going through the network
|
| 59 |
+
graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=None)
|
| 60 |
+
metadata = Data(original=data,im_pos=im_pos.long(),rows=rows,cols=cols,ch=ch)
|
| 61 |
+
|
| 62 |
+
if gt is not None:
|
| 63 |
+
graph.y = y
|
| 64 |
+
|
| 65 |
+
return graph,metadata
|
| 66 |
+
|
| 67 |
+
def graph2Image(result,metadata,canvas=None):
|
| 68 |
+
|
| 69 |
+
x = utils.toNumpy(result,permute=False)
|
| 70 |
+
im_pos = utils.toNumpy(metadata.im_pos,permute=False)
|
| 71 |
+
if canvas is None:
|
| 72 |
+
canvas = utils.makeCanvas(x,metadata.original)
|
| 73 |
+
|
| 74 |
+
# Paint over the original image (neccesary for masked images)
|
| 75 |
+
canvas[im_pos[:,0],im_pos[:,1]] = x
|
| 76 |
+
|
| 77 |
+
return canvas
|
| 78 |
+
|
| 79 |
+
### Begin Interpolated Methods ###
|
| 80 |
+
|
| 81 |
+
def sphere2Graph(data, structure="layering", cluster_method="layering", scale=1.0, stride=2, interpolation_mode = "angle", gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
|
| 82 |
+
|
| 83 |
+
_,ch,rows,cols = data.shape
|
| 84 |
+
|
| 85 |
+
if structure == "equirec":
|
| 86 |
+
# Use the original data to start with
|
| 87 |
+
cartesian, spherical = sh.sampleSphere_Equirec(scale*rows,scale*cols)
|
| 88 |
+
elif structure == "layering":
|
| 89 |
+
cartesian, spherical = sh.sampleSphere_Layering(scale*rows)
|
| 90 |
+
elif structure == "spiral":
|
| 91 |
+
cartesian, spherical = sh.sampleSphere_Spiral(scale*rows,scale*cols)
|
| 92 |
+
elif structure == "icosphere":
|
| 93 |
+
cartesian, spherical = sh.sampleSphere_Icosphere(scale*rows)
|
| 94 |
+
elif structure == "random":
|
| 95 |
+
cartesian, spherical = sh.sampleSphere_Random(scale*rows,scale*cols)
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Sphere structure unknown")
|
| 98 |
+
|
| 99 |
+
if interpolation_mode == "bary":
|
| 100 |
+
bary_d = pi/(scale*rows)
|
| 101 |
+
else:
|
| 102 |
+
bary_d = None
|
| 103 |
+
|
| 104 |
+
# Get the landing point for each node
|
| 105 |
+
sample_x, sample_y = sh.spherical2equirec(spherical[:,0],spherical[:,1],rows,cols)
|
| 106 |
+
|
| 107 |
+
if mask is not None:
|
| 108 |
+
|
| 109 |
+
node_mask = gh.maskPoints(mask,sample_x,sample_y)
|
| 110 |
+
sample_x = sample_x[node_mask]
|
| 111 |
+
sample_y = sample_y[node_mask]
|
| 112 |
+
spherical = spherical[node_mask]
|
| 113 |
+
cartesian = cartesian[node_mask]
|
| 114 |
+
|
| 115 |
+
features = utils.bilinear_interpolate(data, sample_x, sample_y).to(device)
|
| 116 |
+
|
| 117 |
+
if gt is not None:
|
| 118 |
+
features_y = utils.bilinear_interpolate(gt.unsqueeze(0), sample_x, sample_y).to(device)
|
| 119 |
+
|
| 120 |
+
if x_only:
|
| 121 |
+
if gt is not None:
|
| 122 |
+
return features,features_y
|
| 123 |
+
else:
|
| 124 |
+
return features
|
| 125 |
+
|
| 126 |
+
# Build initial graph
|
| 127 |
+
edge_index,directions = gh.surface2Edges(cartesian,cartesian)
|
| 128 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
|
| 129 |
+
|
| 130 |
+
# Generate info for downsampled versions of the graph
|
| 131 |
+
clusters, edge_indexes, selections_list, interps_list = cl.makeSphereClusters(cartesian,edge_index,selections,interps,rows*scale,cols*scale,cluster_method,stride=stride,bary_d=bary_d,depth=depth,device=device)
|
| 132 |
+
|
| 133 |
+
# Make final graph and metadata needed for mapping the result after going through the network
|
| 134 |
+
graph = Data(x=features,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
|
| 135 |
+
metadata = Data(original=data,pos3D=cartesian,mask=mask,rows=rows,cols=cols,ch=ch)
|
| 136 |
+
|
| 137 |
+
if gt is not None:
|
| 138 |
+
graph.y = features_y
|
| 139 |
+
|
| 140 |
+
return graph, metadata
|
| 141 |
+
|
| 142 |
+
def graph2Sphere(features,metadata):
|
| 143 |
+
|
| 144 |
+
# Generate equirectangular points and their 3D locations
|
| 145 |
+
theta, phi = sh.equirec2spherical(metadata.rows, metadata.cols)
|
| 146 |
+
x,y,z = sh.spherical2xyz(theta,phi)
|
| 147 |
+
|
| 148 |
+
v = torch.stack((x,y,z),dim=1)
|
| 149 |
+
|
| 150 |
+
# Find closest 3D point to each equirectangular point
|
| 151 |
+
nearest = torch.reshape(knn(metadata.pos3D,v,3)[1],(len(v),3))
|
| 152 |
+
|
| 153 |
+
#Interpolate based on proximty to each node
|
| 154 |
+
w0 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,0]]),dim=1, keepdim=True).to(features.device)
|
| 155 |
+
w1 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,1]]),dim=1, keepdim=True).to(features.device)
|
| 156 |
+
w2 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,2]]),dim=1, keepdim=True).to(features.device)
|
| 157 |
+
|
| 158 |
+
w0 = torch.nan_to_num(w0, nan=1e6)
|
| 159 |
+
w1 = torch.nan_to_num(w1, nan=1e6)
|
| 160 |
+
w2 = torch.nan_to_num(w2, nan=1e6)
|
| 161 |
+
|
| 162 |
+
w0 = torch.clamp(w0,0,1e6)
|
| 163 |
+
w1 = torch.clamp(w1,0,1e6)
|
| 164 |
+
w2 = torch.clamp(w2,0,1e6)
|
| 165 |
+
|
| 166 |
+
total = w0 + w1 + w2
|
| 167 |
+
|
| 168 |
+
#w0,w1,w2 = mh.getBarycentricWeights(v,metadata.pos3D[nearest[:,0]],metadata.pos3D[nearest[:,1]],metadata.pos3D[nearest[:,2]])
|
| 169 |
+
|
| 170 |
+
#w0 = w0.unsqueeze(1).to(features.device)
|
| 171 |
+
#w1 = w1.unsqueeze(1).to(features.device)
|
| 172 |
+
#w2 = w2.unsqueeze(1).to(features.device)
|
| 173 |
+
|
| 174 |
+
result = (w0*features[nearest[:,0]] + w1*features[nearest[:,1]] + w2*features[nearest[:,2]])/total
|
| 175 |
+
|
| 176 |
+
#result = result.clamp(0,1)
|
| 177 |
+
|
| 178 |
+
if hasattr(metadata,"mask"):
|
| 179 |
+
mask = utils.toNumpy(metadata.mask.squeeze(),permute=False)
|
| 180 |
+
canvas = utils.makeCanvas(result,metadata.original)
|
| 181 |
+
result = np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
|
| 182 |
+
canvas[np.where(mask)] = result[np.where(mask)]
|
| 183 |
+
return canvas
|
| 184 |
+
else:
|
| 185 |
+
return np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def splat2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, depth = 1, device = 'cpu'):
|
| 190 |
+
""" Sample mesh faces to determine graph """
|
| 191 |
+
|
| 192 |
+
if up_vector == None:
|
| 193 |
+
up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
|
| 194 |
+
#up_vector = 2*torch.rand((1,3))-1
|
| 195 |
+
up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
#position, normal vector, uv coordinates in the texture map, x is color
|
| 199 |
+
pos3D, normals = mh.sampleSurface(mesh,N)
|
| 200 |
+
|
| 201 |
+
# Build initial graph
|
| 202 |
+
#edge_index are neighbors of a point, directions are the directions from that point
|
| 203 |
+
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
|
| 204 |
+
#directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
|
| 205 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
|
| 206 |
+
|
| 207 |
+
# Generate info for downsampled versions of the graph
|
| 208 |
+
clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
|
| 209 |
+
#clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
|
| 210 |
+
|
| 211 |
+
# Make final graph and metadata needed for mapping the result after going through the network
|
| 212 |
+
graph = Data(clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
|
| 213 |
+
metadata = Data(original=data,pos3D=pos3D,mesh=mesh)
|
| 214 |
+
|
| 215 |
+
return graph,metadata
|
| 216 |
+
|
| 217 |
+
def mesh2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, mask = None, depth = 1, x_only = False, device = 'cpu'):
|
| 218 |
+
""" Sample mesh faces to determine graph """
|
| 219 |
+
|
| 220 |
+
if up_vector == None:
|
| 221 |
+
up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
|
| 222 |
+
#up_vector = 2*torch.rand((1,3))-1
|
| 223 |
+
up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
|
| 224 |
+
|
| 225 |
+
if mask is not None:
|
| 226 |
+
warn("Masks are not currently implemented for mesh graphs")
|
| 227 |
+
|
| 228 |
+
#position, normal vector, uv coordinates in the texture map, x is color
|
| 229 |
+
pos3D, normals, uvs, x = mh.sampleSurface(mesh,N,return_x=True)
|
| 230 |
+
|
| 231 |
+
x = x.to(device)
|
| 232 |
+
|
| 233 |
+
if x_only:
|
| 234 |
+
warn("x_only returns randomly selected points for mesh2Graph. Do not use with previous graph structures")
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
# Build initial graph
|
| 238 |
+
#edge_index are neighbors of a point, directions are the directions from that point
|
| 239 |
+
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
|
| 240 |
+
#directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
|
| 241 |
+
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
|
| 242 |
+
|
| 243 |
+
# Generate info for downsampled versions of the graph
|
| 244 |
+
clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
|
| 245 |
+
#clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
|
| 246 |
+
|
| 247 |
+
# Make final graph and metadata needed for mapping the result after going through the network
|
| 248 |
+
graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
|
| 249 |
+
metadata = Data(original=data,pos3D=pos3D,uvs=uvs,mesh=mesh)
|
| 250 |
+
|
| 251 |
+
return graph,metadata
|
| 252 |
+
|
| 253 |
+
def graph2Splat(features,metadata,view3D=False):
|
| 254 |
+
|
| 255 |
+
features = features.cpu().numpy()
|
| 256 |
+
|
| 257 |
+
canvas = utils.toNumpy(metadata.original)
|
| 258 |
+
rows,cols,ch = canvas.shape
|
| 259 |
+
|
| 260 |
+
# Get 2D positions by scaling uv
|
| 261 |
+
pos2D = metadata.uvs.cpu().numpy()
|
| 262 |
+
pos2D[:,0] = pos2D[:,0]*cols
|
| 263 |
+
pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
|
| 264 |
+
pos2D[:,1] = pos2D[:,1]*rows
|
| 265 |
+
|
| 266 |
+
# Generate desired points
|
| 267 |
+
row_space = np.arange(rows)
|
| 268 |
+
col_space = np.arange(cols)
|
| 269 |
+
col_image,row_image = np.meshgrid(col_space,row_space)
|
| 270 |
+
|
| 271 |
+
canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
|
| 272 |
+
canvas = np.clip(canvas,0,1)
|
| 273 |
+
|
| 274 |
+
if view3D:
|
| 275 |
+
mesh = mh.setTexture(metadata.mesh,canvas)
|
| 276 |
+
mesh.show()
|
| 277 |
+
|
| 278 |
+
return canvas
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def graph2Mesh(features,metadata,view3D=False):
|
| 282 |
+
|
| 283 |
+
features = features.cpu().numpy()
|
| 284 |
+
|
| 285 |
+
canvas = utils.toNumpy(metadata.original)
|
| 286 |
+
rows,cols,ch = canvas.shape
|
| 287 |
+
|
| 288 |
+
# Get 2D positions by scaling uv
|
| 289 |
+
pos2D = metadata.uvs.cpu().numpy()
|
| 290 |
+
pos2D[:,0] = pos2D[:,0]*cols
|
| 291 |
+
pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
|
| 292 |
+
pos2D[:,1] = pos2D[:,1]*rows
|
| 293 |
+
|
| 294 |
+
# Generate desired points
|
| 295 |
+
row_space = np.arange(rows)
|
| 296 |
+
col_space = np.arange(cols)
|
| 297 |
+
col_image,row_image = np.meshgrid(col_space,row_space)
|
| 298 |
+
|
| 299 |
+
canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
|
| 300 |
+
canvas = np.clip(canvas,0,1)
|
| 301 |
+
|
| 302 |
+
if view3D:
|
| 303 |
+
mesh = mh.setTexture(metadata.mesh,canvas)
|
| 304 |
+
mesh.show()
|
| 305 |
+
|
| 306 |
+
return canvas
|
graph_networks/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
graph_networks/LinearStyleTransfer/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
graph_networks/LinearStyleTransfer/LICENSE
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BSD 2-Clause License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2018, SunshineAtNoon
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without
|
| 7 |
+
modification, are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
and/or other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 17 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 18 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 19 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 20 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 21 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 22 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 23 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 24 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 25 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
graph_networks/LinearStyleTransfer/README.md
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Learning Linear Transformations for Fast Image and Video Style Transfer
|
| 2 |
+
**[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)**
|
| 3 |
+
|
| 4 |
+
<img src="doc/images/chicago_paste.png" height="149" hspace="5"><img src="doc/images/photo_content.png" height="150" hspace="5"><img src="doc/images/content.gif" height="150" hspace="5">
|
| 5 |
+
<img src="doc/images/chicago_27.png" height="150" hspace="5"><img src="doc/images/in5_result.png" height="150" hspace="5"><img src="doc/images/test.gif" height="150" hspace="5">
|
| 6 |
+
|
| 7 |
+
## Prerequisites
|
| 8 |
+
- [Pytorch](http://pytorch.org/)
|
| 9 |
+
- [torchvision](https://github.com/pytorch/vision)
|
| 10 |
+
- [opencv](https://opencv.org/) for video generation
|
| 11 |
+
|
| 12 |
+
**All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2**
|
| 13 |
+
|
| 14 |
+
## Style Transfer
|
| 15 |
+
- Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer`
|
| 16 |
+
- Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing).
|
| 17 |
+
- Uncompress to root folder :
|
| 18 |
+
```
|
| 19 |
+
cd LinearStyleTransfer
|
| 20 |
+
unzip models.zip
|
| 21 |
+
rm models.zip
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
#### Artistic style transfer
|
| 25 |
+
```
|
| 26 |
+
python TestArtistic.py
|
| 27 |
+
```
|
| 28 |
+
or conduct style transfer on relu_31 features
|
| 29 |
+
```
|
| 30 |
+
python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
#### Photo-realistic style transfer
|
| 34 |
+
For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository.
|
| 35 |
+
```
|
| 36 |
+
cd libs/pytorch_spn
|
| 37 |
+
sh make.sh
|
| 38 |
+
cd ../..
|
| 39 |
+
```
|
| 40 |
+
Then:
|
| 41 |
+
```
|
| 42 |
+
python TestPhotoReal.py
|
| 43 |
+
```
|
| 44 |
+
Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py).
|
| 45 |
+
|
| 46 |
+
#### Video style transfer
|
| 47 |
+
```
|
| 48 |
+
python TestVideo.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### Real-time video demo
|
| 52 |
+
```
|
| 53 |
+
python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
## Model Training
|
| 57 |
+
### Data Preparation
|
| 58 |
+
- MSCOCO
|
| 59 |
+
```
|
| 60 |
+
wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
|
| 61 |
+
```
|
| 62 |
+
- WikiArt
|
| 63 |
+
- Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers).
|
| 64 |
+
- Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running:
|
| 65 |
+
```
|
| 66 |
+
kg download -u <username> -p <password> -c painter-by-numbers -f train.zip
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Training
|
| 70 |
+
#### Train a style transfer model
|
| 71 |
+
To train a model that transfers relu4_1 features, run:
|
| 72 |
+
```
|
| 73 |
+
python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
|
| 74 |
+
```
|
| 75 |
+
or train a model that transfers relu3_1 features:
|
| 76 |
+
```
|
| 77 |
+
python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
|
| 78 |
+
```
|
| 79 |
+
Key hyper-parameters:
|
| 80 |
+
- style_layers: which features to compute style loss.
|
| 81 |
+
- style_weight: larger style weight leads to heavier style in transferred images.
|
| 82 |
+
|
| 83 |
+
Intermediate results and weight will be stored in `OUTPUT_DIR`
|
| 84 |
+
|
| 85 |
+
#### Train a SPN model to cancel distortions for photo-realistic style transfer
|
| 86 |
+
Run:
|
| 87 |
+
```
|
| 88 |
+
python TrainSPN.py --contentPath PATH_TO_MSCOCO
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Acknowledgement
|
| 92 |
+
- We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer.
|
| 93 |
+
|
| 94 |
+
### Citation
|
| 95 |
+
```
|
| 96 |
+
@inproceedings{li2018learning,
|
| 97 |
+
author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan},
|
| 98 |
+
title = {Learning Linear Transformations for Fast Arbitrary Style Transfer},
|
| 99 |
+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
|
| 100 |
+
year = {2019}
|
| 101 |
+
}
|
| 102 |
+
```
|
graph_networks/LinearStyleTransfer/TestArtistic.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
from libs.Loader import Dataset
|
| 5 |
+
from libs.Matrix import MulLayer
|
| 6 |
+
import torchvision.utils as vutils
|
| 7 |
+
import torch.backends.cudnn as cudnn
|
| 8 |
+
from libs.utils import print_options
|
| 9 |
+
from libs.models import encoder3,encoder4, encoder5
|
| 10 |
+
from libs.models import decoder3,decoder4, decoder5
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
|
| 14 |
+
help='pre-trained encoder path')
|
| 15 |
+
parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
|
| 16 |
+
help='pre-trained decoder path')
|
| 17 |
+
parser.add_argument("--matrixPath", default='models/r41.pth',
|
| 18 |
+
help='pre-trained model path')
|
| 19 |
+
parser.add_argument("--stylePath", default="./data/style/",
|
| 20 |
+
help='path to style image')
|
| 21 |
+
parser.add_argument("--contentPath", default="./data/content/",
|
| 22 |
+
help='path to frames')
|
| 23 |
+
parser.add_argument("--outf", default="Artistic/",
|
| 24 |
+
help='path to transferred images')
|
| 25 |
+
parser.add_argument("--batchSize", type=int,default=1,
|
| 26 |
+
help='batch size')
|
| 27 |
+
parser.add_argument('--loadSize', type=int, default=256,
|
| 28 |
+
help='scale image size')
|
| 29 |
+
parser.add_argument('--fineSize', type=int, default=256,
|
| 30 |
+
help='crop image size')
|
| 31 |
+
parser.add_argument("--layer", default="r41",
|
| 32 |
+
help='which features to transfer, either r31 or r41')
|
| 33 |
+
|
| 34 |
+
################# PREPARATIONS #################
|
| 35 |
+
opt = parser.parse_args()
|
| 36 |
+
opt.cuda = torch.cuda.is_available()
|
| 37 |
+
print_options(opt)
|
| 38 |
+
|
| 39 |
+
os.makedirs(opt.outf,exist_ok=True)
|
| 40 |
+
cudnn.benchmark = True
|
| 41 |
+
|
| 42 |
+
################# DATA #################
|
| 43 |
+
content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True)
|
| 44 |
+
content_loader = torch.utils.data.DataLoader(dataset=content_dataset,
|
| 45 |
+
batch_size = opt.batchSize,
|
| 46 |
+
shuffle = False)
|
| 47 |
+
#num_workers = 1)
|
| 48 |
+
style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True)
|
| 49 |
+
style_loader = torch.utils.data.DataLoader(dataset=style_dataset,
|
| 50 |
+
batch_size = opt.batchSize,
|
| 51 |
+
shuffle = False)
|
| 52 |
+
#num_workers = 1)
|
| 53 |
+
|
| 54 |
+
################# MODEL #################
|
| 55 |
+
if(opt.layer == 'r31'):
|
| 56 |
+
vgg = encoder3()
|
| 57 |
+
dec = decoder3()
|
| 58 |
+
elif(opt.layer == 'r41'):
|
| 59 |
+
vgg = encoder4()
|
| 60 |
+
dec = decoder4()
|
| 61 |
+
matrix = MulLayer(opt.layer)
|
| 62 |
+
vgg.load_state_dict(torch.load(opt.vgg_dir))
|
| 63 |
+
dec.load_state_dict(torch.load(opt.decoder_dir))
|
| 64 |
+
matrix.load_state_dict(torch.load(opt.matrixPath))
|
| 65 |
+
|
| 66 |
+
################# GLOBAL VARIABLE #################
|
| 67 |
+
contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 68 |
+
styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 69 |
+
|
| 70 |
+
################# GPU #################
|
| 71 |
+
if(opt.cuda):
|
| 72 |
+
vgg.cuda()
|
| 73 |
+
dec.cuda()
|
| 74 |
+
matrix.cuda()
|
| 75 |
+
contentV = contentV.cuda()
|
| 76 |
+
styleV = styleV.cuda()
|
| 77 |
+
|
| 78 |
+
for ci,(content,contentName) in enumerate(content_loader):
|
| 79 |
+
contentName = contentName[0]
|
| 80 |
+
contentV.resize_(content.size()).copy_(content)
|
| 81 |
+
for sj,(style,styleName) in enumerate(style_loader):
|
| 82 |
+
styleName = styleName[0]
|
| 83 |
+
styleV.resize_(style.size()).copy_(style)
|
| 84 |
+
|
| 85 |
+
# forward
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
sF = vgg(styleV)
|
| 88 |
+
cF = vgg(contentV)
|
| 89 |
+
|
| 90 |
+
if(opt.layer == 'r41'):
|
| 91 |
+
feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
|
| 92 |
+
else:
|
| 93 |
+
feature,transmatrix = matrix(cF,sF)
|
| 94 |
+
transfer = dec(feature)
|
| 95 |
+
|
| 96 |
+
transfer = transfer.clamp(0,1)
|
| 97 |
+
vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize)
|
| 98 |
+
print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName))
|
graph_networks/LinearStyleTransfer/TestPhotoReal.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import time
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from libs.SPN import SPN
|
| 9 |
+
import torchvision.utils as vutils
|
| 10 |
+
from libs.utils import print_options
|
| 11 |
+
from libs.MatrixTest import MulLayer
|
| 12 |
+
import torch.backends.cudnn as cudnn
|
| 13 |
+
from libs.LoaderPhotoReal import Dataset
|
| 14 |
+
from libs.models import encoder3,encoder4
|
| 15 |
+
from libs.models import decoder3,decoder4
|
| 16 |
+
import torchvision.transforms as transforms
|
| 17 |
+
from libs.smooth_filter import smooth_filter
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
|
| 21 |
+
help='pre-trained encoder path')
|
| 22 |
+
parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
|
| 23 |
+
help='pre-trained decoder path')
|
| 24 |
+
parser.add_argument("--matrixPath", default='models/r41.pth',
|
| 25 |
+
help='pre-trained model path')
|
| 26 |
+
parser.add_argument("--stylePath", default="data/photo_real/style/images/",
|
| 27 |
+
help='path to style image')
|
| 28 |
+
parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/",
|
| 29 |
+
help='path to style image masks')
|
| 30 |
+
parser.add_argument("--contentPath", default="data/photo_real/content/images/",
|
| 31 |
+
help='path to content image')
|
| 32 |
+
parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/",
|
| 33 |
+
help='path to content image masks')
|
| 34 |
+
parser.add_argument("--outf", default="PhotoReal/",
|
| 35 |
+
help='path to save output images')
|
| 36 |
+
parser.add_argument("--batchSize", type=int,default=1,
|
| 37 |
+
help='batch size')
|
| 38 |
+
parser.add_argument('--fineSize', type=int, default=512,
|
| 39 |
+
help='image size')
|
| 40 |
+
parser.add_argument("--layer", default="r41",
|
| 41 |
+
help='features of which layer to transform, either r31 or r41')
|
| 42 |
+
parser.add_argument("--spn_dir", default='models/r41_spn.pth',
|
| 43 |
+
help='path to pretrained SPN model')
|
| 44 |
+
|
| 45 |
+
################# PREPARATIONS #################
|
| 46 |
+
opt = parser.parse_args()
|
| 47 |
+
opt.cuda = torch.cuda.is_available()
|
| 48 |
+
print_options(opt)
|
| 49 |
+
|
| 50 |
+
os.makedirs(opt.outf, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
cudnn.benchmark = True
|
| 53 |
+
|
| 54 |
+
################# DATA #################
|
| 55 |
+
dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize)
|
| 56 |
+
loader = torch.utils.data.DataLoader(dataset=dataset,
|
| 57 |
+
batch_size=1,
|
| 58 |
+
shuffle=False)
|
| 59 |
+
|
| 60 |
+
################# MODEL #################
|
| 61 |
+
if(opt.layer == 'r31'):
|
| 62 |
+
vgg = encoder3()
|
| 63 |
+
dec = decoder3()
|
| 64 |
+
elif(opt.layer == 'r41'):
|
| 65 |
+
vgg = encoder4()
|
| 66 |
+
dec = decoder4()
|
| 67 |
+
matrix = MulLayer(opt.layer)
|
| 68 |
+
vgg.load_state_dict(torch.load(opt.vgg_dir))
|
| 69 |
+
dec.load_state_dict(torch.load(opt.decoder_dir))
|
| 70 |
+
matrix.load_state_dict(torch.load(opt.matrixPath))
|
| 71 |
+
spn = SPN()
|
| 72 |
+
spn.load_state_dict(torch.load(opt.spn_dir))
|
| 73 |
+
|
| 74 |
+
################# GLOBAL VARIABLE #################
|
| 75 |
+
contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 76 |
+
styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 77 |
+
whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 78 |
+
|
| 79 |
+
################# GPU #################
|
| 80 |
+
if(opt.cuda):
|
| 81 |
+
vgg.cuda()
|
| 82 |
+
dec.cuda()
|
| 83 |
+
spn.cuda()
|
| 84 |
+
matrix.cuda()
|
| 85 |
+
contentV = contentV.cuda()
|
| 86 |
+
styleV = styleV.cuda()
|
| 87 |
+
whitenV = whitenV.cuda()
|
| 88 |
+
|
| 89 |
+
for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader):
|
| 90 |
+
imname = imname[0]
|
| 91 |
+
contentV.resize_(contentImg.size()).copy_(contentImg)
|
| 92 |
+
styleV.resize_(styleImg.size()).copy_(styleImg)
|
| 93 |
+
whitenV.resize_(whitenImg.size()).copy_(whitenImg)
|
| 94 |
+
|
| 95 |
+
# forward
|
| 96 |
+
sF = vgg(styleV)
|
| 97 |
+
cF = vgg(contentV)
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
if(opt.layer == 'r41'):
|
| 101 |
+
feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks)
|
| 102 |
+
else:
|
| 103 |
+
feature = matrix(cF,sF,cmasks,smasks)
|
| 104 |
+
transfer = dec(feature)
|
| 105 |
+
filtered = spn(transfer,whitenV)
|
| 106 |
+
vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0])))
|
| 107 |
+
|
| 108 |
+
filtered = filtered.clamp(0,1)
|
| 109 |
+
filtered = filtered.cpu()
|
| 110 |
+
vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0]))
|
| 111 |
+
out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
|
| 112 |
+
content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
|
| 113 |
+
content = content.copy()
|
| 114 |
+
out_img = out_img.copy()
|
| 115 |
+
smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1)
|
| 116 |
+
smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0]))
|
| 117 |
+
print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \
|
| 118 |
+
%(opt.outf,imname,opt.outf,imname.split('.')[0]))
|
graph_networks/LinearStyleTransfer/TestVideo.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from libs.Loader import Dataset
|
| 6 |
+
from libs.Matrix import MulLayer
|
| 7 |
+
import torch.backends.cudnn as cudnn
|
| 8 |
+
from libs.models import encoder3,encoder4
|
| 9 |
+
from libs.models import decoder3,decoder4
|
| 10 |
+
import torchvision.transforms as transforms
|
| 11 |
+
from libs.utils import makeVideo, print_options
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
|
| 15 |
+
help='pre-trained encoder path')
|
| 16 |
+
parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
|
| 17 |
+
help='pre-trained decoder path')
|
| 18 |
+
parser.add_argument("--matrix_dir", default="models/r31.pth",
|
| 19 |
+
help='path to pre-trained model')
|
| 20 |
+
parser.add_argument("--style", default="data/style/in2.jpg",
|
| 21 |
+
help='path to style image')
|
| 22 |
+
parser.add_argument("--content_dir", default="data/videos/content/mountain_2/",
|
| 23 |
+
help='path to video frames')
|
| 24 |
+
parser.add_argument('--loadSize', type=int, default=512,
|
| 25 |
+
help='scale image size')
|
| 26 |
+
parser.add_argument('--fineSize', type=int, default=512,
|
| 27 |
+
help='crop image size')
|
| 28 |
+
parser.add_argument("--name",default="transferred_video",
|
| 29 |
+
help="name of generated video")
|
| 30 |
+
parser.add_argument("--layer",default="r31",
|
| 31 |
+
help="features of which layer to transform")
|
| 32 |
+
parser.add_argument("--outf",default="videos",
|
| 33 |
+
help="output folder")
|
| 34 |
+
|
| 35 |
+
################# PREPARATIONS #################
|
| 36 |
+
opt = parser.parse_args()
|
| 37 |
+
opt.cuda = torch.cuda.is_available()
|
| 38 |
+
print_options(opt)
|
| 39 |
+
|
| 40 |
+
os.makedirs(opt.outf,exist_ok=True)
|
| 41 |
+
cudnn.benchmark = True
|
| 42 |
+
|
| 43 |
+
################# DATA #################
|
| 44 |
+
def loadImg(imgPath):
|
| 45 |
+
img = Image.open(imgPath).convert('RGB')
|
| 46 |
+
transform = transforms.Compose([
|
| 47 |
+
transforms.Scale(opt.fineSize),
|
| 48 |
+
transforms.ToTensor()])
|
| 49 |
+
return transform(img)
|
| 50 |
+
styleV = loadImg(opt.style).unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
content_dataset = Dataset(opt.content_dir,
|
| 53 |
+
loadSize = opt.loadSize,
|
| 54 |
+
fineSize = opt.fineSize,
|
| 55 |
+
test = True,
|
| 56 |
+
video = True)
|
| 57 |
+
content_loader = torch.utils.data.DataLoader(dataset = content_dataset,
|
| 58 |
+
batch_size = 1,
|
| 59 |
+
shuffle = False)
|
| 60 |
+
|
| 61 |
+
################# MODEL #################
|
| 62 |
+
if(opt.layer == 'r31'):
|
| 63 |
+
vgg = encoder3()
|
| 64 |
+
dec = decoder3()
|
| 65 |
+
elif(opt.layer == 'r41'):
|
| 66 |
+
vgg = encoder4()
|
| 67 |
+
dec = decoder4()
|
| 68 |
+
matrix = MulLayer(layer=opt.layer)
|
| 69 |
+
vgg.load_state_dict(torch.load(opt.vgg_dir))
|
| 70 |
+
dec.load_state_dict(torch.load(opt.decoder_dir))
|
| 71 |
+
matrix.load_state_dict(torch.load(opt.matrix_dir))
|
| 72 |
+
|
| 73 |
+
################# GLOBAL VARIABLE #################
|
| 74 |
+
contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
|
| 75 |
+
|
| 76 |
+
################# GPU #################
|
| 77 |
+
if(opt.cuda):
|
| 78 |
+
vgg.cuda()
|
| 79 |
+
dec.cuda()
|
| 80 |
+
matrix.cuda()
|
| 81 |
+
|
| 82 |
+
styleV = styleV.cuda()
|
| 83 |
+
contentV = contentV.cuda()
|
| 84 |
+
|
| 85 |
+
result_frames = []
|
| 86 |
+
contents = []
|
| 87 |
+
style = styleV.squeeze(0).cpu().numpy()
|
| 88 |
+
sF = vgg(styleV)
|
| 89 |
+
|
| 90 |
+
for i,(content,contentName) in enumerate(content_loader):
|
| 91 |
+
print('Transfer frame %d...'%i)
|
| 92 |
+
contentName = contentName[0]
|
| 93 |
+
contentV.resize_(content.size()).copy_(content)
|
| 94 |
+
contents.append(content.squeeze(0).float().numpy())
|
| 95 |
+
# forward
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
cF = vgg(contentV)
|
| 98 |
+
|
| 99 |
+
if(opt.layer == 'r41'):
|
| 100 |
+
feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
|
| 101 |
+
else:
|
| 102 |
+
feature,transmatrix = matrix(cF,sF)
|
| 103 |
+
transfer = dec(feature)
|
| 104 |
+
|
| 105 |
+
transfer = transfer.clamp(0,1)
|
| 106 |
+
result_frames.append(transfer.squeeze(0).cpu().numpy())
|
| 107 |
+
|
| 108 |
+
makeVideo(contents,style,result_frames,opt.outf)
|
graph_networks/LinearStyleTransfer/Train.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import argparse
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from libs.Loader import Dataset
|
| 7 |
+
from libs.Matrix import MulLayer
|
| 8 |
+
import torchvision.utils as vutils
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
from libs.utils import print_options
|
| 11 |
+
from libs.Criterion import LossCriterion
|
| 12 |
+
from libs.models import encoder3,encoder4
|
| 13 |
+
from libs.models import decoder3,decoder4
|
| 14 |
+
from libs.models import encoder5 as loss_network
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
|
| 18 |
+
help='pre-trained encoder path')
|
| 19 |
+
parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth',
|
| 20 |
+
help='used for loss network')
|
| 21 |
+
parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
|
| 22 |
+
help='pre-trained decoder path')
|
| 23 |
+
parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/",
|
| 24 |
+
help='path to wikiArt dataset')
|
| 25 |
+
parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
|
| 26 |
+
help='path to MSCOCO dataset')
|
| 27 |
+
parser.add_argument("--outf", default="trainingOutput/",
|
| 28 |
+
help='folder to output images and model checkpoints')
|
| 29 |
+
parser.add_argument("--content_layers", default="r41",
|
| 30 |
+
help='layers for content')
|
| 31 |
+
parser.add_argument("--style_layers", default="r11,r21,r31,r41",
|
| 32 |
+
help='layers for style')
|
| 33 |
+
parser.add_argument("--batchSize", type=int,default=8,
|
| 34 |
+
help='batch size')
|
| 35 |
+
parser.add_argument("--niter", type=int,default=100000,
|
| 36 |
+
help='iterations to train the model')
|
| 37 |
+
parser.add_argument('--loadSize', type=int, default=300,
|
| 38 |
+
help='scale image size')
|
| 39 |
+
parser.add_argument('--fineSize', type=int, default=256,
|
| 40 |
+
help='crop image size')
|
| 41 |
+
parser.add_argument("--lr", type=float, default=1e-4,
|
| 42 |
+
help='learning rate')
|
| 43 |
+
parser.add_argument("--content_weight", type=float, default=1.0,
|
| 44 |
+
help='content loss weight')
|
| 45 |
+
parser.add_argument("--style_weight", type=float, default=0.02,
|
| 46 |
+
help='style loss weight')
|
| 47 |
+
parser.add_argument("--log_interval", type=int, default=500,
|
| 48 |
+
help='log interval')
|
| 49 |
+
parser.add_argument("--gpu_id", type=int, default=0,
|
| 50 |
+
help='which gpu to use')
|
| 51 |
+
parser.add_argument("--save_interval", type=int, default=5000,
|
| 52 |
+
help='checkpoint save interval')
|
| 53 |
+
parser.add_argument("--layer", default="r41",
|
| 54 |
+
help='which features to transfer, either r31 or r41')
|
| 55 |
+
|
| 56 |
+
################# PREPARATIONS #################
|
| 57 |
+
opt = parser.parse_args()
|
| 58 |
+
opt.content_layers = opt.content_layers.split(',')
|
| 59 |
+
opt.style_layers = opt.style_layers.split(',')
|
| 60 |
+
opt.cuda = torch.cuda.is_available()
|
| 61 |
+
if(opt.cuda):
|
| 62 |
+
torch.cuda.set_device(opt.gpu_id)
|
| 63 |
+
|
| 64 |
+
os.makedirs(opt.outf,exist_ok=True)
|
| 65 |
+
cudnn.benchmark = True
|
| 66 |
+
print_options(opt)
|
| 67 |
+
|
| 68 |
+
################# DATA #################
|
| 69 |
+
content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
|
| 70 |
+
content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset,
|
| 71 |
+
batch_size = opt.batchSize,
|
| 72 |
+
shuffle = True,
|
| 73 |
+
num_workers = 1,
|
| 74 |
+
drop_last = True)
|
| 75 |
+
content_loader = iter(content_loader_)
|
| 76 |
+
style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize)
|
| 77 |
+
style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset,
|
| 78 |
+
batch_size = opt.batchSize,
|
| 79 |
+
shuffle = True,
|
| 80 |
+
num_workers = 1,
|
| 81 |
+
drop_last = True)
|
| 82 |
+
style_loader = iter(style_loader_)
|
| 83 |
+
|
| 84 |
+
################# MODEL #################
|
| 85 |
+
vgg5 = loss_network()
|
| 86 |
+
if(opt.layer == 'r31'):
|
| 87 |
+
matrix = MulLayer('r31')
|
| 88 |
+
vgg = encoder3()
|
| 89 |
+
dec = decoder3()
|
| 90 |
+
elif(opt.layer == 'r41'):
|
| 91 |
+
matrix = MulLayer('r41')
|
| 92 |
+
vgg = encoder4()
|
| 93 |
+
dec = decoder4()
|
| 94 |
+
vgg.load_state_dict(torch.load(opt.vgg_dir))
|
| 95 |
+
dec.load_state_dict(torch.load(opt.decoder_dir))
|
| 96 |
+
vgg5.load_state_dict(torch.load(opt.loss_network_dir))
|
| 97 |
+
|
| 98 |
+
for param in vgg.parameters():
|
| 99 |
+
param.requires_grad = False
|
| 100 |
+
for param in vgg5.parameters():
|
| 101 |
+
param.requires_grad = False
|
| 102 |
+
for param in dec.parameters():
|
| 103 |
+
param.requires_grad = False
|
| 104 |
+
|
| 105 |
+
################# LOSS & OPTIMIZER #################
|
| 106 |
+
criterion = LossCriterion(opt.style_layers,
|
| 107 |
+
opt.content_layers,
|
| 108 |
+
opt.style_weight,
|
| 109 |
+
opt.content_weight)
|
| 110 |
+
optimizer = optim.Adam(matrix.parameters(), opt.lr)
|
| 111 |
+
|
| 112 |
+
################# GLOBAL VARIABLE #################
|
| 113 |
+
contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 114 |
+
styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 115 |
+
|
| 116 |
+
################# GPU #################
|
| 117 |
+
if(opt.cuda):
|
| 118 |
+
vgg.cuda()
|
| 119 |
+
dec.cuda()
|
| 120 |
+
vgg5.cuda()
|
| 121 |
+
matrix.cuda()
|
| 122 |
+
contentV = contentV.cuda()
|
| 123 |
+
styleV = styleV.cuda()
|
| 124 |
+
|
| 125 |
+
################# TRAINING #################
|
| 126 |
+
def adjust_learning_rate(optimizer, iteration):
|
| 127 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
| 128 |
+
for param_group in optimizer.param_groups:
|
| 129 |
+
param_group['lr'] = opt.lr / (1+iteration*1e-5)
|
| 130 |
+
|
| 131 |
+
for iteration in range(1,opt.niter+1):
|
| 132 |
+
optimizer.zero_grad()
|
| 133 |
+
try:
|
| 134 |
+
content,_ = content_loader.next()
|
| 135 |
+
except IOError:
|
| 136 |
+
content,_ = content_loader.next()
|
| 137 |
+
except StopIteration:
|
| 138 |
+
content_loader = iter(content_loader_)
|
| 139 |
+
content,_ = content_loader.next()
|
| 140 |
+
except:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
style,_ = style_loader.next()
|
| 145 |
+
except IOError:
|
| 146 |
+
style,_ = style_loader.next()
|
| 147 |
+
except StopIteration:
|
| 148 |
+
style_loader = iter(style_loader_)
|
| 149 |
+
style,_ = style_loader.next()
|
| 150 |
+
except:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
contentV.resize_(content.size()).copy_(content)
|
| 154 |
+
styleV.resize_(style.size()).copy_(style)
|
| 155 |
+
|
| 156 |
+
# forward
|
| 157 |
+
sF = vgg(styleV)
|
| 158 |
+
cF = vgg(contentV)
|
| 159 |
+
|
| 160 |
+
if(opt.layer == 'r41'):
|
| 161 |
+
feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
|
| 162 |
+
else:
|
| 163 |
+
feature,transmatrix = matrix(cF,sF)
|
| 164 |
+
transfer = dec(feature)
|
| 165 |
+
|
| 166 |
+
sF_loss = vgg5(styleV)
|
| 167 |
+
cF_loss = vgg5(contentV)
|
| 168 |
+
tF = vgg5(transfer)
|
| 169 |
+
loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss)
|
| 170 |
+
|
| 171 |
+
# backward & optimization
|
| 172 |
+
loss.backward()
|
| 173 |
+
optimizer.step()
|
| 174 |
+
print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'%
|
| 175 |
+
(opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr']))
|
| 176 |
+
|
| 177 |
+
adjust_learning_rate(optimizer,iteration)
|
| 178 |
+
|
| 179 |
+
if((iteration) % opt.log_interval == 0):
|
| 180 |
+
transfer = transfer.clamp(0,1)
|
| 181 |
+
concat = torch.cat((content,style,transfer.cpu()),dim=0)
|
| 182 |
+
vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize)
|
| 183 |
+
|
| 184 |
+
if(iteration > 0 and (iteration) % opt.save_interval == 0):
|
| 185 |
+
torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer))
|
graph_networks/LinearStyleTransfer/TrainSPN.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from libs.SPN import SPN
|
| 6 |
+
from libs.Loader import Dataset
|
| 7 |
+
from libs.models import encoder4
|
| 8 |
+
from libs.models import decoder4
|
| 9 |
+
from libs.utils import print_options
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import torchvision.utils as vutils
|
| 16 |
+
import torch.backends.cudnn as cudnn
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
|
| 19 |
+
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
|
| 21 |
+
help='pre-trained encoder path')
|
| 22 |
+
parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
|
| 23 |
+
help='pre-trained decoder path')
|
| 24 |
+
parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
|
| 25 |
+
help='path to MSCOCO dataset')
|
| 26 |
+
parser.add_argument("--outf", default="trainingSPNOutput/",
|
| 27 |
+
help='folder to output images and model checkpoints')
|
| 28 |
+
parser.add_argument("--layer", default="r41",
|
| 29 |
+
help='layers for content')
|
| 30 |
+
parser.add_argument("--batchSize", type=int,default=8,
|
| 31 |
+
help='batch size')
|
| 32 |
+
parser.add_argument("--niter", type=int,default=100000,
|
| 33 |
+
help='iterations to train the model')
|
| 34 |
+
parser.add_argument('--loadSize', type=int, default=512,
|
| 35 |
+
help='scale image size')
|
| 36 |
+
parser.add_argument('--fineSize', type=int, default=256,
|
| 37 |
+
help='crop image size')
|
| 38 |
+
parser.add_argument("--lr", type=float, default=1e-3,
|
| 39 |
+
help='learning rate')
|
| 40 |
+
parser.add_argument("--log_interval", type=int, default=500,
|
| 41 |
+
help='log interval')
|
| 42 |
+
parser.add_argument("--save_interval", type=int, default=5000,
|
| 43 |
+
help='checkpoint save interval')
|
| 44 |
+
parser.add_argument("--spn_num", type=int, default=1,
|
| 45 |
+
help='number of spn filters')
|
| 46 |
+
|
| 47 |
+
################# PREPARATIONS #################
|
| 48 |
+
opt = parser.parse_args()
|
| 49 |
+
opt.cuda = torch.cuda.is_available()
|
| 50 |
+
print_options(opt)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
os.makedirs(opt.outf, exist_ok = True)
|
| 54 |
+
|
| 55 |
+
cudnn.benchmark = True
|
| 56 |
+
|
| 57 |
+
################# DATA #################
|
| 58 |
+
content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
|
| 59 |
+
content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset,
|
| 60 |
+
batch_size = opt.batchSize,
|
| 61 |
+
shuffle = True,
|
| 62 |
+
num_workers = 4,
|
| 63 |
+
drop_last = True)
|
| 64 |
+
content_loader = iter(content_loader_)
|
| 65 |
+
|
| 66 |
+
################# MODEL #################
|
| 67 |
+
spn = SPN(spn=opt.spn_num)
|
| 68 |
+
if(opt.layer == 'r31'):
|
| 69 |
+
vgg = encoder3()
|
| 70 |
+
dec = decoder3()
|
| 71 |
+
elif(opt.layer == 'r41'):
|
| 72 |
+
vgg = encoder4()
|
| 73 |
+
dec = decoder4()
|
| 74 |
+
vgg.load_state_dict(torch.load(opt.vgg_dir))
|
| 75 |
+
dec.load_state_dict(torch.load(opt.decoder_dir))
|
| 76 |
+
|
| 77 |
+
for param in vgg.parameters():
|
| 78 |
+
param.requires_grad = False
|
| 79 |
+
for param in dec.parameters():
|
| 80 |
+
param.requires_grad = False
|
| 81 |
+
|
| 82 |
+
################# LOSS & OPTIMIZER #################
|
| 83 |
+
criterion = nn.MSELoss(size_average=False)
|
| 84 |
+
#optimizer_spn = optim.SGD(spn.parameters(), opt.lr)
|
| 85 |
+
optimizer_spn = optim.Adam(spn.parameters(), opt.lr)
|
| 86 |
+
|
| 87 |
+
################# GLOBAL VARIABLE #################
|
| 88 |
+
contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
|
| 89 |
+
|
| 90 |
+
################# GPU #################
|
| 91 |
+
if(opt.cuda):
|
| 92 |
+
vgg.cuda()
|
| 93 |
+
dec.cuda()
|
| 94 |
+
spn.cuda()
|
| 95 |
+
contentV = contentV.cuda()
|
| 96 |
+
|
| 97 |
+
################# TRAINING #################
|
| 98 |
+
def adjust_learning_rate(optimizer, iteration):
|
| 99 |
+
for param_group in optimizer.param_groups:
|
| 100 |
+
param_group['lr'] = opt.lr / (1+iteration*1e-5)
|
| 101 |
+
|
| 102 |
+
spn.train()
|
| 103 |
+
for iteration in range(1,opt.niter+1):
|
| 104 |
+
optimizer_spn.zero_grad()
|
| 105 |
+
try:
|
| 106 |
+
content,_ = content_loader.next()
|
| 107 |
+
except IOError:
|
| 108 |
+
content,_ = content_loader.next()
|
| 109 |
+
except StopIteration:
|
| 110 |
+
content_loader = iter(content_loader_)
|
| 111 |
+
content,_ = content_loader.next()
|
| 112 |
+
except:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
contentV.resize_(content.size()).copy_(content)
|
| 116 |
+
|
| 117 |
+
# forward
|
| 118 |
+
cF = vgg(contentV)
|
| 119 |
+
transfer = dec(cF['r41'])
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
propagated = spn(transfer,contentV)
|
| 123 |
+
loss = criterion(propagated,contentV)
|
| 124 |
+
|
| 125 |
+
# backward & optimization
|
| 126 |
+
loss.backward()
|
| 127 |
+
#nn.utils.clip_grad_norm(spn.parameters(), 1000)
|
| 128 |
+
optimizer_spn.step()
|
| 129 |
+
print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f'
|
| 130 |
+
%(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr']))
|
| 131 |
+
|
| 132 |
+
adjust_learning_rate(optimizer_spn,iteration)
|
| 133 |
+
|
| 134 |
+
if((iteration) % opt.log_interval == 0):
|
| 135 |
+
transfer = transfer.clamp(0,1)
|
| 136 |
+
propagated = propagated.clamp(0,1)
|
| 137 |
+
vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration))
|
| 138 |
+
vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration))
|
| 139 |
+
|
| 140 |
+
if(iteration > 0 and (iteration) % opt.save_interval == 0):
|
| 141 |
+
torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer))
|
graph_networks/LinearStyleTransfer/__init__.py
ADDED
|
File without changes
|
graph_networks/LinearStyleTransfer/libs/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
graph_networks/LinearStyleTransfer/libs/Criterion.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class styleLoss(nn.Module):
|
| 5 |
+
def forward(self,input,target):
|
| 6 |
+
ib,ic,ih,iw = input.size()
|
| 7 |
+
iF = input.view(ib,ic,-1)
|
| 8 |
+
iMean = torch.mean(iF,dim=2)
|
| 9 |
+
iCov = GramMatrix()(input)
|
| 10 |
+
|
| 11 |
+
tb,tc,th,tw = target.size()
|
| 12 |
+
tF = target.view(tb,tc,-1)
|
| 13 |
+
tMean = torch.mean(tF,dim=2)
|
| 14 |
+
tCov = GramMatrix()(target)
|
| 15 |
+
|
| 16 |
+
loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
|
| 17 |
+
return loss/tb
|
| 18 |
+
|
| 19 |
+
class GramMatrix(nn.Module):
|
| 20 |
+
def forward(self,input):
|
| 21 |
+
b, c, h, w = input.size()
|
| 22 |
+
f = input.view(b,c,h*w) # bxcx(hxw)
|
| 23 |
+
# torch.bmm(batch1, batch2, out=None) #
|
| 24 |
+
# batch1: bxmxp, batch2: bxpxn -> bxmxn #
|
| 25 |
+
G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
|
| 26 |
+
return G.div_(c*h*w)
|
| 27 |
+
|
| 28 |
+
class LossCriterion(nn.Module):
|
| 29 |
+
def __init__(self,style_layers,content_layers,style_weight,content_weight):
|
| 30 |
+
super(LossCriterion,self).__init__()
|
| 31 |
+
|
| 32 |
+
self.style_layers = style_layers
|
| 33 |
+
self.content_layers = content_layers
|
| 34 |
+
self.style_weight = style_weight
|
| 35 |
+
self.content_weight = content_weight
|
| 36 |
+
|
| 37 |
+
self.styleLosses = [styleLoss()] * len(style_layers)
|
| 38 |
+
self.contentLosses = [nn.MSELoss()] * len(content_layers)
|
| 39 |
+
|
| 40 |
+
def forward(self,tF,sF,cF):
|
| 41 |
+
# content loss
|
| 42 |
+
totalContentLoss = 0
|
| 43 |
+
for i,layer in enumerate(self.content_layers):
|
| 44 |
+
cf_i = cF[layer]
|
| 45 |
+
cf_i = cf_i.detach()
|
| 46 |
+
tf_i = tF[layer]
|
| 47 |
+
loss_i = self.contentLosses[i]
|
| 48 |
+
totalContentLoss += loss_i(tf_i,cf_i)
|
| 49 |
+
totalContentLoss = totalContentLoss * self.content_weight
|
| 50 |
+
|
| 51 |
+
# style loss
|
| 52 |
+
totalStyleLoss = 0
|
| 53 |
+
for i,layer in enumerate(self.style_layers):
|
| 54 |
+
sf_i = sF[layer]
|
| 55 |
+
sf_i = sf_i.detach()
|
| 56 |
+
tf_i = tF[layer]
|
| 57 |
+
loss_i = self.styleLosses[i]
|
| 58 |
+
totalStyleLoss += loss_i(tf_i,sf_i)
|
| 59 |
+
totalStyleLoss = totalStyleLoss * self.style_weight
|
| 60 |
+
loss = totalStyleLoss + totalContentLoss
|
| 61 |
+
|
| 62 |
+
return loss,totalStyleLoss,totalContentLoss
|
graph_networks/LinearStyleTransfer/libs/Loader.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch.utils.data as data
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
|
| 6 |
+
def is_image_file(filename):
|
| 7 |
+
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
|
| 8 |
+
|
| 9 |
+
def default_loader(path):
|
| 10 |
+
return Image.open(path).convert('RGB')
|
| 11 |
+
|
| 12 |
+
class Dataset(data.Dataset):
|
| 13 |
+
def __init__(self,dataPath,loadSize,fineSize,test=False,video=False):
|
| 14 |
+
super(Dataset,self).__init__()
|
| 15 |
+
self.dataPath = dataPath
|
| 16 |
+
self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)]
|
| 17 |
+
self.image_list = sorted(self.image_list)
|
| 18 |
+
if(video):
|
| 19 |
+
self.image_list = sorted(self.image_list)
|
| 20 |
+
if not test:
|
| 21 |
+
self.transform = transforms.Compose([
|
| 22 |
+
transforms.Resize(fineSize),
|
| 23 |
+
transforms.RandomCrop(fineSize),
|
| 24 |
+
transforms.RandomHorizontalFlip(),
|
| 25 |
+
transforms.ToTensor()])
|
| 26 |
+
else:
|
| 27 |
+
self.transform = transforms.Compose([
|
| 28 |
+
transforms.Resize(fineSize),
|
| 29 |
+
transforms.ToTensor()])
|
| 30 |
+
|
| 31 |
+
self.test = test
|
| 32 |
+
|
| 33 |
+
def __getitem__(self,index):
|
| 34 |
+
dataPath = os.path.join(self.dataPath,self.image_list[index])
|
| 35 |
+
|
| 36 |
+
Img = default_loader(dataPath)
|
| 37 |
+
ImgA = self.transform(Img)
|
| 38 |
+
|
| 39 |
+
imgName = self.image_list[index]
|
| 40 |
+
imgName = imgName.split('.')[0]
|
| 41 |
+
return ImgA,imgName
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.image_list)
|
graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torchvision.transforms as transforms
|
| 3 |
+
import torchvision.utils as vutils
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
from os import listdir
|
| 6 |
+
from os.path import join
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.autograd import Variable
|
| 12 |
+
import numpy as np
|
| 13 |
+
from libs.utils import whiten
|
| 14 |
+
|
| 15 |
+
def is_image_file(filename):
|
| 16 |
+
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
|
| 17 |
+
|
| 18 |
+
def default_loader(path,fineSize):
|
| 19 |
+
img = Image.open(path).convert('RGB')
|
| 20 |
+
w,h = img.size
|
| 21 |
+
if(w < h):
|
| 22 |
+
neww = fineSize
|
| 23 |
+
newh = h * neww / w
|
| 24 |
+
newh = int(newh / 8) * 8
|
| 25 |
+
else:
|
| 26 |
+
newh = fineSize
|
| 27 |
+
neww = w * newh / h
|
| 28 |
+
neww = int(neww / 8) * 8
|
| 29 |
+
img = img.resize((neww,newh))
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
def MaskHelper(seg,color):
|
| 33 |
+
# green
|
| 34 |
+
mask = torch.Tensor()
|
| 35 |
+
if(color == 'green'):
|
| 36 |
+
mask = torch.lt(seg[0],0.1)
|
| 37 |
+
mask = torch.mul(mask,torch.gt(seg[1],1-0.1))
|
| 38 |
+
mask = torch.mul(mask,torch.lt(seg[2],0.1))
|
| 39 |
+
elif(color == 'black'):
|
| 40 |
+
mask = torch.lt(seg[0], 0.1)
|
| 41 |
+
mask = torch.mul(mask,torch.lt(seg[1], 0.1))
|
| 42 |
+
mask = torch.mul(mask,torch.lt(seg[2], 0.1))
|
| 43 |
+
elif(color == 'white'):
|
| 44 |
+
mask = torch.gt(seg[0], 1-0.1)
|
| 45 |
+
mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
|
| 46 |
+
mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
|
| 47 |
+
elif(color == 'red'):
|
| 48 |
+
mask = torch.gt(seg[0], 1-0.1)
|
| 49 |
+
mask = torch.mul(mask,torch.lt(seg[1], 0.1))
|
| 50 |
+
mask = torch.mul(mask,torch.lt(seg[2], 0.1))
|
| 51 |
+
elif(color == 'blue'):
|
| 52 |
+
mask = torch.lt(seg[0], 0.1)
|
| 53 |
+
mask = torch.mul(mask,torch.lt(seg[1], 0.1))
|
| 54 |
+
mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
|
| 55 |
+
elif(color == 'yellow'):
|
| 56 |
+
mask = torch.gt(seg[0], 1-0.1)
|
| 57 |
+
mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
|
| 58 |
+
mask = torch.mul(mask,torch.lt(seg[2], 0.1))
|
| 59 |
+
elif(color == 'grey'):
|
| 60 |
+
mask = torch.lt(seg[0], 0.1)
|
| 61 |
+
mask = torch.mul(mask,torch.lt(seg[1], 0.1))
|
| 62 |
+
mask = torch.mul(mask,torch.lt(seg[2], 0.1))
|
| 63 |
+
elif(color == 'lightblue'):
|
| 64 |
+
mask = torch.lt(seg[0], 0.1)
|
| 65 |
+
mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
|
| 66 |
+
mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
|
| 67 |
+
elif(color == 'purple'):
|
| 68 |
+
mask = torch.gt(seg[0], 1-0.1)
|
| 69 |
+
mask = torch.mul(mask,torch.lt(seg[1], 0.1))
|
| 70 |
+
mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
|
| 71 |
+
else:
|
| 72 |
+
print('MaskHelper(): color not recognized, color = ' + color)
|
| 73 |
+
return mask.float()
|
| 74 |
+
|
| 75 |
+
def ExtractMask(Seg):
|
| 76 |
+
# Given segmentation for content and style, we get a list of segmentation for each color
|
| 77 |
+
'''
|
| 78 |
+
Test Code:
|
| 79 |
+
content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg)
|
| 80 |
+
for i,mask in enumerate(content_masks):
|
| 81 |
+
vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True)
|
| 82 |
+
for i,mask in enumerate(style_masks):
|
| 83 |
+
vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True)
|
| 84 |
+
'''
|
| 85 |
+
color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple']
|
| 86 |
+
masks = []
|
| 87 |
+
for color in color_codes:
|
| 88 |
+
mask = MaskHelper(Seg,color)
|
| 89 |
+
masks.append(mask)
|
| 90 |
+
return masks
|
| 91 |
+
|
| 92 |
+
def calculate_size(h,w,fineSize):
|
| 93 |
+
if(h > w):
|
| 94 |
+
newh = fineSize
|
| 95 |
+
neww = int(w * 1.0 * newh / h)
|
| 96 |
+
else:
|
| 97 |
+
neww = fineSize
|
| 98 |
+
newh = int(h * 1.0 * neww / w)
|
| 99 |
+
newh = (newh // 8) * 8
|
| 100 |
+
neww = (neww // 8) * 8
|
| 101 |
+
return neww, newh
|
| 102 |
+
|
| 103 |
+
class Dataset(data.Dataset):
|
| 104 |
+
def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize):
|
| 105 |
+
super(Dataset,self).__init__()
|
| 106 |
+
self.contentPath = contentPath
|
| 107 |
+
self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
|
| 108 |
+
self.stylePath = stylePath
|
| 109 |
+
self.contentSegPath = contentSegPath
|
| 110 |
+
self.styleSegPath = styleSegPath
|
| 111 |
+
self.fineSize = fineSize
|
| 112 |
+
|
| 113 |
+
def __getitem__(self,index):
|
| 114 |
+
contentImgPath = os.path.join(self.contentPath,self.image_list[index])
|
| 115 |
+
styleImgPath = os.path.join(self.stylePath,self.image_list[index])
|
| 116 |
+
contentImg = default_loader(contentImgPath,self.fineSize)
|
| 117 |
+
styleImg = default_loader(styleImgPath,self.fineSize)
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index])
|
| 121 |
+
contentSegImg = default_loader(contentSegImgPath,self.fineSize)
|
| 122 |
+
except :
|
| 123 |
+
print('no mask provided, fake a whole black one')
|
| 124 |
+
contentSegImg = Image.new('RGB', (contentImg.size))
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index])
|
| 128 |
+
styleSegImg = default_loader(styleSegImgPath,self.fineSize)
|
| 129 |
+
except :
|
| 130 |
+
print('no mask provided, fake a whole black one')
|
| 131 |
+
styleSegImg = Image.new('RGB', (styleImg.size))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
hs, ws = styleImg.size
|
| 135 |
+
newhs, newws = calculate_size(hs,ws,self.fineSize)
|
| 136 |
+
|
| 137 |
+
transform = transforms.Compose([
|
| 138 |
+
transforms.Resize((newhs, newws)),
|
| 139 |
+
transforms.ToTensor()])
|
| 140 |
+
# Turning segmentation images into masks
|
| 141 |
+
styleSegImg = transform(styleSegImg)
|
| 142 |
+
styleImgArbi = transform(styleImg)
|
| 143 |
+
|
| 144 |
+
hc, wc = contentImg.size
|
| 145 |
+
newhc, newwc = calculate_size(hc,wc,self.fineSize)
|
| 146 |
+
|
| 147 |
+
transform = transforms.Compose([
|
| 148 |
+
transforms.Resize((newhc, newwc)),
|
| 149 |
+
transforms.ToTensor()])
|
| 150 |
+
contentSegImg = transform(contentSegImg)
|
| 151 |
+
contentImgArbi = transform(contentImg)
|
| 152 |
+
|
| 153 |
+
content_masks = ExtractMask(contentSegImg)
|
| 154 |
+
style_masks = ExtractMask(styleSegImg)
|
| 155 |
+
|
| 156 |
+
ImgW = whiten(contentImgArbi.view(3,-1).double())
|
| 157 |
+
ImgW = ImgW.view(contentImgArbi.size()).float()
|
| 158 |
+
|
| 159 |
+
return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index]
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.image_list)
|
graph_networks/LinearStyleTransfer/libs/Matrix.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class CNN(nn.Module):
|
| 5 |
+
def __init__(self,layer,matrixSize=32):
|
| 6 |
+
super(CNN,self).__init__()
|
| 7 |
+
if(layer == 'r31'):
|
| 8 |
+
# 256x64x64
|
| 9 |
+
self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
|
| 10 |
+
nn.ReLU(inplace=True),
|
| 11 |
+
nn.Conv2d(128,64,3,1,1),
|
| 12 |
+
nn.ReLU(inplace=True),
|
| 13 |
+
nn.Conv2d(64,matrixSize,3,1,1))
|
| 14 |
+
elif(layer == 'r41'):
|
| 15 |
+
# 512x32x32
|
| 16 |
+
self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
nn.Conv2d(256,128,3,1,1),
|
| 19 |
+
nn.ReLU(inplace=True),
|
| 20 |
+
nn.Conv2d(128,matrixSize,3,1,1))
|
| 21 |
+
|
| 22 |
+
# 32x8x8
|
| 23 |
+
self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
|
| 24 |
+
#self.fc = nn.Linear(32*64,256*256)
|
| 25 |
+
|
| 26 |
+
def forward(self,x):
|
| 27 |
+
out = self.convs(x)
|
| 28 |
+
# 32x8x8
|
| 29 |
+
b,c,h,w = out.size()
|
| 30 |
+
out = out.view(b,c,-1)
|
| 31 |
+
# 32x64
|
| 32 |
+
out = torch.bmm(out,out.transpose(1,2)).div(h*w)
|
| 33 |
+
# 32x32
|
| 34 |
+
out = out.view(out.size(0),-1)
|
| 35 |
+
return self.fc(out)
|
| 36 |
+
|
| 37 |
+
class MulLayer(nn.Module):
|
| 38 |
+
def __init__(self,layer,matrixSize=32):
|
| 39 |
+
super(MulLayer,self).__init__()
|
| 40 |
+
self.snet = CNN(layer,matrixSize)
|
| 41 |
+
self.cnet = CNN(layer,matrixSize)
|
| 42 |
+
self.matrixSize = matrixSize
|
| 43 |
+
|
| 44 |
+
if(layer == 'r41'):
|
| 45 |
+
self.compress = nn.Conv2d(512,matrixSize,1,1,0)
|
| 46 |
+
self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
|
| 47 |
+
elif(layer == 'r31'):
|
| 48 |
+
self.compress = nn.Conv2d(256,matrixSize,1,1,0)
|
| 49 |
+
self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
|
| 50 |
+
self.transmatrix = None
|
| 51 |
+
|
| 52 |
+
def forward(self,cF,sF,trans=True):
|
| 53 |
+
cFBK = cF.clone()
|
| 54 |
+
cb,cc,ch,cw = cF.size()
|
| 55 |
+
cFF = cF.view(cb,cc,-1)
|
| 56 |
+
cMean = torch.mean(cFF,dim=2,keepdim=True)
|
| 57 |
+
cMean = cMean.unsqueeze(3)
|
| 58 |
+
cMean = cMean.expand_as(cF)
|
| 59 |
+
cF = cF - cMean
|
| 60 |
+
|
| 61 |
+
sb,sc,sh,sw = sF.size()
|
| 62 |
+
sFF = sF.view(sb,sc,-1)
|
| 63 |
+
sMean = torch.mean(sFF,dim=2,keepdim=True)
|
| 64 |
+
sMean = sMean.unsqueeze(3)
|
| 65 |
+
sMeanC = sMean.expand_as(cF)
|
| 66 |
+
sMeanS = sMean.expand_as(sF)
|
| 67 |
+
sF = sF - sMeanS
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
compress_content = self.compress(cF)
|
| 71 |
+
b,c,h,w = compress_content.size()
|
| 72 |
+
compress_content = compress_content.view(b,c,-1)
|
| 73 |
+
|
| 74 |
+
if(trans):
|
| 75 |
+
cMatrix = self.cnet(cF)
|
| 76 |
+
sMatrix = self.snet(sF)
|
| 77 |
+
|
| 78 |
+
sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
|
| 79 |
+
cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
|
| 80 |
+
transmatrix = torch.bmm(sMatrix,cMatrix)
|
| 81 |
+
print(cMatrix)
|
| 82 |
+
transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
|
| 83 |
+
out = self.unzip(transfeature.view(b,c,h,w))
|
| 84 |
+
out = out + sMeanC
|
| 85 |
+
return out, transmatrix
|
| 86 |
+
else:
|
| 87 |
+
out = self.unzip(compress_content.view(b,c,h,w))
|
| 88 |
+
out = out + cMean
|
| 89 |
+
return out
|
graph_networks/LinearStyleTransfer/libs/MatrixTest.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
import torchvision.utils as vutils
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CNN(nn.Module):
|
| 11 |
+
def __init__(self,layer,matrixSize=32):
|
| 12 |
+
super(CNN,self).__init__()
|
| 13 |
+
# 256x64x64
|
| 14 |
+
if(layer == 'r31'):
|
| 15 |
+
self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
|
| 16 |
+
nn.ReLU(inplace=True),
|
| 17 |
+
nn.Conv2d(128,64,3,1,1),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(64,matrixSize,3,1,1))
|
| 20 |
+
elif(layer == 'r41'):
|
| 21 |
+
# 512x32x32
|
| 22 |
+
self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.Conv2d(256,128,3,1,1),
|
| 25 |
+
nn.ReLU(inplace=True),
|
| 26 |
+
nn.Conv2d(128,matrixSize,3,1,1))
|
| 27 |
+
self.fc = nn.Linear(32*32,32*32)
|
| 28 |
+
|
| 29 |
+
def forward(self,x,masks,style=False):
|
| 30 |
+
color_code_number = 9
|
| 31 |
+
xb,xc,xh,xw = x.size()
|
| 32 |
+
x = x.view(xc,-1)
|
| 33 |
+
feature_sub_mean = x.clone()
|
| 34 |
+
for i in range(color_code_number):
|
| 35 |
+
mask = masks[i].clone().squeeze(0)
|
| 36 |
+
mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST)
|
| 37 |
+
mask = torch.FloatTensor(mask)
|
| 38 |
+
mask = mask.long()
|
| 39 |
+
if(torch.sum(mask) >= 10):
|
| 40 |
+
mask = mask.view(-1)
|
| 41 |
+
|
| 42 |
+
# dilation here
|
| 43 |
+
"""
|
| 44 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
|
| 45 |
+
mask = mask.cpu().numpy()
|
| 46 |
+
mask = cv2.dilate(mask.astype(np.float32), kernel)
|
| 47 |
+
mask = torch.from_numpy(mask)
|
| 48 |
+
mask = mask.squeeze()
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
fgmask = (mask>0).nonzero().squeeze(1)
|
| 52 |
+
fgmask = fgmask.cuda()
|
| 53 |
+
selectFeature = torch.index_select(x,1,fgmask) # 32x96
|
| 54 |
+
# subtract mean
|
| 55 |
+
f_mean = torch.mean(selectFeature,1)
|
| 56 |
+
f_mean = f_mean.unsqueeze(1).expand_as(selectFeature)
|
| 57 |
+
selectFeature = selectFeature - f_mean
|
| 58 |
+
feature_sub_mean.index_copy_(1,fgmask,selectFeature)
|
| 59 |
+
|
| 60 |
+
feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw))
|
| 61 |
+
# 32x16x16
|
| 62 |
+
b,c,h,w = feature.size()
|
| 63 |
+
transMatrices = {}
|
| 64 |
+
feature = feature.view(c,-1)
|
| 65 |
+
|
| 66 |
+
for i in range(color_code_number):
|
| 67 |
+
mask = masks[i].clone().squeeze(0)
|
| 68 |
+
mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST)
|
| 69 |
+
mask = torch.FloatTensor(mask)
|
| 70 |
+
mask = mask.long()
|
| 71 |
+
if(torch.sum(mask) >= 10):
|
| 72 |
+
mask = mask.view(-1)
|
| 73 |
+
fgmask = Variable((mask==1).nonzero().squeeze(1))
|
| 74 |
+
fgmask = fgmask.cuda()
|
| 75 |
+
selectFeature = torch.index_select(feature,1,fgmask) # 32x96
|
| 76 |
+
tc,tN = selectFeature.size()
|
| 77 |
+
|
| 78 |
+
covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN)
|
| 79 |
+
transmatrix = self.fc(covMatrix.view(-1))
|
| 80 |
+
transMatrices[i] = transmatrix
|
| 81 |
+
return transMatrices,feature_sub_mean
|
| 82 |
+
|
| 83 |
+
class MulLayer(nn.Module):
|
| 84 |
+
def __init__(self,layer,matrixSize=32):
|
| 85 |
+
super(MulLayer,self).__init__()
|
| 86 |
+
self.snet = CNN(layer)
|
| 87 |
+
self.cnet = CNN(layer)
|
| 88 |
+
self.matrixSize = matrixSize
|
| 89 |
+
|
| 90 |
+
if(layer == 'r41'):
|
| 91 |
+
self.compress = nn.Conv2d(512,matrixSize,1,1,0)
|
| 92 |
+
self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
|
| 93 |
+
elif(layer == 'r31'):
|
| 94 |
+
self.compress = nn.Conv2d(256,matrixSize,1,1,0)
|
| 95 |
+
self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
|
| 96 |
+
|
| 97 |
+
def forward(self,cF,sF,cmasks,smasks):
|
| 98 |
+
|
| 99 |
+
sb,sc,sh,sw = sF.size()
|
| 100 |
+
|
| 101 |
+
sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True)
|
| 102 |
+
cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False)
|
| 103 |
+
|
| 104 |
+
compress_content = self.compress(cF_sub_mean.view(cF.size()))
|
| 105 |
+
cb,cc,ch,cw = compress_content.size()
|
| 106 |
+
compress_content = compress_content.view(cc,-1)
|
| 107 |
+
transfeature = compress_content.clone()
|
| 108 |
+
color_code_number = 9
|
| 109 |
+
finalSMean = Variable(torch.zeros(cF.size()).cuda(0))
|
| 110 |
+
finalSMean = finalSMean.view(sc,-1)
|
| 111 |
+
for i in range(color_code_number):
|
| 112 |
+
cmask = cmasks[i].clone().squeeze(0)
|
| 113 |
+
smask = smasks[i].clone().squeeze(0)
|
| 114 |
+
|
| 115 |
+
cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST)
|
| 116 |
+
cmask = torch.FloatTensor(cmask)
|
| 117 |
+
cmask = cmask.long()
|
| 118 |
+
smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST)
|
| 119 |
+
smask = torch.FloatTensor(smask)
|
| 120 |
+
smask = smask.long()
|
| 121 |
+
if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10
|
| 122 |
+
and (i in sMatrices) and (i in cMatrices)):
|
| 123 |
+
cmask = cmask.view(-1)
|
| 124 |
+
fgcmask = Variable((cmask==1).nonzero().squeeze(1))
|
| 125 |
+
fgcmask = fgcmask.cuda()
|
| 126 |
+
|
| 127 |
+
smask = smask.view(-1)
|
| 128 |
+
fgsmask = Variable((smask==1).nonzero().squeeze(1))
|
| 129 |
+
fgsmask = fgsmask.cuda()
|
| 130 |
+
|
| 131 |
+
sFF = sF.view(sc,-1)
|
| 132 |
+
sFF_select = torch.index_select(sFF,1,fgsmask)
|
| 133 |
+
sMean = torch.mean(sFF_select,dim=1,keepdim=True)
|
| 134 |
+
sMean = sMean.view(1,sc,1,1)
|
| 135 |
+
sMean = sMean.expand_as(cF)
|
| 136 |
+
|
| 137 |
+
sMatrix = sMatrices[i]
|
| 138 |
+
cMatrix = cMatrices[i]
|
| 139 |
+
|
| 140 |
+
sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
|
| 141 |
+
cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
|
| 142 |
+
|
| 143 |
+
transmatrix = torch.mm(sMatrix,cMatrix) # (C*C)
|
| 144 |
+
|
| 145 |
+
compress_content_select = torch.index_select(compress_content,1,fgcmask)
|
| 146 |
+
|
| 147 |
+
transfeatureFG = torch.mm(transmatrix,compress_content_select)
|
| 148 |
+
transfeature.index_copy_(1,fgcmask,transfeatureFG)
|
| 149 |
+
|
| 150 |
+
sMean = sMean.contiguous()
|
| 151 |
+
sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask)
|
| 152 |
+
finalSMean.index_copy_(1,fgcmask,sMean_select)
|
| 153 |
+
out = self.unzip(transfeature.view(cb,cc,ch,cw))
|
| 154 |
+
return out + finalSMean.view(out.size())
|
graph_networks/LinearStyleTransfer/libs/SPN.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision.models import vgg16
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append('../')
|
| 9 |
+
from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
|
| 10 |
+
|
| 11 |
+
class spn_block(nn.Module):
|
| 12 |
+
def __init__(self, horizontal, reverse):
|
| 13 |
+
super(spn_block, self).__init__()
|
| 14 |
+
self.propagator = GateRecurrent2dnoind(horizontal,reverse)
|
| 15 |
+
|
| 16 |
+
def forward(self,x,G1,G2,G3):
|
| 17 |
+
sum_abs = G1.abs() + G2.abs() + G3.abs()
|
| 18 |
+
sum_abs.data[sum_abs.data == 0] = 1e-6
|
| 19 |
+
mask_need_norm = sum_abs.ge(1)
|
| 20 |
+
mask_need_norm = mask_need_norm.float()
|
| 21 |
+
G1_norm = torch.div(G1, sum_abs)
|
| 22 |
+
G2_norm = torch.div(G2, sum_abs)
|
| 23 |
+
G3_norm = torch.div(G3, sum_abs)
|
| 24 |
+
|
| 25 |
+
G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
|
| 26 |
+
G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
|
| 27 |
+
G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
|
| 28 |
+
|
| 29 |
+
return self.propagator(x,G1,G2,G3)
|
| 30 |
+
|
| 31 |
+
class VGG(nn.Module):
|
| 32 |
+
def __init__(self,nf):
|
| 33 |
+
super(VGG,self).__init__()
|
| 34 |
+
self.conv1 = nn.Conv2d(3,nf,3,padding = 1)
|
| 35 |
+
# 256 x 256
|
| 36 |
+
self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
|
| 37 |
+
self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1)
|
| 38 |
+
# 128 x 128
|
| 39 |
+
self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
|
| 40 |
+
self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1)
|
| 41 |
+
# 64 x 64
|
| 42 |
+
self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
|
| 43 |
+
# 32 x 32
|
| 44 |
+
self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1)
|
| 45 |
+
|
| 46 |
+
def forward(self,x):
|
| 47 |
+
output = {}
|
| 48 |
+
output['conv1'] = self.conv1(x)
|
| 49 |
+
x = F.relu(output['conv1'])
|
| 50 |
+
x = self.pool1(x)
|
| 51 |
+
output['conv2'] = self.conv2(x)
|
| 52 |
+
# 128 x 128
|
| 53 |
+
x = F.relu(output['conv2'])
|
| 54 |
+
x = self.pool2(x)
|
| 55 |
+
output['conv3'] = self.conv3(x)
|
| 56 |
+
# 64 x 64
|
| 57 |
+
x = F.relu(output['conv3'])
|
| 58 |
+
output['pool3'] = self.pool3(x)
|
| 59 |
+
# 32 x 32
|
| 60 |
+
output['conv4'] = self.conv4(output['pool3'])
|
| 61 |
+
return output
|
| 62 |
+
|
| 63 |
+
class Decoder(nn.Module):
|
| 64 |
+
def __init__(self,nf=32,spn=1):
|
| 65 |
+
super(Decoder,self).__init__()
|
| 66 |
+
# 32 x 32
|
| 67 |
+
self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5
|
| 68 |
+
self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear')
|
| 69 |
+
self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8
|
| 70 |
+
nn.ELU(inplace=True))
|
| 71 |
+
# 64 x 64
|
| 72 |
+
self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear')
|
| 73 |
+
self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8
|
| 74 |
+
nn.ELU(inplace=True))
|
| 75 |
+
# 128 x 128
|
| 76 |
+
self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear')
|
| 77 |
+
self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8
|
| 78 |
+
nn.ELU(inplace=True))
|
| 79 |
+
if(spn == 1):
|
| 80 |
+
self.layer7 = nn.Conv2d(nf,nf*12,3,1,1)
|
| 81 |
+
else:
|
| 82 |
+
self.layer7 = nn.Conv2d(nf,nf*24,3,1,1)
|
| 83 |
+
self.spn = spn
|
| 84 |
+
# 256 x 256
|
| 85 |
+
|
| 86 |
+
def forward(self,encode_feature):
|
| 87 |
+
output = {}
|
| 88 |
+
output['0'] = self.layer0(encode_feature['conv4'])
|
| 89 |
+
output['1'] = self.layer1(output['0'])
|
| 90 |
+
|
| 91 |
+
output['2'] = self.layer2(output['1'])
|
| 92 |
+
output['2res'] = output['2'] + encode_feature['conv3']
|
| 93 |
+
# 64 x 64
|
| 94 |
+
|
| 95 |
+
output['3'] = self.layer3(output['2res'])
|
| 96 |
+
output['4'] = self.layer4(output['3'])
|
| 97 |
+
output['4res'] = output['4'] + encode_feature['conv2']
|
| 98 |
+
# 128 x 128
|
| 99 |
+
|
| 100 |
+
output['5'] = self.layer5(output['4res'])
|
| 101 |
+
output['6'] = self.layer6(output['5'])
|
| 102 |
+
output['6res'] = output['6'] + encode_feature['conv1']
|
| 103 |
+
|
| 104 |
+
output['7'] = self.layer7(output['6res'])
|
| 105 |
+
|
| 106 |
+
return output['7']
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class SPN(nn.Module):
|
| 110 |
+
def __init__(self,nf=32,spn=1):
|
| 111 |
+
super(SPN,self).__init__()
|
| 112 |
+
# conv for mask
|
| 113 |
+
self.mask_conv = nn.Conv2d(3,nf,3,1,1)
|
| 114 |
+
|
| 115 |
+
# guidance network
|
| 116 |
+
self.encoder = VGG(nf)
|
| 117 |
+
self.decoder = Decoder(nf,spn)
|
| 118 |
+
|
| 119 |
+
# spn blocks
|
| 120 |
+
self.left_right = spn_block(True,False)
|
| 121 |
+
self.right_left = spn_block(True,True)
|
| 122 |
+
self.top_down = spn_block(False, False)
|
| 123 |
+
self.down_top = spn_block(False,True)
|
| 124 |
+
|
| 125 |
+
# post upsample
|
| 126 |
+
self.post = nn.Conv2d(nf,3,3,1,1)
|
| 127 |
+
self.nf = nf
|
| 128 |
+
|
| 129 |
+
def forward(self,x,rgb):
|
| 130 |
+
# feature for mask
|
| 131 |
+
X = self.mask_conv(x)
|
| 132 |
+
|
| 133 |
+
# guidance
|
| 134 |
+
features = self.encoder(rgb)
|
| 135 |
+
guide = self.decoder(features)
|
| 136 |
+
|
| 137 |
+
G = torch.split(guide,self.nf,1)
|
| 138 |
+
out1 = self.left_right(X,G[0],G[1],G[2])
|
| 139 |
+
out2 = self.right_left(X,G[3],G[4],G[5])
|
| 140 |
+
out3 = self.top_down(X,G[6],G[7],G[8])
|
| 141 |
+
out4 = self.down_top(X,G[9],G[10],G[11])
|
| 142 |
+
|
| 143 |
+
out = torch.max(out1,out2)
|
| 144 |
+
out = torch.max(out,out3)
|
| 145 |
+
out = torch.max(out,out4)
|
| 146 |
+
|
| 147 |
+
return self.post(out)
|
| 148 |
+
|
| 149 |
+
if __name__ == '__main__':
|
| 150 |
+
spn = SPN()
|
| 151 |
+
spn = spn.cuda()
|
| 152 |
+
for i in range(100):
|
| 153 |
+
x = Variable(torch.Tensor(1,3,256,256)).cuda()
|
| 154 |
+
rgb = Variable(torch.Tensor(1,3,256,256)).cuda()
|
| 155 |
+
output = spn(x,rgb)
|
| 156 |
+
print(output.size())
|
graph_networks/LinearStyleTransfer/libs/__init__.py
ADDED
|
File without changes
|
graph_networks/LinearStyleTransfer/libs/models.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class encoder3(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super(encoder3,self).__init__()
|
| 7 |
+
# vgg
|
| 8 |
+
# 224 x 224
|
| 9 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
| 10 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
| 11 |
+
# 226 x 226
|
| 12 |
+
|
| 13 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
| 14 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 15 |
+
# 224 x 224
|
| 16 |
+
|
| 17 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
| 18 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
| 19 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 20 |
+
# 224 x 224
|
| 21 |
+
|
| 22 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
|
| 23 |
+
# 112 x 112
|
| 24 |
+
|
| 25 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
| 26 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
| 27 |
+
self.relu4 = nn.ReLU(inplace=True)
|
| 28 |
+
# 112 x 112
|
| 29 |
+
|
| 30 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
| 31 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
| 32 |
+
self.relu5 = nn.ReLU(inplace=True)
|
| 33 |
+
# 112 x 112
|
| 34 |
+
|
| 35 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
|
| 36 |
+
# 56 x 56
|
| 37 |
+
|
| 38 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
| 39 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
| 40 |
+
self.relu6 = nn.ReLU(inplace=True)
|
| 41 |
+
# 56 x 56
|
| 42 |
+
def forward(self,x):
|
| 43 |
+
out = self.conv1(x)
|
| 44 |
+
out = self.reflecPad1(out)
|
| 45 |
+
out = self.conv2(out)
|
| 46 |
+
out = self.relu2(out)
|
| 47 |
+
out = self.reflecPad3(out)
|
| 48 |
+
out = self.conv3(out)
|
| 49 |
+
pool1 = self.relu3(out)
|
| 50 |
+
out,pool_idx = self.maxPool(pool1)
|
| 51 |
+
out = self.reflecPad4(out)
|
| 52 |
+
out = self.conv4(out)
|
| 53 |
+
out = self.relu4(out)
|
| 54 |
+
out = self.reflecPad5(out)
|
| 55 |
+
out = self.conv5(out)
|
| 56 |
+
pool2 = self.relu5(out)
|
| 57 |
+
out,pool_idx2 = self.maxPool2(pool2)
|
| 58 |
+
out = self.reflecPad6(out)
|
| 59 |
+
out = self.conv6(out)
|
| 60 |
+
out = self.relu6(out)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
class decoder3(nn.Module):
|
| 64 |
+
def __init__(self):
|
| 65 |
+
super(decoder3,self).__init__()
|
| 66 |
+
# decoder
|
| 67 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
| 68 |
+
self.conv7 = nn.Conv2d(256,128,3,1,0)
|
| 69 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 70 |
+
# 56 x 56
|
| 71 |
+
|
| 72 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
| 73 |
+
# 112 x 112
|
| 74 |
+
|
| 75 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
| 76 |
+
self.conv8 = nn.Conv2d(128,128,3,1,0)
|
| 77 |
+
self.relu8 = nn.ReLU(inplace=True)
|
| 78 |
+
# 112 x 112
|
| 79 |
+
|
| 80 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
| 81 |
+
self.conv9 = nn.Conv2d(128,64,3,1,0)
|
| 82 |
+
self.relu9 = nn.ReLU(inplace=True)
|
| 83 |
+
|
| 84 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 85 |
+
# 224 x 224
|
| 86 |
+
|
| 87 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
| 88 |
+
self.conv10 = nn.Conv2d(64,64,3,1,0)
|
| 89 |
+
self.relu10 = nn.ReLU(inplace=True)
|
| 90 |
+
|
| 91 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
| 92 |
+
self.conv11 = nn.Conv2d(64,3,3,1,0)
|
| 93 |
+
|
| 94 |
+
def forward(self,x):
|
| 95 |
+
output = {}
|
| 96 |
+
out = self.reflecPad7(x)
|
| 97 |
+
out = self.conv7(out)
|
| 98 |
+
out = self.relu7(out)
|
| 99 |
+
out = self.unpool(out)
|
| 100 |
+
out = self.reflecPad8(out)
|
| 101 |
+
out = self.conv8(out)
|
| 102 |
+
out = self.relu8(out)
|
| 103 |
+
out = self.reflecPad9(out)
|
| 104 |
+
out = self.conv9(out)
|
| 105 |
+
out_relu9 = self.relu9(out)
|
| 106 |
+
out = self.unpool2(out_relu9)
|
| 107 |
+
out = self.reflecPad10(out)
|
| 108 |
+
out = self.conv10(out)
|
| 109 |
+
out = self.relu10(out)
|
| 110 |
+
out = self.reflecPad11(out)
|
| 111 |
+
out = self.conv11(out)
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
class encoder4(nn.Module):
|
| 115 |
+
def __init__(self):
|
| 116 |
+
super(encoder4,self).__init__()
|
| 117 |
+
# vgg
|
| 118 |
+
# 224 x 224
|
| 119 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
| 120 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
| 121 |
+
# 226 x 226
|
| 122 |
+
|
| 123 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
| 124 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 125 |
+
# 224 x 224
|
| 126 |
+
|
| 127 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
| 128 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
| 129 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 130 |
+
# 224 x 224
|
| 131 |
+
|
| 132 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 133 |
+
# 112 x 112
|
| 134 |
+
|
| 135 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
| 136 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
| 137 |
+
self.relu4 = nn.ReLU(inplace=True)
|
| 138 |
+
# 112 x 112
|
| 139 |
+
|
| 140 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
| 141 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
| 142 |
+
self.relu5 = nn.ReLU(inplace=True)
|
| 143 |
+
# 112 x 112
|
| 144 |
+
|
| 145 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 146 |
+
# 56 x 56
|
| 147 |
+
|
| 148 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
| 149 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
| 150 |
+
self.relu6 = nn.ReLU(inplace=True)
|
| 151 |
+
# 56 x 56
|
| 152 |
+
|
| 153 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
| 154 |
+
self.conv7 = nn.Conv2d(256,256,3,1,0)
|
| 155 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 156 |
+
# 56 x 56
|
| 157 |
+
|
| 158 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
| 159 |
+
self.conv8 = nn.Conv2d(256,256,3,1,0)
|
| 160 |
+
self.relu8 = nn.ReLU(inplace=True)
|
| 161 |
+
# 56 x 56
|
| 162 |
+
|
| 163 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
| 164 |
+
self.conv9 = nn.Conv2d(256,256,3,1,0)
|
| 165 |
+
self.relu9 = nn.ReLU(inplace=True)
|
| 166 |
+
# 56 x 56
|
| 167 |
+
|
| 168 |
+
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 169 |
+
# 28 x 28
|
| 170 |
+
|
| 171 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
| 172 |
+
self.conv10 = nn.Conv2d(256,512,3,1,0)
|
| 173 |
+
self.relu10 = nn.ReLU(inplace=True)
|
| 174 |
+
# 28 x 28
|
| 175 |
+
def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
|
| 176 |
+
output = {}
|
| 177 |
+
out = self.conv1(x)
|
| 178 |
+
out = self.reflecPad1(out)
|
| 179 |
+
out = self.conv2(out)
|
| 180 |
+
output['r11'] = self.relu2(out)
|
| 181 |
+
out = self.reflecPad7(output['r11'])
|
| 182 |
+
|
| 183 |
+
out = self.conv3(out)
|
| 184 |
+
output['r12'] = self.relu3(out)
|
| 185 |
+
|
| 186 |
+
output['p1'] = self.maxPool(output['r12'])
|
| 187 |
+
out = self.reflecPad4(output['p1'])
|
| 188 |
+
out = self.conv4(out)
|
| 189 |
+
output['r21'] = self.relu4(out)
|
| 190 |
+
out = self.reflecPad7(output['r21'])
|
| 191 |
+
|
| 192 |
+
out = self.conv5(out)
|
| 193 |
+
output['r22'] = self.relu5(out)
|
| 194 |
+
|
| 195 |
+
output['p2'] = self.maxPool2(output['r22'])
|
| 196 |
+
out = self.reflecPad6(output['p2'])
|
| 197 |
+
out = self.conv6(out)
|
| 198 |
+
output['r31'] = self.relu6(out)
|
| 199 |
+
if(matrix31 is not None):
|
| 200 |
+
feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
|
| 201 |
+
out = self.reflecPad7(feature3)
|
| 202 |
+
else:
|
| 203 |
+
out = self.reflecPad7(output['r31'])
|
| 204 |
+
out = self.conv7(out)
|
| 205 |
+
output['r32'] = self.relu7(out)
|
| 206 |
+
|
| 207 |
+
out = self.reflecPad8(output['r32'])
|
| 208 |
+
out = self.conv8(out)
|
| 209 |
+
output['r33'] = self.relu8(out)
|
| 210 |
+
|
| 211 |
+
out = self.reflecPad9(output['r33'])
|
| 212 |
+
out = self.conv9(out)
|
| 213 |
+
output['r34'] = self.relu9(out)
|
| 214 |
+
|
| 215 |
+
output['p3'] = self.maxPool3(output['r34'])
|
| 216 |
+
out = self.reflecPad10(output['p3'])
|
| 217 |
+
out = self.conv10(out)
|
| 218 |
+
output['r41'] = self.relu10(out)
|
| 219 |
+
|
| 220 |
+
return output
|
| 221 |
+
|
| 222 |
+
class decoder4(nn.Module):
|
| 223 |
+
def __init__(self):
|
| 224 |
+
super(decoder4,self).__init__()
|
| 225 |
+
# decoder
|
| 226 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
| 227 |
+
self.conv11 = nn.Conv2d(512,256,3,1,0)
|
| 228 |
+
self.relu11 = nn.ReLU(inplace=True)
|
| 229 |
+
# 28 x 28
|
| 230 |
+
|
| 231 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
| 232 |
+
# 56 x 56
|
| 233 |
+
|
| 234 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
| 235 |
+
self.conv12 = nn.Conv2d(256,256,3,1,0)
|
| 236 |
+
self.relu12 = nn.ReLU(inplace=True)
|
| 237 |
+
# 56 x 56
|
| 238 |
+
|
| 239 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
| 240 |
+
self.conv13 = nn.Conv2d(256,256,3,1,0)
|
| 241 |
+
self.relu13 = nn.ReLU(inplace=True)
|
| 242 |
+
# 56 x 56
|
| 243 |
+
|
| 244 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
| 245 |
+
self.conv14 = nn.Conv2d(256,256,3,1,0)
|
| 246 |
+
self.relu14 = nn.ReLU(inplace=True)
|
| 247 |
+
# 56 x 56
|
| 248 |
+
|
| 249 |
+
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
|
| 250 |
+
self.conv15 = nn.Conv2d(256,128,3,1,0)
|
| 251 |
+
self.relu15 = nn.ReLU(inplace=True)
|
| 252 |
+
# 56 x 56
|
| 253 |
+
|
| 254 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 255 |
+
# 112 x 112
|
| 256 |
+
|
| 257 |
+
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
|
| 258 |
+
self.conv16 = nn.Conv2d(128,128,3,1,0)
|
| 259 |
+
self.relu16 = nn.ReLU(inplace=True)
|
| 260 |
+
# 112 x 112
|
| 261 |
+
|
| 262 |
+
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
|
| 263 |
+
self.conv17 = nn.Conv2d(128,64,3,1,0)
|
| 264 |
+
self.relu17 = nn.ReLU(inplace=True)
|
| 265 |
+
# 112 x 112
|
| 266 |
+
|
| 267 |
+
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 268 |
+
# 224 x 224
|
| 269 |
+
|
| 270 |
+
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
|
| 271 |
+
self.conv18 = nn.Conv2d(64,64,3,1,0)
|
| 272 |
+
self.relu18 = nn.ReLU(inplace=True)
|
| 273 |
+
# 224 x 224
|
| 274 |
+
|
| 275 |
+
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
|
| 276 |
+
self.conv19 = nn.Conv2d(64,3,3,1,0)
|
| 277 |
+
|
| 278 |
+
def forward(self,x):
|
| 279 |
+
# decoder
|
| 280 |
+
out = self.reflecPad11(x)
|
| 281 |
+
out = self.conv11(out)
|
| 282 |
+
out = self.relu11(out)
|
| 283 |
+
out = self.unpool(out)
|
| 284 |
+
out = self.reflecPad12(out)
|
| 285 |
+
out = self.conv12(out)
|
| 286 |
+
|
| 287 |
+
out = self.relu12(out)
|
| 288 |
+
out = self.reflecPad13(out)
|
| 289 |
+
out = self.conv13(out)
|
| 290 |
+
out = self.relu13(out)
|
| 291 |
+
out = self.reflecPad14(out)
|
| 292 |
+
out = self.conv14(out)
|
| 293 |
+
out = self.relu14(out)
|
| 294 |
+
out = self.reflecPad15(out)
|
| 295 |
+
out = self.conv15(out)
|
| 296 |
+
out = self.relu15(out)
|
| 297 |
+
out = self.unpool2(out)
|
| 298 |
+
out = self.reflecPad16(out)
|
| 299 |
+
out = self.conv16(out)
|
| 300 |
+
out = self.relu16(out)
|
| 301 |
+
out = self.reflecPad17(out)
|
| 302 |
+
out = self.conv17(out)
|
| 303 |
+
out = self.relu17(out)
|
| 304 |
+
out = self.unpool3(out)
|
| 305 |
+
out = self.reflecPad18(out)
|
| 306 |
+
out = self.conv18(out)
|
| 307 |
+
out = self.relu18(out)
|
| 308 |
+
out = self.reflecPad19(out)
|
| 309 |
+
out = self.conv19(out)
|
| 310 |
+
return out
|
| 311 |
+
|
| 312 |
+
class decoder4(nn.Module):
|
| 313 |
+
def __init__(self):
|
| 314 |
+
super(decoder4,self).__init__()
|
| 315 |
+
# decoder
|
| 316 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
| 317 |
+
self.conv11 = nn.Conv2d(512,256,3,1,0)
|
| 318 |
+
self.relu11 = nn.ReLU(inplace=True)
|
| 319 |
+
# 28 x 28
|
| 320 |
+
|
| 321 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
| 322 |
+
# 56 x 56
|
| 323 |
+
|
| 324 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
| 325 |
+
self.conv12 = nn.Conv2d(256,256,3,1,0)
|
| 326 |
+
self.relu12 = nn.ReLU(inplace=True)
|
| 327 |
+
# 56 x 56
|
| 328 |
+
|
| 329 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
| 330 |
+
self.conv13 = nn.Conv2d(256,256,3,1,0)
|
| 331 |
+
self.relu13 = nn.ReLU(inplace=True)
|
| 332 |
+
# 56 x 56
|
| 333 |
+
|
| 334 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
| 335 |
+
self.conv14 = nn.Conv2d(256,256,3,1,0)
|
| 336 |
+
self.relu14 = nn.ReLU(inplace=True)
|
| 337 |
+
# 56 x 56
|
| 338 |
+
|
| 339 |
+
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
|
| 340 |
+
self.conv15 = nn.Conv2d(256,128,3,1,0)
|
| 341 |
+
self.relu15 = nn.ReLU(inplace=True)
|
| 342 |
+
# 56 x 56
|
| 343 |
+
|
| 344 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 345 |
+
# 112 x 112
|
| 346 |
+
|
| 347 |
+
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
|
| 348 |
+
self.conv16 = nn.Conv2d(128,128,3,1,0)
|
| 349 |
+
self.relu16 = nn.ReLU(inplace=True)
|
| 350 |
+
# 112 x 112
|
| 351 |
+
|
| 352 |
+
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
|
| 353 |
+
self.conv17 = nn.Conv2d(128,64,3,1,0)
|
| 354 |
+
self.relu17 = nn.ReLU(inplace=True)
|
| 355 |
+
# 112 x 112
|
| 356 |
+
|
| 357 |
+
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 358 |
+
# 224 x 224
|
| 359 |
+
|
| 360 |
+
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
|
| 361 |
+
self.conv18 = nn.Conv2d(64,64,3,1,0)
|
| 362 |
+
self.relu18 = nn.ReLU(inplace=True)
|
| 363 |
+
# 224 x 224
|
| 364 |
+
|
| 365 |
+
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
|
| 366 |
+
self.conv19 = nn.Conv2d(64,3,3,1,0)
|
| 367 |
+
|
| 368 |
+
def forward(self,x):
|
| 369 |
+
# decoder
|
| 370 |
+
out = self.reflecPad11(x)
|
| 371 |
+
out = self.conv11(out)
|
| 372 |
+
out = self.relu11(out)
|
| 373 |
+
out = self.unpool(out)
|
| 374 |
+
out = self.reflecPad12(out)
|
| 375 |
+
out = self.conv12(out)
|
| 376 |
+
|
| 377 |
+
out = self.relu12(out)
|
| 378 |
+
out = self.reflecPad13(out)
|
| 379 |
+
out = self.conv13(out)
|
| 380 |
+
out = self.relu13(out)
|
| 381 |
+
out = self.reflecPad14(out)
|
| 382 |
+
out = self.conv14(out)
|
| 383 |
+
out = self.relu14(out)
|
| 384 |
+
out = self.reflecPad15(out)
|
| 385 |
+
out = self.conv15(out)
|
| 386 |
+
out = self.relu15(out)
|
| 387 |
+
out = self.unpool2(out)
|
| 388 |
+
out = self.reflecPad16(out)
|
| 389 |
+
out = self.conv16(out)
|
| 390 |
+
out = self.relu16(out)
|
| 391 |
+
out = self.reflecPad17(out)
|
| 392 |
+
out = self.conv17(out)
|
| 393 |
+
out = self.relu17(out)
|
| 394 |
+
out = self.unpool3(out)
|
| 395 |
+
out = self.reflecPad18(out)
|
| 396 |
+
out = self.conv18(out)
|
| 397 |
+
out = self.relu18(out)
|
| 398 |
+
out = self.reflecPad19(out)
|
| 399 |
+
out = self.conv19(out)
|
| 400 |
+
return out
|
| 401 |
+
|
| 402 |
+
class encoder5(nn.Module):
|
| 403 |
+
def __init__(self):
|
| 404 |
+
super(encoder5,self).__init__()
|
| 405 |
+
# vgg
|
| 406 |
+
# 224 x 224
|
| 407 |
+
self.conv1 = nn.Conv2d(3,3,1,1,0)
|
| 408 |
+
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
|
| 409 |
+
# 226 x 226
|
| 410 |
+
|
| 411 |
+
self.conv2 = nn.Conv2d(3,64,3,1,0)
|
| 412 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 413 |
+
# 224 x 224
|
| 414 |
+
|
| 415 |
+
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
|
| 416 |
+
self.conv3 = nn.Conv2d(64,64,3,1,0)
|
| 417 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 418 |
+
# 224 x 224
|
| 419 |
+
|
| 420 |
+
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 421 |
+
# 112 x 112
|
| 422 |
+
|
| 423 |
+
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
|
| 424 |
+
self.conv4 = nn.Conv2d(64,128,3,1,0)
|
| 425 |
+
self.relu4 = nn.ReLU(inplace=True)
|
| 426 |
+
# 112 x 112
|
| 427 |
+
|
| 428 |
+
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
|
| 429 |
+
self.conv5 = nn.Conv2d(128,128,3,1,0)
|
| 430 |
+
self.relu5 = nn.ReLU(inplace=True)
|
| 431 |
+
# 112 x 112
|
| 432 |
+
|
| 433 |
+
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 434 |
+
# 56 x 56
|
| 435 |
+
|
| 436 |
+
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
|
| 437 |
+
self.conv6 = nn.Conv2d(128,256,3,1,0)
|
| 438 |
+
self.relu6 = nn.ReLU(inplace=True)
|
| 439 |
+
# 56 x 56
|
| 440 |
+
|
| 441 |
+
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
|
| 442 |
+
self.conv7 = nn.Conv2d(256,256,3,1,0)
|
| 443 |
+
self.relu7 = nn.ReLU(inplace=True)
|
| 444 |
+
# 56 x 56
|
| 445 |
+
|
| 446 |
+
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
|
| 447 |
+
self.conv8 = nn.Conv2d(256,256,3,1,0)
|
| 448 |
+
self.relu8 = nn.ReLU(inplace=True)
|
| 449 |
+
# 56 x 56
|
| 450 |
+
|
| 451 |
+
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
|
| 452 |
+
self.conv9 = nn.Conv2d(256,256,3,1,0)
|
| 453 |
+
self.relu9 = nn.ReLU(inplace=True)
|
| 454 |
+
# 56 x 56
|
| 455 |
+
|
| 456 |
+
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 457 |
+
# 28 x 28
|
| 458 |
+
|
| 459 |
+
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
|
| 460 |
+
self.conv10 = nn.Conv2d(256,512,3,1,0)
|
| 461 |
+
self.relu10 = nn.ReLU(inplace=True)
|
| 462 |
+
|
| 463 |
+
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
|
| 464 |
+
self.conv11 = nn.Conv2d(512,512,3,1,0)
|
| 465 |
+
self.relu11 = nn.ReLU(inplace=True)
|
| 466 |
+
|
| 467 |
+
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
|
| 468 |
+
self.conv12 = nn.Conv2d(512,512,3,1,0)
|
| 469 |
+
self.relu12 = nn.ReLU(inplace=True)
|
| 470 |
+
|
| 471 |
+
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
|
| 472 |
+
self.conv13 = nn.Conv2d(512,512,3,1,0)
|
| 473 |
+
self.relu13 = nn.ReLU(inplace=True)
|
| 474 |
+
|
| 475 |
+
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
|
| 476 |
+
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
|
| 477 |
+
self.conv14 = nn.Conv2d(512,512,3,1,0)
|
| 478 |
+
self.relu14 = nn.ReLU(inplace=True)
|
| 479 |
+
|
| 480 |
+
def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
|
| 481 |
+
output = {}
|
| 482 |
+
out = self.conv1(x)
|
| 483 |
+
out = self.reflecPad1(out)
|
| 484 |
+
out = self.conv2(out)
|
| 485 |
+
output['r11'] = self.relu2(out)
|
| 486 |
+
out = self.reflecPad7(output['r11'])
|
| 487 |
+
|
| 488 |
+
#out = self.reflecPad3(output['r11'])
|
| 489 |
+
out = self.conv3(out)
|
| 490 |
+
output['r12'] = self.relu3(out)
|
| 491 |
+
|
| 492 |
+
output['p1'] = self.maxPool(output['r12'])
|
| 493 |
+
out = self.reflecPad4(output['p1'])
|
| 494 |
+
out = self.conv4(out)
|
| 495 |
+
output['r21'] = self.relu4(out)
|
| 496 |
+
out = self.reflecPad7(output['r21'])
|
| 497 |
+
|
| 498 |
+
#out = self.reflecPad5(output['r21'])
|
| 499 |
+
out = self.conv5(out)
|
| 500 |
+
output['r22'] = self.relu5(out)
|
| 501 |
+
|
| 502 |
+
output['p2'] = self.maxPool2(output['r22'])
|
| 503 |
+
out = self.reflecPad6(output['p2'])
|
| 504 |
+
out = self.conv6(out)
|
| 505 |
+
output['r31'] = self.relu6(out)
|
| 506 |
+
if(styleV256 is not None):
|
| 507 |
+
feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
|
| 508 |
+
out = self.reflecPad7(feature)
|
| 509 |
+
else:
|
| 510 |
+
out = self.reflecPad7(output['r31'])
|
| 511 |
+
out = self.conv7(out)
|
| 512 |
+
output['r32'] = self.relu7(out)
|
| 513 |
+
|
| 514 |
+
out = self.reflecPad8(output['r32'])
|
| 515 |
+
out = self.conv8(out)
|
| 516 |
+
output['r33'] = self.relu8(out)
|
| 517 |
+
|
| 518 |
+
out = self.reflecPad9(output['r33'])
|
| 519 |
+
out = self.conv9(out)
|
| 520 |
+
output['r34'] = self.relu9(out)
|
| 521 |
+
|
| 522 |
+
output['p3'] = self.maxPool3(output['r34'])
|
| 523 |
+
out = self.reflecPad10(output['p3'])
|
| 524 |
+
out = self.conv10(out)
|
| 525 |
+
output['r41'] = self.relu10(out)
|
| 526 |
+
|
| 527 |
+
out = self.reflecPad11(output['r41'])
|
| 528 |
+
out = self.conv11(out)
|
| 529 |
+
output['r42'] = self.relu11(out)
|
| 530 |
+
|
| 531 |
+
out = self.reflecPad12(output['r42'])
|
| 532 |
+
out = self.conv12(out)
|
| 533 |
+
output['r43'] = self.relu12(out)
|
| 534 |
+
|
| 535 |
+
out = self.reflecPad13(output['r43'])
|
| 536 |
+
out = self.conv13(out)
|
| 537 |
+
output['r44'] = self.relu13(out)
|
| 538 |
+
|
| 539 |
+
output['p4'] = self.maxPool4(output['r44'])
|
| 540 |
+
|
| 541 |
+
out = self.reflecPad14(output['p4'])
|
| 542 |
+
out = self.conv14(out)
|
| 543 |
+
output['r51'] = self.relu14(out)
|
| 544 |
+
return output
|
| 545 |
+
|
| 546 |
+
class decoder5(nn.Module):
|
| 547 |
+
def __init__(self):
|
| 548 |
+
super(decoder5,self).__init__()
|
| 549 |
+
|
| 550 |
+
# decoder
|
| 551 |
+
self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
|
| 552 |
+
self.conv15 = nn.Conv2d(512,512,3,1,0)
|
| 553 |
+
self.relu15 = nn.ReLU(inplace=True)
|
| 554 |
+
|
| 555 |
+
self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
|
| 556 |
+
# 28 x 28
|
| 557 |
+
|
| 558 |
+
self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
|
| 559 |
+
self.conv16 = nn.Conv2d(512,512,3,1,0)
|
| 560 |
+
self.relu16 = nn.ReLU(inplace=True)
|
| 561 |
+
# 28 x 28
|
| 562 |
+
|
| 563 |
+
self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
|
| 564 |
+
self.conv17 = nn.Conv2d(512,512,3,1,0)
|
| 565 |
+
self.relu17 = nn.ReLU(inplace=True)
|
| 566 |
+
# 28 x 28
|
| 567 |
+
|
| 568 |
+
self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
|
| 569 |
+
self.conv18 = nn.Conv2d(512,512,3,1,0)
|
| 570 |
+
self.relu18 = nn.ReLU(inplace=True)
|
| 571 |
+
# 28 x 28
|
| 572 |
+
|
| 573 |
+
self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
|
| 574 |
+
self.conv19 = nn.Conv2d(512,256,3,1,0)
|
| 575 |
+
self.relu19 = nn.ReLU(inplace=True)
|
| 576 |
+
# 28 x 28
|
| 577 |
+
|
| 578 |
+
self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 579 |
+
# 56 x 56
|
| 580 |
+
|
| 581 |
+
self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
|
| 582 |
+
self.conv20 = nn.Conv2d(256,256,3,1,0)
|
| 583 |
+
self.relu20 = nn.ReLU(inplace=True)
|
| 584 |
+
# 56 x 56
|
| 585 |
+
|
| 586 |
+
self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
|
| 587 |
+
self.conv21 = nn.Conv2d(256,256,3,1,0)
|
| 588 |
+
self.relu21 = nn.ReLU(inplace=True)
|
| 589 |
+
|
| 590 |
+
self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
|
| 591 |
+
self.conv22 = nn.Conv2d(256,256,3,1,0)
|
| 592 |
+
self.relu22 = nn.ReLU(inplace=True)
|
| 593 |
+
|
| 594 |
+
self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
|
| 595 |
+
self.conv23 = nn.Conv2d(256,128,3,1,0)
|
| 596 |
+
self.relu23 = nn.ReLU(inplace=True)
|
| 597 |
+
|
| 598 |
+
self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 599 |
+
# 112 X 112
|
| 600 |
+
|
| 601 |
+
self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
|
| 602 |
+
self.conv24 = nn.Conv2d(128,128,3,1,0)
|
| 603 |
+
self.relu24 = nn.ReLU(inplace=True)
|
| 604 |
+
|
| 605 |
+
self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
|
| 606 |
+
self.conv25 = nn.Conv2d(128,64,3,1,0)
|
| 607 |
+
self.relu25 = nn.ReLU(inplace=True)
|
| 608 |
+
|
| 609 |
+
self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)
|
| 610 |
+
|
| 611 |
+
self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
|
| 612 |
+
self.conv26 = nn.Conv2d(64,64,3,1,0)
|
| 613 |
+
self.relu26 = nn.ReLU(inplace=True)
|
| 614 |
+
|
| 615 |
+
self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
|
| 616 |
+
self.conv27 = nn.Conv2d(64,3,3,1,0)
|
| 617 |
+
|
| 618 |
+
def forward(self,x):
|
| 619 |
+
# decoder
|
| 620 |
+
out = self.reflecPad15(x)
|
| 621 |
+
out = self.conv15(out)
|
| 622 |
+
out = self.relu15(out)
|
| 623 |
+
out = self.unpool(out)
|
| 624 |
+
out = self.reflecPad16(out)
|
| 625 |
+
out = self.conv16(out)
|
| 626 |
+
out = self.relu16(out)
|
| 627 |
+
out = self.reflecPad17(out)
|
| 628 |
+
out = self.conv17(out)
|
| 629 |
+
out = self.relu17(out)
|
| 630 |
+
out = self.reflecPad18(out)
|
| 631 |
+
out = self.conv18(out)
|
| 632 |
+
out = self.relu18(out)
|
| 633 |
+
out = self.reflecPad19(out)
|
| 634 |
+
out = self.conv19(out)
|
| 635 |
+
out = self.relu19(out)
|
| 636 |
+
out = self.unpool2(out)
|
| 637 |
+
out = self.reflecPad20(out)
|
| 638 |
+
out = self.conv20(out)
|
| 639 |
+
out = self.relu20(out)
|
| 640 |
+
out = self.reflecPad21(out)
|
| 641 |
+
out = self.conv21(out)
|
| 642 |
+
out = self.relu21(out)
|
| 643 |
+
out = self.reflecPad22(out)
|
| 644 |
+
out = self.conv22(out)
|
| 645 |
+
out = self.relu22(out)
|
| 646 |
+
out = self.reflecPad23(out)
|
| 647 |
+
out = self.conv23(out)
|
| 648 |
+
out = self.relu23(out)
|
| 649 |
+
out = self.unpool3(out)
|
| 650 |
+
out = self.reflecPad24(out)
|
| 651 |
+
out = self.conv24(out)
|
| 652 |
+
out = self.relu24(out)
|
| 653 |
+
out = self.reflecPad25(out)
|
| 654 |
+
out = self.conv25(out)
|
| 655 |
+
out = self.relu25(out)
|
| 656 |
+
out = self.unpool4(out)
|
| 657 |
+
out = self.reflecPad26(out)
|
| 658 |
+
out = self.conv26(out)
|
| 659 |
+
out = self.relu26(out)
|
| 660 |
+
out = self.reflecPad27(out)
|
| 661 |
+
out = self.conv27(out)
|
| 662 |
+
return out
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytorch_spn
|
| 2 |
+
To build, install [pytorch](https://github.com/pytorch) and run:
|
| 3 |
+
|
| 4 |
+
$ sh make.sh
|
| 5 |
+
|
| 6 |
+
See left_right_demo.py for usage:
|
| 7 |
+
|
| 8 |
+
$ mv left_right_demo.py ../
|
| 9 |
+
|
| 10 |
+
$ python left_right_demo.py
|
| 11 |
+
|
| 12 |
+
The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git).
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py
ADDED
|
File without changes
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py
ADDED
|
File without changes
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from torch.utils.ffi import _wrap_function
|
| 3 |
+
from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi
|
| 4 |
+
|
| 5 |
+
__all__ = []
|
| 6 |
+
def _import_symbols(locals):
|
| 7 |
+
for symbol in dir(_lib):
|
| 8 |
+
fn = getattr(_lib, symbol)
|
| 9 |
+
if callable(fn):
|
| 10 |
+
locals[symbol] = _wrap_function(fn, _ffi)
|
| 11 |
+
else:
|
| 12 |
+
locals[symbol] = fn
|
| 13 |
+
__all__.append(symbol)
|
| 14 |
+
|
| 15 |
+
_import_symbols(locals())
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.ffi import create_extension
|
| 4 |
+
|
| 5 |
+
this_file = os.path.dirname(__file__)
|
| 6 |
+
|
| 7 |
+
sources = []
|
| 8 |
+
headers = []
|
| 9 |
+
defines = []
|
| 10 |
+
with_cuda = False
|
| 11 |
+
|
| 12 |
+
if torch.cuda.is_available():
|
| 13 |
+
print('Including CUDA code.')
|
| 14 |
+
sources += ['src/gaterecurrent2dnoind_cuda.c']
|
| 15 |
+
headers += ['src/gaterecurrent2dnoind_cuda.h']
|
| 16 |
+
defines += [('WITH_CUDA', None)]
|
| 17 |
+
with_cuda = True
|
| 18 |
+
|
| 19 |
+
this_file = os.path.dirname(os.path.realpath(__file__))
|
| 20 |
+
extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o']
|
| 21 |
+
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
|
| 22 |
+
|
| 23 |
+
ffi = create_extension(
|
| 24 |
+
'_ext.gaterecurrent2dnoind',
|
| 25 |
+
headers=headers,
|
| 26 |
+
sources=sources,
|
| 27 |
+
define_macros=defines,
|
| 28 |
+
relative_to=__file__,
|
| 29 |
+
with_cuda=with_cuda,
|
| 30 |
+
extra_objects=extra_objects
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if __name__ == '__main__':
|
| 34 |
+
ffi.build()
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py
ADDED
|
File without changes
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import Function
|
| 3 |
+
from .._ext import gaterecurrent2dnoind as gaterecurrent2d
|
| 4 |
+
|
| 5 |
+
class GateRecurrent2dnoindFunction(Function):
|
| 6 |
+
def __init__(self, horizontal_, reverse_):
|
| 7 |
+
self.horizontal = horizontal_
|
| 8 |
+
self.reverse = reverse_
|
| 9 |
+
|
| 10 |
+
def forward(self, X, G1, G2, G3):
|
| 11 |
+
num, channels, height, width = X.size()
|
| 12 |
+
output = torch.zeros(num, channels, height, width)
|
| 13 |
+
|
| 14 |
+
if not X.is_cuda:
|
| 15 |
+
print("cpu version is not ready at this time")
|
| 16 |
+
return 0
|
| 17 |
+
else:
|
| 18 |
+
output = output.cuda()
|
| 19 |
+
gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output)
|
| 20 |
+
|
| 21 |
+
self.X = X
|
| 22 |
+
self.G1 = G1
|
| 23 |
+
self.G2 = G2
|
| 24 |
+
self.G3 = G3
|
| 25 |
+
self.output = output
|
| 26 |
+
self.hiddensize = X.size()
|
| 27 |
+
return output
|
| 28 |
+
|
| 29 |
+
def backward(self, grad_output):
|
| 30 |
+
assert(self.hiddensize is not None and grad_output.is_cuda)
|
| 31 |
+
num, channels, height, width = self.hiddensize
|
| 32 |
+
|
| 33 |
+
grad_X = torch.zeros(num, channels, height, width).cuda()
|
| 34 |
+
grad_G1 = torch.zeros(num, channels, height, width).cuda()
|
| 35 |
+
grad_G2 = torch.zeros(num, channels, height, width).cuda()
|
| 36 |
+
grad_G3 = torch.zeros(num, channels, height, width).cuda()
|
| 37 |
+
|
| 38 |
+
gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3)
|
| 39 |
+
|
| 40 |
+
del self.hiddensize
|
| 41 |
+
del self.G1
|
| 42 |
+
del self.G2
|
| 43 |
+
del self.G3
|
| 44 |
+
del self.output
|
| 45 |
+
del self.X
|
| 46 |
+
|
| 47 |
+
return grad_X, grad_G1, grad_G2, grad_G3
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
An example of left->right propagation
|
| 3 |
+
|
| 4 |
+
Other direction settings:
|
| 5 |
+
left->right: Propagator = GateRecurrent2dnoind(True,False)
|
| 6 |
+
right->left: Propagator = GateRecurrent2dnoind(True,True)
|
| 7 |
+
top->bottom: Propagator = GateRecurrent2dnoind(False,False)
|
| 8 |
+
bottom->top: Propagator = GateRecurrent2dnoind(False,True)
|
| 9 |
+
|
| 10 |
+
X: any signal/feature map to be filtered
|
| 11 |
+
G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)
|
| 12 |
+
|
| 13 |
+
Note:
|
| 14 |
+
1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
|
| 15 |
+
2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
|
| 16 |
+
"""
|
| 17 |
+
import torch
|
| 18 |
+
from torch.autograd import Variable
|
| 19 |
+
from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
|
| 20 |
+
|
| 21 |
+
Propagator = GateRecurrent2dnoind(True,False)
|
| 22 |
+
|
| 23 |
+
X = Variable(torch.randn(1,3,10,10))
|
| 24 |
+
G1 = Variable(torch.randn(1,3,10,10))
|
| 25 |
+
G2 = Variable(torch.randn(1,3,10,10))
|
| 26 |
+
G3 = Variable(torch.randn(1,3,10,10))
|
| 27 |
+
|
| 28 |
+
sum_abs = G1.abs() + G2.abs() + G3.abs()
|
| 29 |
+
mask_need_norm = sum_abs.ge(1)
|
| 30 |
+
mask_need_norm = mask_need_norm.float()
|
| 31 |
+
G1_norm = torch.div(G1, sum_abs)
|
| 32 |
+
G2_norm = torch.div(G2, sum_abs)
|
| 33 |
+
G3_norm = torch.div(G3, sum_abs)
|
| 34 |
+
|
| 35 |
+
G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
|
| 36 |
+
G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
|
| 37 |
+
G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
|
| 38 |
+
|
| 39 |
+
X = X.cuda()
|
| 40 |
+
G1 = G1.cuda()
|
| 41 |
+
G2 = G2.cuda()
|
| 42 |
+
G3 = G3.cuda()
|
| 43 |
+
|
| 44 |
+
output = Propagator.forward(X,G1,G2,G3)
|
| 45 |
+
print(X)
|
| 46 |
+
print(output)
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
CUDA_PATH=/usr/local/cuda/
|
| 4 |
+
|
| 5 |
+
cd src/cuda/
|
| 6 |
+
echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..."
|
| 7 |
+
nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
|
| 8 |
+
cd ../../
|
| 9 |
+
python build.py
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction
|
| 3 |
+
|
| 4 |
+
class GateRecurrent2dnoind(nn.Module):
|
| 5 |
+
"""docstring for ."""
|
| 6 |
+
def __init__(self, horizontal_, reverse_):
|
| 7 |
+
super(GateRecurrent2dnoind, self).__init__()
|
| 8 |
+
self.horizontal = horizontal_
|
| 9 |
+
self.reverse = reverse_
|
| 10 |
+
|
| 11 |
+
def forward(self, X, G1, G2, G3):
|
| 12 |
+
return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3)
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu
ADDED
|
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifdef __cplusplus
|
| 2 |
+
extern "C" {
|
| 3 |
+
#endif
|
| 4 |
+
|
| 5 |
+
#include <stdio.h>
|
| 6 |
+
#include <math.h>
|
| 7 |
+
#include <float.h>
|
| 8 |
+
#include "gaterecurrent2dnoind_kernel.h"
|
| 9 |
+
|
| 10 |
+
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
| 11 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
|
| 12 |
+
i += blockDim.x * gridDim.x)
|
| 13 |
+
|
| 14 |
+
__device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse)
|
| 15 |
+
{
|
| 16 |
+
if(horizontal && ! reverse) // left -> right
|
| 17 |
+
{
|
| 18 |
+
if(w1>w2)
|
| 19 |
+
{
|
| 20 |
+
out[0]=h1;
|
| 21 |
+
out[1]=w1;
|
| 22 |
+
}
|
| 23 |
+
else
|
| 24 |
+
{
|
| 25 |
+
out[0]=h2;
|
| 26 |
+
out[1]=w2;
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
if(horizontal && reverse) // right -> left
|
| 30 |
+
{
|
| 31 |
+
if(w1<w2)
|
| 32 |
+
{
|
| 33 |
+
out[0]=h1;
|
| 34 |
+
out[1]=w1;
|
| 35 |
+
}
|
| 36 |
+
else
|
| 37 |
+
{
|
| 38 |
+
out[0]=h2;
|
| 39 |
+
out[1]=w2;
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
if(!horizontal && !reverse) // top -> bottom
|
| 43 |
+
{
|
| 44 |
+
if(h1>h2)
|
| 45 |
+
{
|
| 46 |
+
out[0]=h1;
|
| 47 |
+
out[1]=w1;
|
| 48 |
+
}
|
| 49 |
+
else
|
| 50 |
+
{
|
| 51 |
+
out[0]=h2;
|
| 52 |
+
out[1]=w2;
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
if(!horizontal && reverse) // bottom -> top
|
| 56 |
+
{
|
| 57 |
+
if(h1<h2)
|
| 58 |
+
{
|
| 59 |
+
out[0]=h1;
|
| 60 |
+
out[1]=w1;
|
| 61 |
+
}
|
| 62 |
+
else
|
| 63 |
+
{
|
| 64 |
+
out[0]=h2;
|
| 65 |
+
out[1]=w2;
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
__device__ float get_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w)
|
| 72 |
+
{
|
| 73 |
+
if(h<0 || h >=height)
|
| 74 |
+
return 0;
|
| 75 |
+
if(w<0 || w >= width)
|
| 76 |
+
return 0;
|
| 77 |
+
|
| 78 |
+
return data[n*channels*height*width + c * height*width + h * width + w];
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
__device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v)
|
| 82 |
+
{
|
| 83 |
+
if(h<0 || h >=height)
|
| 84 |
+
return ;
|
| 85 |
+
if(w<0 || w >= width)
|
| 86 |
+
return ;
|
| 87 |
+
|
| 88 |
+
data[n*channels*height*width + c * height*width + h * width + w]=v;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
__device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse)
|
| 92 |
+
{
|
| 93 |
+
if(h1<0 || h1 >=height)
|
| 94 |
+
return 0;
|
| 95 |
+
if(w1<0 || w1 >= width)
|
| 96 |
+
return 0;
|
| 97 |
+
if(h2<0 || h2 >=height)
|
| 98 |
+
return 0;
|
| 99 |
+
if(w2<0 || w2 >= width)
|
| 100 |
+
return 0;
|
| 101 |
+
int idx[2];
|
| 102 |
+
|
| 103 |
+
get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
|
| 104 |
+
|
| 105 |
+
int h = idx[0];
|
| 106 |
+
int w = idx[1];
|
| 107 |
+
|
| 108 |
+
return data[n*channels*height*width + c * height*width + h * width + w];
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
__device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v)
|
| 112 |
+
{
|
| 113 |
+
if(h1<0 || h1 >=height)
|
| 114 |
+
return ;
|
| 115 |
+
if(w1<0 || w1 >= width)
|
| 116 |
+
return ;
|
| 117 |
+
if(h2<0 || h2 >=height)
|
| 118 |
+
return ;
|
| 119 |
+
if(w2<0 || w2 >= width)
|
| 120 |
+
return ;
|
| 121 |
+
int idx[2];
|
| 122 |
+
|
| 123 |
+
get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
|
| 124 |
+
|
| 125 |
+
int h = idx[0];
|
| 126 |
+
int w = idx[1];
|
| 127 |
+
|
| 128 |
+
data[n*channels*height*width + c * height*width + h * width + w]=v;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// we do not use set_gate_add_sf(...) in the caffe implimentation
|
| 132 |
+
// avoid using atomicAdd
|
| 133 |
+
|
| 134 |
+
__global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) {
|
| 135 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 136 |
+
|
| 137 |
+
int hc_count = height * channels;
|
| 138 |
+
|
| 139 |
+
int n,c,h,w;
|
| 140 |
+
int temp=index;
|
| 141 |
+
w = T;
|
| 142 |
+
n = temp / hc_count;
|
| 143 |
+
temp = temp % hc_count;
|
| 144 |
+
c = temp / height;
|
| 145 |
+
temp = temp % height;
|
| 146 |
+
h = temp;
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 150 |
+
|
| 151 |
+
float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 152 |
+
float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
|
| 153 |
+
float h1_minus1 = g_data_1 * h_minus1_data_1;
|
| 154 |
+
|
| 155 |
+
float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
|
| 156 |
+
float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
|
| 157 |
+
float h2_minus1 = g_data_2 * h_minus1_data_2;
|
| 158 |
+
|
| 159 |
+
float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 160 |
+
float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
|
| 161 |
+
float h3_minus1 = g_data_3 * h_minus1_data_3;
|
| 162 |
+
|
| 163 |
+
float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
|
| 164 |
+
float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
|
| 165 |
+
|
| 166 |
+
float h_data = x_hype + h_hype;
|
| 167 |
+
|
| 168 |
+
set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
|
| 169 |
+
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
__global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
|
| 174 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 175 |
+
|
| 176 |
+
int hc_count = height * channels;
|
| 177 |
+
int n,c,h,w;
|
| 178 |
+
int temp=index;
|
| 179 |
+
w = T;
|
| 180 |
+
n = temp / hc_count;
|
| 181 |
+
temp = temp % hc_count;
|
| 182 |
+
c = temp / height;
|
| 183 |
+
temp = temp % height;
|
| 184 |
+
h = temp;
|
| 185 |
+
|
| 186 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 187 |
+
|
| 188 |
+
float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 189 |
+
float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
|
| 190 |
+
float h1_minus1 = g_data_1 * h_minus1_data_1;
|
| 191 |
+
|
| 192 |
+
float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
|
| 193 |
+
float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
|
| 194 |
+
float h2_minus1 = g_data_2 * h_minus1_data_2;
|
| 195 |
+
|
| 196 |
+
float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 197 |
+
float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
|
| 198 |
+
float h3_minus1 = g_data_3 * h_minus1_data_3;
|
| 199 |
+
|
| 200 |
+
float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
|
| 201 |
+
float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
|
| 202 |
+
|
| 203 |
+
float h_data = x_hype + h_hype;
|
| 204 |
+
|
| 205 |
+
set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
|
| 206 |
+
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
__global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
|
| 211 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 212 |
+
|
| 213 |
+
int wc_count = width * channels;
|
| 214 |
+
|
| 215 |
+
int n,c,h,w;
|
| 216 |
+
int temp=index;
|
| 217 |
+
h = T;
|
| 218 |
+
n = temp / wc_count;
|
| 219 |
+
temp = temp % wc_count;
|
| 220 |
+
c = temp / width;
|
| 221 |
+
temp = temp % width;
|
| 222 |
+
w = temp;
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 229 |
+
float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
|
| 230 |
+
float h1_minus1 = g_data_1 * h_minus1_data_1;
|
| 231 |
+
|
| 232 |
+
float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
|
| 233 |
+
float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
|
| 234 |
+
float h2_minus1 = g_data_2 * h_minus1_data_2;
|
| 235 |
+
|
| 236 |
+
float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 237 |
+
float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
|
| 238 |
+
float h3_minus1 = g_data_3 * h_minus1_data_3;
|
| 239 |
+
|
| 240 |
+
float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
|
| 241 |
+
float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
|
| 242 |
+
|
| 243 |
+
float h_data = x_hype + h_hype;
|
| 244 |
+
|
| 245 |
+
set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
|
| 246 |
+
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
__global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
|
| 251 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 252 |
+
|
| 253 |
+
int wc_count = width * channels;
|
| 254 |
+
|
| 255 |
+
int n,c,h,w;
|
| 256 |
+
int temp=index;
|
| 257 |
+
h = T;
|
| 258 |
+
n = temp / wc_count;
|
| 259 |
+
temp = temp % wc_count;
|
| 260 |
+
c = temp / width;
|
| 261 |
+
temp = temp % width;
|
| 262 |
+
w = temp;
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 269 |
+
float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
|
| 270 |
+
float h1_minus1 = g_data_1 * h_minus1_data_1;
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
|
| 274 |
+
float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
|
| 275 |
+
float h2_minus1 = g_data_2 * h_minus1_data_2;
|
| 276 |
+
|
| 277 |
+
float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 278 |
+
float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
|
| 279 |
+
float h3_minus1 = g_data_3 * h_minus1_data_3;
|
| 280 |
+
|
| 281 |
+
float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
|
| 282 |
+
float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
|
| 283 |
+
|
| 284 |
+
float h_data = x_hype + h_hype;
|
| 285 |
+
|
| 286 |
+
set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
|
| 287 |
+
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
__global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
|
| 293 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 294 |
+
|
| 295 |
+
int hc_count = height * channels;
|
| 296 |
+
|
| 297 |
+
int n,c,h,w;
|
| 298 |
+
int temp=index;
|
| 299 |
+
|
| 300 |
+
w = T;
|
| 301 |
+
n = temp / hc_count;
|
| 302 |
+
temp = temp % hc_count;
|
| 303 |
+
c = temp / height;
|
| 304 |
+
temp = temp % height;
|
| 305 |
+
h = temp;
|
| 306 |
+
|
| 307 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 308 |
+
|
| 309 |
+
//h(t)_diff = top(t)_diff
|
| 310 |
+
float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
|
| 311 |
+
|
| 312 |
+
//h(t)_diff += h(t+1)_diff * g(t+1) if t<T
|
| 313 |
+
float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w+1);
|
| 314 |
+
float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 315 |
+
|
| 316 |
+
float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w+1);
|
| 317 |
+
float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
|
| 318 |
+
|
| 319 |
+
float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w+1);
|
| 320 |
+
float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 321 |
+
|
| 322 |
+
h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
//Hdiff[n*channels*height*width + c*height*width + h*width + w]=0;
|
| 326 |
+
set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
//x(t)_diff=(1-sum(g_date))*h(t)_diff
|
| 330 |
+
float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 331 |
+
float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
|
| 332 |
+
float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 333 |
+
|
| 334 |
+
float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
|
| 335 |
+
set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
// g_diff = h_diff * (h_data(t-1) - x_data)
|
| 339 |
+
float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
|
| 340 |
+
float g1_diff = h_diff * (h1_minus1_data - x_data);
|
| 341 |
+
set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse,g1_diff);
|
| 342 |
+
|
| 343 |
+
float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
|
| 344 |
+
float g2_diff = h_diff * (h2_minus1_data - x_data);
|
| 345 |
+
set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse,g2_diff);
|
| 346 |
+
|
| 347 |
+
float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
|
| 348 |
+
float g3_diff = h_diff * (h3_minus1_data - x_data);
|
| 349 |
+
set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse,g3_diff);
|
| 350 |
+
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
__global__ void backward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
|
| 356 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 357 |
+
|
| 358 |
+
int hc_count = height * channels;
|
| 359 |
+
|
| 360 |
+
int n,c,h,w;
|
| 361 |
+
int temp=index;
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
w = T;
|
| 365 |
+
n = temp / hc_count;
|
| 366 |
+
temp = temp % hc_count;
|
| 367 |
+
c = temp / height;
|
| 368 |
+
temp = temp % height;
|
| 369 |
+
h = temp;
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 373 |
+
|
| 374 |
+
//h(t)_diff = top(t)_diff
|
| 375 |
+
float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
|
| 376 |
+
|
| 377 |
+
///h(t)_diff += h(t+1)_diff * g(t+1) if t<T
|
| 378 |
+
float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w-1);
|
| 379 |
+
float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 380 |
+
|
| 381 |
+
float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w-1);
|
| 382 |
+
float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
|
| 383 |
+
|
| 384 |
+
float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w-1);
|
| 385 |
+
float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 386 |
+
|
| 387 |
+
h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
|
| 391 |
+
|
| 392 |
+
float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 393 |
+
float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
|
| 394 |
+
float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 395 |
+
float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
|
| 396 |
+
set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
|
| 397 |
+
|
| 398 |
+
// g_diff = h_diff * (h_data(t-1) - x_data)
|
| 399 |
+
float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
|
| 400 |
+
float g1_diff = h_diff * (h1_minus1_data - x_data);
|
| 401 |
+
set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse,g1_diff);
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
|
| 405 |
+
float g2_diff = h_diff * (h2_minus1_data - x_data);
|
| 406 |
+
set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse,g2_diff);
|
| 407 |
+
|
| 408 |
+
float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
|
| 409 |
+
float g3_diff = h_diff * (h3_minus1_data - x_data);
|
| 410 |
+
set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse,g3_diff);
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
__global__ void backward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
|
| 417 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
int wc_count = width * channels;
|
| 421 |
+
|
| 422 |
+
int n,c,h,w;
|
| 423 |
+
int temp=index;
|
| 424 |
+
h = T;
|
| 425 |
+
n = temp / wc_count;
|
| 426 |
+
temp = temp % wc_count;
|
| 427 |
+
c = temp / width;
|
| 428 |
+
temp = temp % width;
|
| 429 |
+
w = temp;
|
| 430 |
+
|
| 431 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 432 |
+
|
| 433 |
+
float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
|
| 434 |
+
|
| 435 |
+
//h(t)_diff += h(t+1)_diff * g(t+1) if t<T
|
| 436 |
+
float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w-1);
|
| 437 |
+
float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 438 |
+
|
| 439 |
+
float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w);
|
| 440 |
+
float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
|
| 441 |
+
|
| 442 |
+
float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h+1,w+1);
|
| 443 |
+
float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 444 |
+
|
| 445 |
+
h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
//x(t)_diff=(1-g(t))*h(t)_diff
|
| 452 |
+
float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 453 |
+
float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
|
| 454 |
+
float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 455 |
+
float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
|
| 456 |
+
set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
// g_diff = h_diff * (h_data(t-1) - x_data)
|
| 461 |
+
float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
|
| 462 |
+
float g1_diff = h_diff * (h1_minus1_data - x_data);
|
| 463 |
+
set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse,g1_diff);
|
| 464 |
+
|
| 465 |
+
float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
|
| 466 |
+
float g2_diff = h_diff * (h2_minus1_data - x_data);
|
| 467 |
+
set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse,g2_diff);
|
| 468 |
+
|
| 469 |
+
float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
|
| 470 |
+
float g3_diff = h_diff * (h3_minus1_data - x_data);
|
| 471 |
+
set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse,g3_diff);
|
| 472 |
+
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
__global__ void backward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
|
| 477 |
+
CUDA_1D_KERNEL_LOOP(index, count) {
|
| 478 |
+
|
| 479 |
+
int wc_count = width * channels;
|
| 480 |
+
|
| 481 |
+
int n,c,h,w;
|
| 482 |
+
int temp=index;
|
| 483 |
+
h = T;
|
| 484 |
+
n = temp / wc_count;
|
| 485 |
+
temp = temp % wc_count;
|
| 486 |
+
c = temp / width;
|
| 487 |
+
temp = temp % width;
|
| 488 |
+
w = temp;
|
| 489 |
+
|
| 490 |
+
float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
|
| 491 |
+
|
| 492 |
+
//h(t)_diff = top(t)_diff
|
| 493 |
+
float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
|
| 494 |
+
|
| 495 |
+
//h(t)_diff += h(t+1)_diff * g(t+1) if t<T
|
| 496 |
+
float add1_h3_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w-1);
|
| 497 |
+
float add1_g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
|
| 498 |
+
|
| 499 |
+
float add1_h2_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w);
|
| 500 |
+
float add1_g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
|
| 501 |
+
|
| 502 |
+
float add1_h1_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h-1,w+1);
|
| 503 |
+
float add1_g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
|
| 504 |
+
|
| 505 |
+
h_diff = h_diff + add1_h3_diff * add1_g3_data + add1_h2_diff * add1_g2_data + add1_h1_diff * add1_g1_data;
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
set_data_sf(Hdiff,num,channels,height,width,n,c,h,w,h_diff);
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
//x(t)_diff=(1-g(t))*h(t)_diff
|
| 512 |
+
float g1_data = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
|
| 513 |
+
float g2_data = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
|
| 514 |
+
float g3_data = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
|
| 515 |
+
float x_diff = (1- g1_data -g2_data -g3_data) * h_diff;
|
| 516 |
+
set_data_sf(X_diff,num,channels,height,width,n,c,h,w,x_diff);
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
// g_diff = h_diff * (h_data(t-1) - x_data)
|
| 520 |
+
float h1_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
|
| 521 |
+
float g1_diff = h_diff * (h1_minus1_data - x_data);
|
| 522 |
+
set_gate_sf(G1_diff,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse,g1_diff);
|
| 523 |
+
|
| 524 |
+
//float g2_diff = h_diff * g2_idx * x_data * -1;
|
| 525 |
+
float h2_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
|
| 526 |
+
float g2_diff = h_diff * (h2_minus1_data - x_data);
|
| 527 |
+
set_gate_sf(G2_diff,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse,g2_diff);
|
| 528 |
+
|
| 529 |
+
//float g3_diff = h_diff * g3_idx * x_data * -1;
|
| 530 |
+
float h3_minus1_data = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
|
| 531 |
+
float g3_diff = h_diff * (h3_minus1_data - x_data);
|
| 532 |
+
set_gate_sf(G3_diff,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse,g3_diff);
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
|
| 540 |
+
{
|
| 541 |
+
int count = height_ * channels_ * num_;
|
| 542 |
+
int kThreadsPerBlock = 1024;
|
| 543 |
+
cudaError_t err;
|
| 544 |
+
|
| 545 |
+
for(int t=0; t<width_; t++) {
|
| 546 |
+
forward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
|
| 547 |
+
|
| 548 |
+
err = cudaGetLastError();
|
| 549 |
+
if(cudaSuccess != err)
|
| 550 |
+
{
|
| 551 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 552 |
+
exit( -1 );
|
| 553 |
+
}
|
| 554 |
+
}
|
| 555 |
+
return 1;
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
|
| 559 |
+
{
|
| 560 |
+
int count = height_ * channels_ * num_;
|
| 561 |
+
int kThreadsPerBlock = 1024;
|
| 562 |
+
cudaError_t err;
|
| 563 |
+
|
| 564 |
+
for(int t = width_ - 1; t >= 0; t--) {
|
| 565 |
+
forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
|
| 566 |
+
|
| 567 |
+
err = cudaGetLastError();
|
| 568 |
+
if(cudaSuccess != err)
|
| 569 |
+
{
|
| 570 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 571 |
+
exit( -1 );
|
| 572 |
+
}
|
| 573 |
+
}
|
| 574 |
+
return 1;
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
|
| 578 |
+
{
|
| 579 |
+
int count = width_ * channels_ * num_;
|
| 580 |
+
int kThreadsPerBlock = 1024;
|
| 581 |
+
cudaError_t err;
|
| 582 |
+
|
| 583 |
+
for(int t=0; t< height_; t++) {
|
| 584 |
+
forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
|
| 585 |
+
|
| 586 |
+
err = cudaGetLastError();
|
| 587 |
+
if(cudaSuccess != err)
|
| 588 |
+
{
|
| 589 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 590 |
+
exit( -1 );
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
return 1;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
|
| 597 |
+
{
|
| 598 |
+
int count = width_ * channels_ * num_;
|
| 599 |
+
int kThreadsPerBlock = 1024;
|
| 600 |
+
cudaError_t err;
|
| 601 |
+
|
| 602 |
+
for(int t = height_-1; t >= 0; t--) {
|
| 603 |
+
forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
|
| 604 |
+
|
| 605 |
+
err = cudaGetLastError();
|
| 606 |
+
if(cudaSuccess != err)
|
| 607 |
+
{
|
| 608 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 609 |
+
exit( -1 );
|
| 610 |
+
}
|
| 611 |
+
}
|
| 612 |
+
return 1;
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
|
| 616 |
+
{
|
| 617 |
+
int count = height_ * channels_ * num_;
|
| 618 |
+
int kThreadsPerBlock = 1024;
|
| 619 |
+
cudaError_t err;
|
| 620 |
+
|
| 621 |
+
for(int t = width_ -1; t>=0; t--)
|
| 622 |
+
{
|
| 623 |
+
backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
|
| 624 |
+
|
| 625 |
+
err = cudaGetLastError();
|
| 626 |
+
if(cudaSuccess != err)
|
| 627 |
+
{
|
| 628 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 629 |
+
exit( -1 );
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
return 1;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
|
| 636 |
+
{
|
| 637 |
+
int count = height_ * channels_ * num_;
|
| 638 |
+
int kThreadsPerBlock = 1024;
|
| 639 |
+
cudaError_t err;
|
| 640 |
+
|
| 641 |
+
for(int t = 0; t<width_; t++)
|
| 642 |
+
{
|
| 643 |
+
backward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
|
| 644 |
+
|
| 645 |
+
err = cudaGetLastError();
|
| 646 |
+
if(cudaSuccess != err)
|
| 647 |
+
{
|
| 648 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 649 |
+
exit( -1 );
|
| 650 |
+
}
|
| 651 |
+
}
|
| 652 |
+
return 1;
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
|
| 656 |
+
{
|
| 657 |
+
int count = width_ * channels_ * num_;
|
| 658 |
+
int kThreadsPerBlock = 1024;
|
| 659 |
+
cudaError_t err;
|
| 660 |
+
|
| 661 |
+
for(int t = height_-1; t>=0; t--)
|
| 662 |
+
{
|
| 663 |
+
backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
|
| 664 |
+
|
| 665 |
+
err = cudaGetLastError();
|
| 666 |
+
if(cudaSuccess != err)
|
| 667 |
+
{
|
| 668 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 669 |
+
exit( -1 );
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
return 1;
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
|
| 676 |
+
{
|
| 677 |
+
int count = width_ * channels_ * num_;
|
| 678 |
+
int kThreadsPerBlock = 1024;
|
| 679 |
+
cudaError_t err;
|
| 680 |
+
|
| 681 |
+
for(int t = 0; t<height_; t++)
|
| 682 |
+
{
|
| 683 |
+
backward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
|
| 684 |
+
|
| 685 |
+
err = cudaGetLastError();
|
| 686 |
+
if(cudaSuccess != err)
|
| 687 |
+
{
|
| 688 |
+
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
|
| 689 |
+
exit( -1 );
|
| 690 |
+
}
|
| 691 |
+
}
|
| 692 |
+
return 1;
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
#ifdef __cplusplus
|
| 696 |
+
}
|
| 697 |
+
#endif
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o
ADDED
|
Binary file (98.8 kB). View file
|
|
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef _GATERECURRENT2DNOIND_KERNEL
|
| 2 |
+
#define _GATERECURRENT2DNOIND_KERNEL
|
| 3 |
+
|
| 4 |
+
#ifdef __cplusplus
|
| 5 |
+
extern "C" {
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
|
| 9 |
+
|
| 10 |
+
int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
|
| 11 |
+
|
| 12 |
+
int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
|
| 13 |
+
|
| 14 |
+
int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
|
| 15 |
+
|
| 16 |
+
int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
|
| 17 |
+
|
| 18 |
+
int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
|
| 19 |
+
|
| 20 |
+
int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
|
| 21 |
+
|
| 22 |
+
int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
|
| 23 |
+
|
| 24 |
+
#ifdef __cplusplus
|
| 25 |
+
}
|
| 26 |
+
#endif
|
| 27 |
+
|
| 28 |
+
#endif
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// gaterecurrent2dnoind_cuda.c
|
| 2 |
+
#include <THC/THC.h>
|
| 3 |
+
#include <math.h>
|
| 4 |
+
#include "gaterecurrent2dnoind_cuda.h"
|
| 5 |
+
#include "cuda/gaterecurrent2dnoind_kernel.h"
|
| 6 |
+
|
| 7 |
+
// typedef bool boolean;
|
| 8 |
+
|
| 9 |
+
// this symbol will be resolved automatically from PyTorch libs
|
| 10 |
+
extern THCState *state;
|
| 11 |
+
|
| 12 |
+
int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output)
|
| 13 |
+
{
|
| 14 |
+
// Grab the input tensor to flat
|
| 15 |
+
float * X_data = THCudaTensor_data(state, X);
|
| 16 |
+
float * G1_data = THCudaTensor_data(state, G1);
|
| 17 |
+
float * G2_data = THCudaTensor_data(state, G2);
|
| 18 |
+
float * G3_data = THCudaTensor_data(state, G3);
|
| 19 |
+
float * H_data = THCudaTensor_data(state, output);
|
| 20 |
+
|
| 21 |
+
// dimensions
|
| 22 |
+
int num_ = THCudaTensor_size(state, X, 0);
|
| 23 |
+
int channels_ = THCudaTensor_size(state, X, 1);
|
| 24 |
+
int height_ = THCudaTensor_size(state, X, 2);
|
| 25 |
+
int width_ = THCudaTensor_size(state, X, 3);
|
| 26 |
+
|
| 27 |
+
cudaStream_t stream = THCState_getCurrentStream(state);
|
| 28 |
+
|
| 29 |
+
if(horizontal_ && !reverse_) // left to right
|
| 30 |
+
{
|
| 31 |
+
//const int count = height_ * channels_ * num_;
|
| 32 |
+
Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
|
| 33 |
+
}
|
| 34 |
+
else if(horizontal_ && reverse_) // right to left
|
| 35 |
+
{
|
| 36 |
+
Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
|
| 37 |
+
}
|
| 38 |
+
else if(!horizontal_ && !reverse_) // top to bottom
|
| 39 |
+
{
|
| 40 |
+
Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
|
| 41 |
+
}
|
| 42 |
+
else
|
| 43 |
+
{
|
| 44 |
+
Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
return 1;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad)
|
| 51 |
+
{
|
| 52 |
+
//Grab the input tensor to flat
|
| 53 |
+
float * X_data = THCudaTensor_data(state, X);
|
| 54 |
+
float * G1_data = THCudaTensor_data(state, G1);
|
| 55 |
+
float * G2_data = THCudaTensor_data(state, G2);
|
| 56 |
+
float * G3_data = THCudaTensor_data(state, G3);
|
| 57 |
+
float * H_data = THCudaTensor_data(state, top);
|
| 58 |
+
|
| 59 |
+
float * H_diff = THCudaTensor_data(state, top_grad);
|
| 60 |
+
|
| 61 |
+
float * X_diff = THCudaTensor_data(state, X_grad);
|
| 62 |
+
float * G1_diff = THCudaTensor_data(state, G1_grad);
|
| 63 |
+
float * G2_diff = THCudaTensor_data(state, G2_grad);
|
| 64 |
+
float * G3_diff = THCudaTensor_data(state, G3_grad);
|
| 65 |
+
|
| 66 |
+
// dimensions
|
| 67 |
+
int num_ = THCudaTensor_size(state, X, 0);
|
| 68 |
+
int channels_ = THCudaTensor_size(state, X, 1);
|
| 69 |
+
int height_ = THCudaTensor_size(state, X, 2);
|
| 70 |
+
int width_ = THCudaTensor_size(state, X, 3);
|
| 71 |
+
|
| 72 |
+
cudaStream_t stream = THCState_getCurrentStream(state);
|
| 73 |
+
|
| 74 |
+
if(horizontal_ && ! reverse_) //left to right
|
| 75 |
+
{
|
| 76 |
+
Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
|
| 77 |
+
}
|
| 78 |
+
else if(horizontal_ && reverse_) //right to left
|
| 79 |
+
{
|
| 80 |
+
Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
|
| 81 |
+
}
|
| 82 |
+
else if(!horizontal_ && !reverse_) //top to bottom
|
| 83 |
+
{
|
| 84 |
+
Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
|
| 85 |
+
}
|
| 86 |
+
else {
|
| 87 |
+
Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return 1;
|
| 91 |
+
}
|
graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
// #include <stdbool.h>
|
| 3 |
+
// gaterecurrent2dnoind_cuda.h
|
| 4 |
+
int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output);
|
| 5 |
+
|
| 6 |
+
int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff);
|
graph_networks/LinearStyleTransfer/libs/smooth_filter.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py
|
| 3 |
+
"""
|
| 4 |
+
src = '''
|
| 5 |
+
#include "/usr/local/cuda/include/math_functions.h"
|
| 6 |
+
#define TB 256
|
| 7 |
+
#define EPS 1e-7
|
| 8 |
+
|
| 9 |
+
__device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
|
| 10 |
+
double m[16], inv[16];
|
| 11 |
+
for (int i = 0; i < 4; i++) {
|
| 12 |
+
for (int j = 0; j < 4; j++) {
|
| 13 |
+
m[i * 4 + j] = m_in[i][j];
|
| 14 |
+
}
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
inv[0] = m[5] * m[10] * m[15] -
|
| 18 |
+
m[5] * m[11] * m[14] -
|
| 19 |
+
m[9] * m[6] * m[15] +
|
| 20 |
+
m[9] * m[7] * m[14] +
|
| 21 |
+
m[13] * m[6] * m[11] -
|
| 22 |
+
m[13] * m[7] * m[10];
|
| 23 |
+
|
| 24 |
+
inv[4] = -m[4] * m[10] * m[15] +
|
| 25 |
+
m[4] * m[11] * m[14] +
|
| 26 |
+
m[8] * m[6] * m[15] -
|
| 27 |
+
m[8] * m[7] * m[14] -
|
| 28 |
+
m[12] * m[6] * m[11] +
|
| 29 |
+
m[12] * m[7] * m[10];
|
| 30 |
+
|
| 31 |
+
inv[8] = m[4] * m[9] * m[15] -
|
| 32 |
+
m[4] * m[11] * m[13] -
|
| 33 |
+
m[8] * m[5] * m[15] +
|
| 34 |
+
m[8] * m[7] * m[13] +
|
| 35 |
+
m[12] * m[5] * m[11] -
|
| 36 |
+
m[12] * m[7] * m[9];
|
| 37 |
+
|
| 38 |
+
inv[12] = -m[4] * m[9] * m[14] +
|
| 39 |
+
m[4] * m[10] * m[13] +
|
| 40 |
+
m[8] * m[5] * m[14] -
|
| 41 |
+
m[8] * m[6] * m[13] -
|
| 42 |
+
m[12] * m[5] * m[10] +
|
| 43 |
+
m[12] * m[6] * m[9];
|
| 44 |
+
|
| 45 |
+
inv[1] = -m[1] * m[10] * m[15] +
|
| 46 |
+
m[1] * m[11] * m[14] +
|
| 47 |
+
m[9] * m[2] * m[15] -
|
| 48 |
+
m[9] * m[3] * m[14] -
|
| 49 |
+
m[13] * m[2] * m[11] +
|
| 50 |
+
m[13] * m[3] * m[10];
|
| 51 |
+
|
| 52 |
+
inv[5] = m[0] * m[10] * m[15] -
|
| 53 |
+
m[0] * m[11] * m[14] -
|
| 54 |
+
m[8] * m[2] * m[15] +
|
| 55 |
+
m[8] * m[3] * m[14] +
|
| 56 |
+
m[12] * m[2] * m[11] -
|
| 57 |
+
m[12] * m[3] * m[10];
|
| 58 |
+
|
| 59 |
+
inv[9] = -m[0] * m[9] * m[15] +
|
| 60 |
+
m[0] * m[11] * m[13] +
|
| 61 |
+
m[8] * m[1] * m[15] -
|
| 62 |
+
m[8] * m[3] * m[13] -
|
| 63 |
+
m[12] * m[1] * m[11] +
|
| 64 |
+
m[12] * m[3] * m[9];
|
| 65 |
+
|
| 66 |
+
inv[13] = m[0] * m[9] * m[14] -
|
| 67 |
+
m[0] * m[10] * m[13] -
|
| 68 |
+
m[8] * m[1] * m[14] +
|
| 69 |
+
m[8] * m[2] * m[13] +
|
| 70 |
+
m[12] * m[1] * m[10] -
|
| 71 |
+
m[12] * m[2] * m[9];
|
| 72 |
+
|
| 73 |
+
inv[2] = m[1] * m[6] * m[15] -
|
| 74 |
+
m[1] * m[7] * m[14] -
|
| 75 |
+
m[5] * m[2] * m[15] +
|
| 76 |
+
m[5] * m[3] * m[14] +
|
| 77 |
+
m[13] * m[2] * m[7] -
|
| 78 |
+
m[13] * m[3] * m[6];
|
| 79 |
+
|
| 80 |
+
inv[6] = -m[0] * m[6] * m[15] +
|
| 81 |
+
m[0] * m[7] * m[14] +
|
| 82 |
+
m[4] * m[2] * m[15] -
|
| 83 |
+
m[4] * m[3] * m[14] -
|
| 84 |
+
m[12] * m[2] * m[7] +
|
| 85 |
+
m[12] * m[3] * m[6];
|
| 86 |
+
|
| 87 |
+
inv[10] = m[0] * m[5] * m[15] -
|
| 88 |
+
m[0] * m[7] * m[13] -
|
| 89 |
+
m[4] * m[1] * m[15] +
|
| 90 |
+
m[4] * m[3] * m[13] +
|
| 91 |
+
m[12] * m[1] * m[7] -
|
| 92 |
+
m[12] * m[3] * m[5];
|
| 93 |
+
|
| 94 |
+
inv[14] = -m[0] * m[5] * m[14] +
|
| 95 |
+
m[0] * m[6] * m[13] +
|
| 96 |
+
m[4] * m[1] * m[14] -
|
| 97 |
+
m[4] * m[2] * m[13] -
|
| 98 |
+
m[12] * m[1] * m[6] +
|
| 99 |
+
m[12] * m[2] * m[5];
|
| 100 |
+
|
| 101 |
+
inv[3] = -m[1] * m[6] * m[11] +
|
| 102 |
+
m[1] * m[7] * m[10] +
|
| 103 |
+
m[5] * m[2] * m[11] -
|
| 104 |
+
m[5] * m[3] * m[10] -
|
| 105 |
+
m[9] * m[2] * m[7] +
|
| 106 |
+
m[9] * m[3] * m[6];
|
| 107 |
+
|
| 108 |
+
inv[7] = m[0] * m[6] * m[11] -
|
| 109 |
+
m[0] * m[7] * m[10] -
|
| 110 |
+
m[4] * m[2] * m[11] +
|
| 111 |
+
m[4] * m[3] * m[10] +
|
| 112 |
+
m[8] * m[2] * m[7] -
|
| 113 |
+
m[8] * m[3] * m[6];
|
| 114 |
+
|
| 115 |
+
inv[11] = -m[0] * m[5] * m[11] +
|
| 116 |
+
m[0] * m[7] * m[9] +
|
| 117 |
+
m[4] * m[1] * m[11] -
|
| 118 |
+
m[4] * m[3] * m[9] -
|
| 119 |
+
m[8] * m[1] * m[7] +
|
| 120 |
+
m[8] * m[3] * m[5];
|
| 121 |
+
|
| 122 |
+
inv[15] = m[0] * m[5] * m[10] -
|
| 123 |
+
m[0] * m[6] * m[9] -
|
| 124 |
+
m[4] * m[1] * m[10] +
|
| 125 |
+
m[4] * m[2] * m[9] +
|
| 126 |
+
m[8] * m[1] * m[6] -
|
| 127 |
+
m[8] * m[2] * m[5];
|
| 128 |
+
|
| 129 |
+
double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
|
| 130 |
+
|
| 131 |
+
if (abs(det) < 1e-9) {
|
| 132 |
+
return false;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
det = 1.0 / det;
|
| 137 |
+
|
| 138 |
+
for (int i = 0; i < 4; i++) {
|
| 139 |
+
for (int j = 0; j < 4; j++) {
|
| 140 |
+
inv_out[i][j] = inv[i * 4 + j] * det;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
return true;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
extern "C"
|
| 148 |
+
__global__ void best_local_affine_kernel(
|
| 149 |
+
float *output, float *input, float *affine_model,
|
| 150 |
+
int h, int w, float epsilon, int kernel_radius
|
| 151 |
+
)
|
| 152 |
+
{
|
| 153 |
+
int size = h * w;
|
| 154 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 155 |
+
|
| 156 |
+
if (id < size) {
|
| 157 |
+
int x = id % w, y = id / w;
|
| 158 |
+
|
| 159 |
+
double Mt_M[4][4] = {}; // 4x4
|
| 160 |
+
double invMt_M[4][4] = {};
|
| 161 |
+
double Mt_S[3][4] = {}; // RGB -> 1x4
|
| 162 |
+
double A[3][4] = {};
|
| 163 |
+
for (int i = 0; i < 4; i++)
|
| 164 |
+
for (int j = 0; j < 4; j++) {
|
| 165 |
+
Mt_M[i][j] = 0, invMt_M[i][j] = 0;
|
| 166 |
+
if (i != 3) {
|
| 167 |
+
Mt_S[i][j] = 0, A[i][j] = 0;
|
| 168 |
+
if (i == j)
|
| 169 |
+
Mt_M[i][j] = 1e-3;
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
| 174 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
| 175 |
+
|
| 176 |
+
int xx = x + dx, yy = y + dy;
|
| 177 |
+
int id2 = yy * w + xx;
|
| 178 |
+
|
| 179 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
| 180 |
+
|
| 181 |
+
Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
|
| 182 |
+
Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
|
| 183 |
+
Mt_M[0][2] += input[id2 + 2*size] * input[id2];
|
| 184 |
+
Mt_M[0][3] += input[id2 + 2*size];
|
| 185 |
+
|
| 186 |
+
Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
|
| 187 |
+
Mt_M[1][1] += input[id2 + size] * input[id2 + size];
|
| 188 |
+
Mt_M[1][2] += input[id2 + size] * input[id2];
|
| 189 |
+
Mt_M[1][3] += input[id2 + size];
|
| 190 |
+
|
| 191 |
+
Mt_M[2][0] += input[id2] * input[id2 + 2*size];
|
| 192 |
+
Mt_M[2][1] += input[id2] * input[id2 + size];
|
| 193 |
+
Mt_M[2][2] += input[id2] * input[id2];
|
| 194 |
+
Mt_M[2][3] += input[id2];
|
| 195 |
+
|
| 196 |
+
Mt_M[3][0] += input[id2 + 2*size];
|
| 197 |
+
Mt_M[3][1] += input[id2 + size];
|
| 198 |
+
Mt_M[3][2] += input[id2];
|
| 199 |
+
Mt_M[3][3] += 1;
|
| 200 |
+
|
| 201 |
+
Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
|
| 202 |
+
Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
|
| 203 |
+
Mt_S[0][2] += input[id2] * output[id2 + 2*size];
|
| 204 |
+
Mt_S[0][3] += output[id2 + 2*size];
|
| 205 |
+
|
| 206 |
+
Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
|
| 207 |
+
Mt_S[1][1] += input[id2 + size] * output[id2 + size];
|
| 208 |
+
Mt_S[1][2] += input[id2] * output[id2 + size];
|
| 209 |
+
Mt_S[1][3] += output[id2 + size];
|
| 210 |
+
|
| 211 |
+
Mt_S[2][0] += input[id2 + 2*size] * output[id2];
|
| 212 |
+
Mt_S[2][1] += input[id2 + size] * output[id2];
|
| 213 |
+
Mt_S[2][2] += input[id2] * output[id2];
|
| 214 |
+
Mt_S[2][3] += output[id2];
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
bool success = InverseMat4x4(Mt_M, invMt_M);
|
| 220 |
+
|
| 221 |
+
for (int i = 0; i < 3; i++) {
|
| 222 |
+
for (int j = 0; j < 4; j++) {
|
| 223 |
+
for (int k = 0; k < 4; k++) {
|
| 224 |
+
A[i][j] += invMt_M[j][k] * Mt_S[i][k];
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
for (int i = 0; i < 3; i++) {
|
| 230 |
+
for (int j = 0; j < 4; j++) {
|
| 231 |
+
int affine_id = i * 4 + j;
|
| 232 |
+
affine_model[12 * id + affine_id] = A[i][j];
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
return ;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
extern "C"
|
| 240 |
+
__global__ void bilateral_smooth_kernel(
|
| 241 |
+
float *affine_model, float *filtered_affine_model, float *guide,
|
| 242 |
+
int h, int w, int kernel_radius, float sigma1, float sigma2
|
| 243 |
+
)
|
| 244 |
+
{
|
| 245 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 246 |
+
int size = h * w;
|
| 247 |
+
if (id < size) {
|
| 248 |
+
int x = id % w;
|
| 249 |
+
int y = id / w;
|
| 250 |
+
|
| 251 |
+
double sum_affine[12] = {};
|
| 252 |
+
double sum_weight = 0;
|
| 253 |
+
for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
|
| 254 |
+
for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
|
| 255 |
+
int yy = y + dy, xx = x + dx;
|
| 256 |
+
int id2 = yy * w + xx;
|
| 257 |
+
if (0 <= xx && xx < w && 0 <= yy && yy < h) {
|
| 258 |
+
float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
|
| 259 |
+
float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
|
| 260 |
+
float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
|
| 261 |
+
float color_diff_sqr =
|
| 262 |
+
(color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
|
| 263 |
+
|
| 264 |
+
float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
|
| 265 |
+
float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
|
| 266 |
+
float weight = v1 * v2;
|
| 267 |
+
|
| 268 |
+
for (int i = 0; i < 3; i++) {
|
| 269 |
+
for (int j = 0; j < 4; j++) {
|
| 270 |
+
int affine_id = i * 4 + j;
|
| 271 |
+
sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
sum_weight += weight;
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
for (int i = 0; i < 3; i++) {
|
| 280 |
+
for (int j = 0; j < 4; j++) {
|
| 281 |
+
int affine_id = i * 4 + j;
|
| 282 |
+
filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
|
| 283 |
+
}
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
return ;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
extern "C"
|
| 291 |
+
__global__ void reconstruction_best_kernel(
|
| 292 |
+
float *input, float *filtered_affine_model, float *filtered_best_output,
|
| 293 |
+
int h, int w
|
| 294 |
+
)
|
| 295 |
+
{
|
| 296 |
+
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
| 297 |
+
int size = h * w;
|
| 298 |
+
if (id < size) {
|
| 299 |
+
double out1 =
|
| 300 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
|
| 301 |
+
input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
|
| 302 |
+
input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
|
| 303 |
+
filtered_affine_model[id*12 + 3]; //A[0][3];
|
| 304 |
+
double out2 =
|
| 305 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
|
| 306 |
+
input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
|
| 307 |
+
input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
|
| 308 |
+
filtered_affine_model[id*12 + 7]; //A[1][3];
|
| 309 |
+
double out3 =
|
| 310 |
+
input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
|
| 311 |
+
input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
|
| 312 |
+
input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
|
| 313 |
+
filtered_affine_model[id*12 + 11]; // A[2][3];
|
| 314 |
+
|
| 315 |
+
filtered_best_output[id] = out1;
|
| 316 |
+
filtered_best_output[id + size] = out2;
|
| 317 |
+
filtered_best_output[id + 2*size] = out3;
|
| 318 |
+
}
|
| 319 |
+
return ;
|
| 320 |
+
}
|
| 321 |
+
'''
|
| 322 |
+
|
| 323 |
+
import cv2
|
| 324 |
+
import torch
|
| 325 |
+
import numpy as np
|
| 326 |
+
from PIL import Image
|
| 327 |
+
from cupy.cuda import function
|
| 328 |
+
from pynvrtc.compiler import Program
|
| 329 |
+
from collections import namedtuple
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
|
| 333 |
+
# program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
|
| 334 |
+
# ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
|
| 335 |
+
program = Program(src, 'best_local_affine_kernel.cu')
|
| 336 |
+
ptx = program.compile(['-I/usr/local/cuda/include'])
|
| 337 |
+
m = function.Module()
|
| 338 |
+
m.load(bytes(ptx.encode()))
|
| 339 |
+
|
| 340 |
+
_reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
|
| 341 |
+
_bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
|
| 342 |
+
_best_local_affine_kernel = m.get_function('best_local_affine_kernel')
|
| 343 |
+
Stream = namedtuple('Stream', ['ptr'])
|
| 344 |
+
s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
|
| 345 |
+
|
| 346 |
+
filter_radius = f_r
|
| 347 |
+
sigma1 = filter_radius / 3
|
| 348 |
+
sigma2 = f_e
|
| 349 |
+
radius = (patch - 1) / 2
|
| 350 |
+
|
| 351 |
+
filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
|
| 352 |
+
affine_model = torch.zeros((h * w, 12)).cuda()
|
| 353 |
+
filtered_affine_model =torch.zeros((h * w, 12)).cuda()
|
| 354 |
+
|
| 355 |
+
input_ = torch.from_numpy(input_cpu).cuda()
|
| 356 |
+
output_ = torch.from_numpy(output_cpu).cuda()
|
| 357 |
+
_best_local_affine_kernel(
|
| 358 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 359 |
+
block=(256, 1, 1),
|
| 360 |
+
args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
|
| 361 |
+
np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
_bilateral_smooth_kernel(
|
| 365 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 366 |
+
block=(256, 1, 1),
|
| 367 |
+
args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
_reconstruction_best_kernel(
|
| 371 |
+
grid=(int((h * w) / 256 + 1), 1),
|
| 372 |
+
block=(256, 1, 1),
|
| 373 |
+
args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
|
| 374 |
+
np.int32(h), np.int32(w)], stream=s
|
| 375 |
+
)
|
| 376 |
+
numpy_filtered_best_output = filtered_best_output.cpu().numpy()
|
| 377 |
+
return numpy_filtered_best_output
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
|
| 381 |
+
'''
|
| 382 |
+
:param initImg: intermediate output. Either image path or PIL Image
|
| 383 |
+
:param contentImg: content image output. Either path or PIL Image
|
| 384 |
+
:return: stylized output image. PIL Image
|
| 385 |
+
'''
|
| 386 |
+
if type(initImg) == str:
|
| 387 |
+
initImg = Image.open(initImg).convert("RGB")
|
| 388 |
+
best_image_bgr = np.array(initImg, dtype=np.float32)
|
| 389 |
+
bW, bH, bC = best_image_bgr.shape
|
| 390 |
+
best_image_bgr = best_image_bgr[:, :, ::-1]
|
| 391 |
+
best_image_bgr = best_image_bgr.transpose((2, 0, 1))
|
| 392 |
+
|
| 393 |
+
if type(contentImg) == str:
|
| 394 |
+
contentImg = Image.open(contentImg).convert("RGB")
|
| 395 |
+
content_input = contentImg.resize((bH,bW))
|
| 396 |
+
else:
|
| 397 |
+
content_input = cv2.resize(contentImg,(bH,bW))
|
| 398 |
+
content_input = np.array(content_input, dtype=np.float32)
|
| 399 |
+
content_input = content_input[:, :, ::-1]
|
| 400 |
+
content_input = content_input.transpose((2, 0, 1))
|
| 401 |
+
input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
|
| 402 |
+
_, H, W = np.shape(input_)
|
| 403 |
+
output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
|
| 404 |
+
best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
|
| 405 |
+
best_ = best_.transpose(1, 2, 0)
|
| 406 |
+
result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
|
| 407 |
+
return result
|
graph_networks/LinearStyleTransfer/libs/utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import scipy.misc
|
| 7 |
+
import numpy as np
|
| 8 |
+
import scipy.sparse
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import scipy.sparse.linalg
|
| 11 |
+
from cv2.ximgproc import jointBilateralFilter
|
| 12 |
+
from numpy.lib.stride_tricks import as_strided
|
| 13 |
+
|
| 14 |
+
def whiten(cF):
|
| 15 |
+
cFSize = cF.size()
|
| 16 |
+
c_mean = torch.mean(cF,1) # c x (h x w)
|
| 17 |
+
c_mean = c_mean.unsqueeze(1).expand_as(cF)
|
| 18 |
+
cF = cF - c_mean
|
| 19 |
+
|
| 20 |
+
contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
|
| 21 |
+
c_u,c_e,c_v = torch.svd(contentConv,some=False)
|
| 22 |
+
|
| 23 |
+
k_c = cFSize[0]
|
| 24 |
+
for i in range(cFSize[0]):
|
| 25 |
+
if c_e[i] < 0.00001:
|
| 26 |
+
k_c = i
|
| 27 |
+
break
|
| 28 |
+
|
| 29 |
+
c_d = (c_e[0:k_c]).pow(-0.5)
|
| 30 |
+
step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
|
| 31 |
+
step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
|
| 32 |
+
whiten_cF = torch.mm(step2,cF)
|
| 33 |
+
return whiten_cF
|
| 34 |
+
|
| 35 |
+
def numpy2cv2(cont,style,prop,width,height):
|
| 36 |
+
cont = cont.transpose((1,2,0))
|
| 37 |
+
cont = cont[...,::-1]
|
| 38 |
+
cont = cont * 255
|
| 39 |
+
cont = cv2.resize(cont,(width,height))
|
| 40 |
+
#cv2.resize(iimg,(width,height))
|
| 41 |
+
style = style.transpose((1,2,0))
|
| 42 |
+
style = style[...,::-1]
|
| 43 |
+
style = style * 255
|
| 44 |
+
style = cv2.resize(style,(width,height))
|
| 45 |
+
|
| 46 |
+
prop = prop.transpose((1,2,0))
|
| 47 |
+
prop = prop[...,::-1]
|
| 48 |
+
prop = prop * 255
|
| 49 |
+
prop = cv2.resize(prop,(width,height))
|
| 50 |
+
|
| 51 |
+
#return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1)
|
| 52 |
+
return prop,cont
|
| 53 |
+
|
| 54 |
+
def makeVideo(content,style,props,outf):
|
| 55 |
+
print('Stack transferred frames back to video...')
|
| 56 |
+
layers,height,width = content[0].shape
|
| 57 |
+
fourcc = cv2.VideoWriter_fourcc(*'MJPG')
|
| 58 |
+
video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height))
|
| 59 |
+
ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height))
|
| 60 |
+
for j in range(len(content)):
|
| 61 |
+
prop,cont = numpy2cv2(content[j],style,props[j],width,height)
|
| 62 |
+
cv2.imwrite('prop.png',prop)
|
| 63 |
+
cv2.imwrite('content.png',cont)
|
| 64 |
+
# TODO: this is ugly, fix this
|
| 65 |
+
imgj = cv2.imread('prop.png')
|
| 66 |
+
imgc = cv2.imread('content.png')
|
| 67 |
+
|
| 68 |
+
video.write(imgj)
|
| 69 |
+
ori_video.write(imgc)
|
| 70 |
+
# RGB or BRG, yuks
|
| 71 |
+
video.release()
|
| 72 |
+
ori_video.release()
|
| 73 |
+
os.remove('prop.png')
|
| 74 |
+
os.remove('content.png')
|
| 75 |
+
print('Transferred video saved at %s.'%outf)
|
| 76 |
+
|
| 77 |
+
def print_options(opt):
|
| 78 |
+
message = ''
|
| 79 |
+
message += '----------------- Options ---------------\n'
|
| 80 |
+
for k, v in sorted(vars(opt).items()):
|
| 81 |
+
comment = ''
|
| 82 |
+
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
| 83 |
+
message += '----------------- End -------------------'
|
| 84 |
+
print(message)
|
| 85 |
+
|
| 86 |
+
# save to the disk
|
| 87 |
+
expr_dir = os.path.join(opt.outf)
|
| 88 |
+
os.makedirs(expr_dir,exist_ok=True)
|
| 89 |
+
file_name = os.path.join(expr_dir, 'opt.txt')
|
| 90 |
+
with open(file_name, 'wt') as opt_file:
|
| 91 |
+
opt_file.write(message)
|
| 92 |
+
opt_file.write('\n')
|
graph_networks/LinearStyleTransfer/models/dec_r31.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ccc3bbc97a15e1d002d0b13523543c518dda2d0346c8f4d39c1d381d8490f68
|
| 3 |
+
size 2221888
|
graph_networks/LinearStyleTransfer/models/dec_r41.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6858f96d3d0882fa3b40652a0315928219086e4bcb0e3efbe43bd04ea631911
|
| 3 |
+
size 14023509
|
graph_networks/LinearStyleTransfer/models/r31.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8bb75be684b331a105a8fad266343556067e4eec888249f7abb693a66b5ad7e3
|
| 3 |
+
size 11564438
|