diff --git a/.gitattributes b/.gitattributes
index c6e5ee7d6af749414a50317a4d3711e699df69bd..29f914d2807014282af1141fd1c21ab7b436c57b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -40,3 +40,10 @@ FastSplatStyler_huggingface/style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -
FastSplatStyler_huggingface/style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
FastSplatStyler_huggingface/style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
FastSplatStyler_huggingface/style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
+example-broche-rose-gold.splat filter=lfs diff=lfs merge=lfs -text
+output.splat filter=lfs diff=lfs merge=lfs -text
+style_ims/style-10.jpg filter=lfs diff=lfs merge=lfs -text
+style_ims/style0.jpg filter=lfs diff=lfs merge=lfs -text
+style_ims/style2.jpg filter=lfs diff=lfs merge=lfs -text
+style_ims/style44.jpg filter=lfs diff=lfs merge=lfs -text
+style_ims/style6.jpg filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..dd3fbced4eaaebd97a05c54796d9d2afafd095b2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 ECU Computer Vision Lab
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index cc29616b9d688b4d4913dede957c11a2b6aa0363..fdf34a3b83a36b9d9b1ac8f2495a9b921473ff9c 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,36 @@
----
-title: FastSplatStyler
-emoji: π¨
-colorFrom: blue
-colorTo: yellow
-sdk: gradio
-sdk_version: 6.9.0
-app_file: app.py
-pinned: false
-license: mit
-short_description: Optimization-Free Style Transfer for 3D Gaussian Splats
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# FastSplatStyler
+Official Implementation of "Optimization-Free Style Transfer of 3D Gaussian Splats"
+
+[arXiv Paper](https://arxiv.org/abs/2508.05813)
+
+
+
+## Example Outputs
+
+Example Outputs can be visualized using the [Antimatter WebGL viewer](https://antimatter15.com/splat/) at the following links.
+
+- Broche: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_style3.splat)
+- Crystal Lamp: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-style2.splat)
+- Family Statue: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family-style-6.splat)
+- M60 Tanks: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60-style-31.splat)
+- Table: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table_style5.splat)
+- Train: [Original](https://antimatter15.com/splat/) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/train_style1.splat)
+- Truck: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck-style-21.splat)
+
+Example inputs and outputs can be downloaded from [Google Drive](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link) or [Hugging Face](https://huggingface.co/datasets/incrl/fast-splat-styler/tree/main)
+
+## Demo
+
+**Coming soon**
+
+## Install
+
+This work relies heavily on the [Pytorch](https://pytorch.org/) and [Pytorch Geometric](https://www.pyg.org/) libraries.
+
+This code was tested with Python 3.12, Pytorch 2.9.1 (CUDA Toolkit 12.8), and Pytorch Geometric 2.8 on Windows. It was also tested on Mac with the same setup (without Cuda). The provided requirements.txt file comes from the Mac configuration.
+
+This repository relies on a graph networks library that was presented in a [previous work](https://github.com/davidmhart/interpolated-selectionconv/tree/main). The library can be downloaded, with included model weights, at this [google drive link](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link). Place the "graph_networks" folder in the main directory.
+
+To stylize an example splat, run `python styletransfer_splat.py example-broche-rose-gold.splat --stylePath style_ims/style0.jpg --samplingRate 1.5`. You can change the input splat and style image to your specific use case.
+
+Supports `.splat` and `.ply` files.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bca07be619b8f6a2f19d5470c513b84217c05b
--- /dev/null
+++ b/app.py
@@ -0,0 +1,302 @@
+import gradio as gr
+import torch
+import os
+import tempfile
+import shutil
+from pathlib import Path
+from time import time
+
+# ββ Core style-transfer logic (adapted from styletransfer_splat.py) ββββββββββ
+import pointCloudToMesh as ply2M
+import utils
+import graph_io as gio
+from clusters import *
+import splat_mesh_helpers as splt
+import clusters as cl
+from torch_geometric.data import Data
+from scipy.interpolate import NearestNDInterpolator
+
+from graph_networks.LinearStyleTransfer_vgg import encoder, decoder
+from graph_networks.LinearStyleTransfer_matrix import TransformLayer
+from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
+from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
+
+
+# ββ Example assets (place your own files in ./examples/) βββββββββββββββββββββ
+EXAMPLE_SPLATS = [
+ ["example-broche-rose-gold.splat", "style_ims/style2.jpg"],
+ ["example-broche-rose-gold.splat", "style_ims/style6.jpg"],
+]
+
+
+# ββ Style-transfer function called by Gradio βββββββββββββββββββββββββββββββββ
+def run_style_transfer(
+ splat_file,
+ style_image,
+ threshold: float,
+ sampling_rate: float,
+ device_choice: str,
+ progress=gr.Progress(track_tqdm=True),
+):
+ if splat_file is None:
+ raise gr.Error("Please upload a 3D Gaussian Splat file (.ply or .splat).")
+ if style_image is None:
+ raise gr.Error("Please upload a style image.")
+
+ device = device_choice if device_choice == "cpu" else f"cuda:{device_choice}"
+
+ # ββ Parameters ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ n = 25
+ ratio = 0.25
+ depth = 3
+ style_shape = (512, 512)
+
+ logs = []
+
+ def log(msg):
+ logs.append(msg)
+ print(msg)
+ return "\n".join(logs)
+
+ # ββ 1. Load splat βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.05, desc="Loading splatβ¦")
+ splat_path = splat_file.name if hasattr(splat_file, "name") else splat_file
+ log(f"Loading splat: {splat_path}")
+
+ pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = \
+ splt.splat_unpacker_with_threshold(n, splat_path, threshold)
+
+ # ββ 2. Gaussian super-sampling ββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.15, desc="Super-samplingβ¦")
+ t0 = time()
+ if sampling_rate > 1:
+ GaussianSamples = int(pos3D_Original.shape[0] * sampling_rate)
+ pos3D, colors = splt.splat_GaussianSuperSampler(
+ pos3D_Original.clone(), colors_Original.clone(),
+ opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(),
+ GaussianSamples,
+ )
+ else:
+ pos3D, colors = pos3D_Original, colors_Original
+ log(f"Nodes in graph: {pos3D.shape[0]} ({time()-t0:.1f}s)")
+
+ # ββ 3. Graph construction βββββββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.30, desc="Building surface graphβ¦")
+ t0 = time()
+ style_ref = utils.loadImage(style_image, shape=style_shape)
+
+ normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
+ normals = torch.from_numpy(normalsNP)
+
+ up_vector = torch.tensor([[1, 1, 1]], dtype=torch.float)
+ up_vector = up_vector / torch.linalg.norm(up_vector, dim=1)
+
+ pos3D = pos3D.to(device)
+ colors = colors.to(device)
+ normals = normals.to(device)
+ up_vector = up_vector.to(device)
+
+ edge_index, directions = gh.surface2Edges(pos3D, normals, up_vector, k_neighbors=16)
+ edge_index, selections, interps = gh.edges2Selections(edge_index, directions, interpolated=True)
+
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(
+ pos3D, normals, edge_index, selections, interps,
+ ratio=ratio, up_vector=up_vector, depth=depth, device=device,
+ )
+ log(f"Graph built ({time()-t0:.1f}s)")
+
+ # ββ 4. Load networks ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.50, desc="Loading networksβ¦")
+ t0 = time()
+
+ enc_ref = encoder4()
+ dec_ref = decoder4()
+ matrix_ref = MulLayer("r41")
+
+ enc_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/vgg_r41.pth", map_location=device))
+ dec_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/dec_r41.pth", map_location=device))
+ matrix_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/r41.pth", map_location=device))
+
+ enc = encoder(padding_mode="replicate")
+ dec = decoder(padding_mode="replicate")
+ matrix = TransformLayer()
+
+ with torch.no_grad():
+ enc.copy_weights(enc_ref)
+ dec.copy_weights(dec_ref)
+ matrix.copy_weights(matrix_ref)
+
+ content = Data(
+ x=colors, clusters=clusters,
+ edge_indexes=edge_indexes,
+ selections_list=selections_list,
+ interps_list=interps_list,
+ ).to(device)
+
+ style, _ = gio.image2Graph(style_ref, depth=3, device=device)
+
+ enc = enc.to(device)
+ dec = dec.to(device)
+ matrix = matrix.to(device)
+ log(f"Networks loaded ({time()-t0:.1f}s)")
+
+ # ββ 5. Style transfer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.70, desc="Running style transferβ¦")
+ t0 = time()
+
+ with torch.no_grad():
+ cF = enc(content)
+ sF = enc(style)
+ feature, _ = matrix(
+ cF["r41"], sF["r41"],
+ content.edge_indexes[3], content.selections_list[3],
+ style.edge_indexes[3], style.selections_list[3],
+ content.interps_list[3] if hasattr(content, "interps_list") else None,
+ )
+ result = dec(feature, content).clamp(0, 1)
+
+ colors[:, 0:3] = result
+ log(f"Stylization done ({time()-t0:.1f}s)")
+
+ # ββ 6. Interpolate back to original resolution ββββββββββββββββββββββββββββ
+ progress(0.88, desc="Interpolating back to original splatβ¦")
+ t0 = time()
+
+ interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
+ results_OriginalNP = interp2(pos3D_Original)
+ results_Original = torch.from_numpy(results_OriginalNP).to(torch.float32)
+ colors_and_opacity_Original = torch.cat(
+ (results_Original, opacity_Original.unsqueeze(1)), dim=1
+ )
+ log(f"Interpolation done ({time()-t0:.1f}s)")
+
+ # ββ 7. Save output ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ progress(0.95, desc="Saving output splatβ¦")
+ suffix = ".splat" if fileType == "splat" else ".ply"
+ out_dir = tempfile.mkdtemp()
+ out_path = os.path.join(out_dir, f"stylized{suffix}")
+
+ splt.splat_save(
+ pos3D_Original.numpy(),
+ scales_Original.numpy(),
+ rots_Original.numpy(),
+ colors_and_opacity_Original.numpy(),
+ out_path,
+ fileType,
+ )
+ log(f"Saved to: {out_path}")
+ progress(1.0, desc="Done!")
+
+ return out_path, "\n".join(logs)
+
+
+# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+def build_ui():
+ available_devices = (
+ [str(i) for i in range(torch.cuda.device_count())] + ["cpu"]
+ if torch.cuda.is_available()
+ else ["cpu"]
+ )
+
+ with gr.Blocks(
+ title="3DGS Style Transfer",
+ theme=gr.themes.Soft(primary_hue="violet"),
+ css="""
+ #title { text-align: center; }
+ #subtitle { text-align: center; color: #666; margin-bottom: 1rem; }
+ .panel { border-radius: 12px; }
+ #run-btn { font-size: 1.1rem; }
+ """,
+ ) as demo:
+
+ gr.Markdown("# π¨ 3D Gaussian Splat Style Transfer", elem_id="title")
+ gr.Markdown(
+ "Upload a 3DGS scene and a style image β the app will repaint the splat "
+ "with the artistic style of the image and give you a stylized splat to download. "
+ "After downloading, you can view your splat with an [online viewer](https://antimatter15.com/splat/).",
+ elem_id="subtitle",
+ )
+
+ with gr.Row():
+ # ββ Left column: inputs βββββββββββββββββββββββββββββββββββββββββββ
+ with gr.Column(scale=1, elem_classes="panel"):
+ gr.Markdown("### π Inputs")
+
+ splat_input = gr.File(
+ label="3D Gaussian Splat (.ply or .splat)",
+ file_types=[".ply", ".splat"],
+ type="filepath",
+ )
+
+ style_input = gr.Image(
+ label="Style Image",
+ type="filepath",
+ height=240,
+ )
+
+ with gr.Accordion("βοΈ Advanced Settings", open=False):
+ threshold_slider = gr.Slider(
+ minimum=90.0, maximum=100.0, value=99.8, step=0.1,
+ label="Opacity threshold (percentile)",
+ info="Points below this opacity percentile are removed.",
+ )
+ sampling_slider = gr.Slider(
+ minimum=0.5, maximum=3.0, value=1.5, step=0.1,
+ label="Gaussian super-sampling rate",
+ info="Values > 1 add extra samples; 1.0 = no super-sampling.",
+ )
+ device_radio = gr.Radio(
+ choices=available_devices,
+ value=available_devices[0],
+ label="Device",
+ )
+
+ run_btn = gr.Button("π Run Style Transfer", variant="primary", elem_id="run-btn")
+
+ # ββ Right column: outputs βββββββββββββββββββββββββββββββββββββββββ
+ with gr.Column(scale=1, elem_classes="panel"):
+ gr.Markdown("### π₯ Output")
+
+ output_file = gr.File(
+ label="Download Stylized Splat",
+ interactive=False,
+ )
+
+ log_box = gr.Textbox(
+ label="Progress log",
+ lines=12,
+ max_lines=20,
+ interactive=False,
+ placeholder="Logs will appear here once processing startsβ¦",
+ )
+
+ # ββ Examples βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ example_splat_paths = [row[0] for row in EXAMPLE_SPLATS]
+ example_style_paths = [row[1] for row in EXAMPLE_SPLATS]
+
+ valid_examples = [
+ row for row in EXAMPLE_SPLATS
+ if os.path.exists(row[0]) and os.path.exists(row[1])
+ ]
+
+ if valid_examples:
+ gr.Markdown("### πΌοΈ Examples")
+ gr.Examples(
+ examples=valid_examples,
+ inputs=[splat_input, style_input],
+ label="Click an example to load it",
+ )
+
+ # ββ Event wiring ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
+ run_btn.click(
+ fn=run_style_transfer,
+ inputs=[splat_input, style_input, threshold_slider, sampling_slider, device_radio],
+ outputs=[output_file, log_box],
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = build_ui()
+ demo.launch(share=False)
diff --git a/clusters.py b/clusters.py
new file mode 100644
index 0000000000000000000000000000000000000000..556fe199df9f85abd6c4a3df52252753dd0d2461
--- /dev/null
+++ b/clusters.py
@@ -0,0 +1,234 @@
+import torch
+from torch_scatter import scatter
+from torch_geometric.nn.pool.consecutive import consecutive_cluster
+from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops
+from torch_geometric.nn import fps, knn
+from torch_sparse import coalesce
+import graph_helpers as gh
+import sphere_helpers as sh
+import mesh_helpers as mh
+
+import math
+from math import pi,sqrt
+
+
+from warnings import warn
+
+def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2):
+ clusters = []
+ edge_indexes = [torch.clone(edge_index).to(device)]
+ selections_list = [torch.clone(selections).to(device)]
+
+ for _ in range(depth):
+ Nx = Nx//stride
+ Ny = Ny//stride
+ cx,cy = getGrid(pos2D,Nx,Ny)
+ cluster, pos2D = gridCluster(pos2D,cx,cy,Nx)
+ edge_index, selections = selectionAverage(cluster, edge_index, selections)
+
+ clusters.append(torch.clone(cluster).to(device))
+ edge_indexes.append(torch.clone(edge_index).to(device))
+ selections_list.append(torch.clone(selections).to(device))
+
+ return clusters, edge_indexes, selections_list
+
+def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'):
+ clusters = []
+ edge_indexes = [torch.clone(edge_index).to(device)]
+ selections_list = [torch.clone(selections).to(device)]
+ interps_list = [torch.clone(interps).to(device)]
+
+ for _ in range(depth):
+
+ rows = rows//stride
+ cols = cols//stride
+
+ if bary_d is not None:
+ bary_d = bary_d*stride
+
+ if cluster_method == "equirec":
+ centroids, _ = sh.sampleSphere_Equirec(rows,cols)
+
+ elif cluster_method == "layering":
+ centroids, _ = sh.sampleSphere_Layering(rows)
+
+ elif cluster_method == "spiral":
+ centroids, _ = sh.sampleSphere_Spiral(rows,cols)
+
+ elif cluster_method == "icosphere":
+ centroids, _ = sh.sampleSphere_Icosphere(rows)
+
+ elif cluster_method == "random":
+ centroids, _ = sh.sampleSphere_Random(rows,cols)
+
+ elif cluster_method == "random_nodes":
+ index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
+ centroids = pos3D[index]
+
+ elif cluster_method == "fps":
+ # Farthest Point Search used in PointNet++
+ index = fps(pos3D, ratio=ratio)
+ centroids = pos3D[index]
+ else:
+ raise ValueError("Sphere cluster_method unknown")
+
+
+ # Find closest centriod to each current point
+ cluster = knn(centroids,pos3D,1)[1]
+ cluster, _ = consecutive_cluster(cluster)
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
+
+ # Regenerate surface graph
+ normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors
+ edge_index,directions = gh.surface2Edges(pos3D,normals)
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
+
+ clusters.append(torch.clone(cluster).to(device))
+ edge_indexes.append(torch.clone(edge_index).to(device))
+ selections_list.append(torch.clone(selections).to(device))
+ interps_list.append(torch.clone(interps).to(device))
+
+ return clusters, edge_indexes, selections_list, interps_list
+
+def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'):
+ clusters = []
+ edge_indexes = [torch.clone(edge_index).to(device)]
+ selections_list = [torch.clone(selections).to(device)]
+ interps_list = [torch.clone(interps).to(device)]
+
+ for _ in range(depth):
+
+ #Desired number of clusters in the next level
+ N = int(len(pos3D) * ratio)
+
+ if cluster_method == "random":
+ index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
+ centroids = pos3D[index]
+
+ elif cluster_method == "fps":
+ # Farthest Point Search used in PointNet++
+ index = fps(pos3D, ratio=ratio)
+ centroids = pos3D[index]
+
+ # Find closest centriod to each current point
+ cluster = knn(centroids,pos3D,1)[1]
+ cluster, _ = consecutive_cluster(cluster)
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
+ normals = scatter(normals, cluster, dim=0, reduce='mean')
+
+ # Regenerate surface graph
+ normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
+
+ clusters.append(torch.clone(cluster).to(device))
+ edge_indexes.append(torch.clone(edge_index).to(device))
+ selections_list.append(torch.clone(selections).to(device))
+ interps_list.append(torch.clone(interps).to(device))
+
+ return clusters, edge_indexes, selections_list, interps_list
+
+
+def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'):
+ clusters = []
+ edge_indexes = [torch.clone(edge_index).to(device)]
+ selections_list = [torch.clone(selections).to(device)]
+ interps_list = [torch.clone(interps).to(device)]
+
+ for _ in range(depth):
+
+ #Desired number of clusters in the next level
+ N = int(len(pos3D) * ratio)
+
+ # Generate new point cloud from downsampled version of texture map
+ centroids, normals = mh.sampleSurface(mesh,N,return_x=False)
+
+ # Find closest centriod to each current point
+ cluster = knn(centroids,pos3D,1)[1]
+ cluster, _ = consecutive_cluster(cluster)
+ pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
+
+ # Regenerate surface graph
+ #normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector)
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
+
+ clusters.append(torch.clone(cluster).to(device))
+ edge_indexes.append(torch.clone(edge_index).to(device))
+ selections_list.append(torch.clone(selections).to(device))
+ interps_list.append(torch.clone(interps).to(device))
+
+ return clusters, edge_indexes, selections_list, interps_list
+
+
+
+def getGrid(pos,Nx,Ny,xrange=None,yrange=None):
+ xmin = torch.min(pos[:,0]) if xrange is None else xrange[0]
+ ymin = torch.min(pos[:,1]) if yrange is None else yrange[0]
+ xmax = torch.max(pos[:,0]) if xrange is None else xrange[1]
+ ymax = torch.max(pos[:,1]) if yrange is None else yrange[1]
+
+ cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1)
+ cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1)
+ return cx, cy
+
+def gridCluster(pos,cx,cy,xmax):
+ cluster = cx + cy*xmax
+ cluster = cluster.type(torch.long) # Cast appropriately
+ cluster, _ = consecutive_cluster(cluster)
+ pos = scatter(pos, cluster, dim=0, reduce='mean')
+
+ return cluster, pos
+
+def selectionAverage(cluster, edge_index, selections):
+ num_nodes = cluster.size(0)
+ edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1)
+ edge_index, selections = remove_self_loops(edge_index, selections)
+ if edge_index.numel() > 0:
+
+ # To avoid means over discontinuities, do mean for two selections at at a time
+ final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
+ selections_check = torch.round(selections_check).type(torch.long)
+
+ final_selections = torch.zeros_like(selections_check).to(selections.device)
+
+ final_selections[torch.where(selections_check==4)] = 4
+ final_selections[torch.where(selections_check==5)] = 5
+
+ #Rotate selection kernel
+ selections += 2
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
+
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
+ selections_check = torch.round(selections_check).type(torch.long)
+ final_selections[torch.where(selections_check==4)] = 2
+ final_selections[torch.where(selections_check==5)] = 3
+
+ #Rotate selection kernel
+ selections += 2
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
+
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
+ selections_check = torch.round(selections_check).type(torch.long)
+ final_selections[torch.where(selections_check==4)] = 8
+ final_selections[torch.where(selections_check==5)] = 1
+
+ #Rotate selection kernel
+ selections += 2
+ selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
+
+ _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
+ selections_check = torch.round(selections_check).type(torch.long)
+ final_selections[torch.where(selections_check==4)] = 6
+ final_selections[torch.where(selections_check==5)] = 7
+
+ #print(torch.min(final_selections), torch.max(final_selections))
+ #print(torch.mean(final_selections.type(torch.float)))
+
+ edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long))
+
+ else:
+ edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long))
+ print("Warning: Edge Pool found no edges")
+
+ return edge_index, selections
diff --git a/example-broche-rose-gold.splat b/example-broche-rose-gold.splat
new file mode 100644
index 0000000000000000000000000000000000000000..4e96c18d7bd1e6436ad9fa64577dd91b1e4ad113
--- /dev/null
+++ b/example-broche-rose-gold.splat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62cccf0adfb4e5713985e8ee874fa30a6a37ae222a23ead7c6e4639a5802ab62
+size 4157728
diff --git a/example.jpg b/example.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2a3fbbc38cb9027240bd9d50a506c1fb5def1dd7
Binary files /dev/null and b/example.jpg differ
diff --git a/graph_helpers.py b/graph_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b53a78fd9ab2ac39855d9663e3343d3f1093b24
--- /dev/null
+++ b/graph_helpers.py
@@ -0,0 +1,400 @@
+import torch
+from torch_geometric.nn import radius_graph, knn_graph
+import torch_geometric as tg
+from torch_geometric.utils import subgraph
+import utils
+from math import sqrt
+
+def getImPos(rows,cols,start_row=0,start_col=0):
+ row_space = torch.arange(start_row,rows+start_row)
+ col_space = torch.arange(start_col,cols+start_col)
+ col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy')
+ im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2))
+ return im_pos
+
+def convertImPos(im_pos,flip_y=True):
+
+ # Cast to float for clustering based methods
+ pos2D = im_pos.float()
+
+ # Switch rows,cols to x,y
+ pos2D[:,[1,0]] = pos2D[:,[0,1]]
+
+ if flip_y:
+
+ # Flip to y-axis to match mathematical definition and edges2Selections settings
+ pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1]
+
+ return pos2D
+
+
+def grid2Edges(locs):
+ # Assume locs are already spaced at a distance of 1 structure
+ edge_index = radius_graph(locs,1.44,loop=True)
+ return edge_index
+
+def radius2Edges(locs,r=1.0):
+ edge_index = radius_graph(locs,r,loop=True)
+ return edge_index
+
+def knn2Edges(locs,knn=9):
+ edge_index = knn_graph(locs,knn,loop=True)
+ return edge_index
+
+def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9):
+
+ if up_vector is None:
+ up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device)
+
+ # K Nearest Neighbors graph
+ edge_index = knn_graph(pos3D,k_neighbors,loop=True)
+
+ # Cull neighbors based on normals (dot them together)
+ culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1)
+ edge_index = edge_index[:,torch.where(culling>0)[0]]
+
+ # For each node, rotate based on Grahm-Schmidt Orthognalization
+ norms = normals[edge_index[0]]
+
+ z_dir = norms
+ z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector
+ #x_dir = torch.cross(up_vector,norms,dim=1)
+ x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch
+ x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True)
+ #y_dir = torch.cross(norms,x_dir,dim=1)
+ y_dir = utils.cross(norms,x_dir)
+ y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True)
+
+ directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]])
+
+ # Perform rotation by multiplying out rotation matrix
+ temp = torch.clone(directions) # Buffer
+ directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2]
+ directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2]
+ #directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2]
+
+ # Drop z coordinate
+ directions = directions[:,:2]
+
+ return edge_index, directions
+
+def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False):
+
+ # Current Ordering
+ # 4 3 2
+ # 5 0 1
+ # 6 7 8
+ if y_down:
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0)
+ else:
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
+
+ if interpolated:
+
+ if bary_d is None:
+ edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList)
+ else:
+ edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList)
+ interps = normalizeEdges(edge_index,selections,interps)
+ return edge_index,selections,interps
+
+ else:
+ selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1
+ selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection
+ return selections
+
+
+def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True):
+
+ sources = sources.flatten()
+ targets = targets.flatten()
+
+ prev_sources += sources.tolist()
+ prev_targets += targets.tolist()
+ prev_selections += len(sources)*[selection]
+
+ if reverse:
+ prev_sources += targets
+ prev_targets += sources
+ prev_selections += len(sources)*[utils.reverse_selection(selection)]
+
+ return prev_sources,prev_targets,prev_selections
+
+def maskNodes(mask,x):
+ node_mask = torch.where(mask)
+ x = x[node_mask]
+ return x
+
+def maskPoints(mask,x,y):
+
+ mask = torch.squeeze(mask)
+
+ x0 = torch.floor(x).long()
+ x1 = x0 + 1
+ y0 = torch.floor(y).long()
+ y1 = y0 + 1
+
+ x0 = torch.clip(x0, 0, mask.shape[1]-1);
+ x1 = torch.clip(x1, 0, mask.shape[1]-1);
+ y0 = torch.clip(y0, 0, mask.shape[0]-1);
+ y1 = torch.clip(y1, 0, mask.shape[0]-1);
+
+ Ma = mask[ y0, x0 ]
+ Mb = mask[ y1, x0 ]
+ Mc = mask[ y0, x1 ]
+ Md = mask[ y1, x1 ]
+
+ node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0]
+
+ return node_mask
+
+
+def maskGraph(mask,edge_index,selections,interps=None):
+
+ edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True)
+ selections = selections[edge_mask]
+
+ if interps:
+ interps = interps[edge_mask]
+ return edge_index, selections, interps
+ else:
+ return edge_index, selections
+
+def interpolateSelections(edge_index,directions,vectorList=None):
+
+ if vectorList is None:
+ # Current Ordering
+ # 4 3 2
+ # 5 0 1
+ # 6 7 8
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
+
+ # Normalize directions for simplicity of calculations
+ dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
+ directions = directions/dir_norm
+ #locs = torch.where(dir_norm > 1)[0]
+ #directions[locs] = directions[locs]/dir_norm[locs]
+
+ values = torch.matmul(directions,vectorList)
+ best = torch.unsqueeze(torch.argmax(values,dim=1),1)
+
+ best_val = torch.take_along_dim(values,best,dim=1)
+
+ # Look at both neighbors to see who is closer
+ lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
+ upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
+
+ comp_vals = torch.cat((lower_val,upper_val),dim=1)
+
+ second_best_vals = torch.amax(comp_vals,dim=1)
+ second_best = torch.argmax(comp_vals,dim=1)
+
+ # Find the interpolation value (in terms of angles)
+ best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function
+ angle_best = torch.arccos(best_val)
+ angle_second_best = torch.arccos(second_best_vals)
+
+ angle_vals = angle_best/(angle_second_best + angle_best)
+
+ # Use negative values for clockwise selections
+ clockwise = torch.where(second_best == 0)[0]
+ angle_vals[clockwise] = -angle_vals[clockwise]
+
+ # Handle computation problems at the poles
+ angle_vals = torch.nan_to_num(angle_vals)
+
+ # Make Selections
+ selections = best[:,0] + 1
+
+ # Same cell selection
+ same_locs = torch.where(edge_index[0] == edge_index[1])
+ selections[same_locs] = 0
+ angle_vals[same_locs] = 0
+
+ # Make starting interp_values
+ interps = torch.ones_like(angle_vals)
+ interps -= torch.abs(angle_vals)
+
+ # Add new edges
+ pos_interp_locs = torch.where(angle_vals > 1e-2)[0]
+ pos_interps = angle_vals[pos_interp_locs]
+ pos_edges = edge_index[:,pos_interp_locs]
+ pos_selections = selections[pos_interp_locs] + 1
+ pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around
+
+ neg_interp_locs = torch.where(angle_vals < -1e-2)[0]
+ neg_interps = torch.abs(angle_vals[neg_interp_locs])
+ neg_edges = edge_index[:,neg_interp_locs]
+ neg_selections = selections[neg_interp_locs] - 1
+ neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
+
+ edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1)
+ selections = torch.cat((selections,pos_selections,neg_selections),dim=0)
+ interps = torch.cat((interps,pos_interps,neg_interps),dim=0)
+
+ return edge_index,selections,interps
+
+def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None):
+
+ if vectorList is None:
+ # Current Ordering
+ # 4 3 2
+ # 5 0 1
+ # 6 7 8
+ vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device)
+
+ # Preprune central selections and reappend them at the end
+ same_locs = torch.where(edge_index[0] == edge_index[1])
+ same_edges = edge_index[:,same_locs[0]]
+
+ different_locs = torch.where(edge_index[0] != edge_index[1])
+ edge_index = edge_index[:,different_locs[0]]
+ directions = directions[different_locs[0]]
+
+ # Normalize directions for simplicity of calculations
+ dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
+ unit_directions = directions/dir_norm
+ #locs = torch.where(dir_norm > 1)[0]
+ #directions[locs] = directions[locs]/dir_norm[locs]
+
+ values = torch.matmul(unit_directions,vectorList)
+ best = torch.unsqueeze(torch.argmax(values,dim=1),1)
+ #best_val = torch.take_along_dim(values,best,dim=1)
+
+ # Look at both neighbors to see who is closer
+ lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
+ upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
+
+ comp_vals = torch.cat((lower_val,upper_val),dim=1)
+
+ second_best = torch.argmax(comp_vals,dim=1)
+ #second_best_vals = torch.amax(comp_vals,dim=1)
+
+ # Convert into uv cooridnates for barycentric interpolation calculation
+ # /|
+ # / |v
+ # /__|
+ # u
+
+ scaled_directions = torch.abs(directions/d)
+ u = torch.amax(scaled_directions,dim=1)
+ v = torch.amin(scaled_directions,dim=1)
+
+ # Force coordinates to be within the triangle
+ boundary_check = torch.where(u > d)
+ v[boundary_check] /= u[boundary_check]
+ u[boundary_check] = 1.0
+
+ # Precalculated barycentric values from linear matrix solve
+ I0 = 1 - u
+ I1 = u - v
+ I2 = v
+
+ # Make first selections and proper interps
+ selections = best[:,0] + 1
+ interps = I1
+ even_sels = torch.where(selections % 2 == 0)
+ interps[even_sels] = I2[even_sels] # Corners get different weights
+
+ # Make new edges for the central selections
+ central_edges = torch.clone(edge_index).to(edge_index.device)
+ central_selections = torch.zeros_like(selections)
+ central_interps = I0
+
+ # Make new edges for the last selection
+ pos_locs = torch.where(second_best==1)[0]
+ pos_edges = edge_index[:,pos_locs]
+ pos_selections = selections[pos_locs] + 1
+ pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around
+ pos_interps = I1[pos_locs]
+ even_sels = torch.where(pos_selections % 2 == 0)
+ pos_interps[even_sels] = I2[pos_locs][even_sels]
+
+ neg_locs = torch.where(second_best==0)[0]
+ neg_edges = edge_index[:,neg_locs]
+ neg_selections = selections[neg_locs] - 1
+ neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
+ neg_interps = I1[neg_locs]
+ even_sels = torch.where(neg_selections % 2 == 0)
+ neg_interps[even_sels] = I2[neg_locs][even_sels]
+
+ # Account for the previously pruned same node edges
+ same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long)
+ same_interps = torch.ones(same_edges.shape[1],dtype=torch.float)
+
+ # Combine
+ edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1)
+ selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0)
+ interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0)
+
+ #edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1)
+ #selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0)
+ #interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0)
+
+ # Account for edges to the same node
+ #same_locs = torch.where(edge_index[0] == edge_index[1])
+ #selections[same_locs] = 0
+ #interps[same_locs] = 1
+
+ return edge_index,selections,interps
+
+def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False):
+ '''Given an edge_index and selections, normalize the edges for each node so that
+ aggregation of edges with interps = 1. If interps is given, use a weighted average.
+ if kernel_norm = True, account for missing selections by increasing weight on other selections.'''
+
+ N = torch.max(edge_index) + 1
+ S = torch.max(selections) + 1
+
+ total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device)
+
+ if interps is None:
+ interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device)
+
+ # Aggregate all edges to determine normalizations per selection
+ nodes = edge_index[0]
+ #total_weight[nodes,selections] += interps
+ total_weight.index_put_((nodes,selections),interps,accumulate=True)
+
+ # Reassign interps accordingly
+ if kernel_norm:
+ row_totals = torch.sum(total_weight,dim=1)
+ interps = interps * S/row_totals[nodes]
+ else:
+ norms = total_weight[nodes,selections]
+ norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error
+ interps = interps/norms
+
+ return interps
+
+def simplifyGraph(edge_index,selections,edge_lengths):
+ # Take the shortest edge for the set of the same selections on a given node
+ num_edges = edge_index.shape[1]
+
+ # Keep track of which nodes have been visited
+ keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device)
+
+ previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
+ previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
+
+ for i in range(num_edges):
+ start_node = edge_index[0,i]
+ #end_node = edge_index[1,i]
+ selection = selections[i]
+ distance = edge_lengths[i]
+
+ if distance < previous_best_distance[start_node,selection]:
+ previous_best_distance[start_node,selection] = distance
+ keep_edges[i] = True
+
+ prev = previous_best_edge[start_node,selection]
+ if prev != -1:
+ keep_edges[prev] = False
+
+ previous_best_edge[start_node,selection] = i
+
+ edge_index = edge_index[:,torch.where(keep_edges)[0]]
+ selections = selections[torch.where(keep_edges)]
+
+ return edge_index, selections
+
diff --git a/graph_io.py b/graph_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..bce2bb09aff7752716898e64e86fa519af7820b4
--- /dev/null
+++ b/graph_io.py
@@ -0,0 +1,306 @@
+import numpy as np
+import torch
+from torch_geometric.nn import knn
+from torch_geometric.data import Data
+from torch_geometric.nn import radius_graph, knn_graph
+
+import graph_helpers as gh
+import sphere_helpers as sh
+import mesh_helpers as mh
+import clusters as cl
+import utils
+
+from torch_scatter import scatter
+
+import math
+from math import pi, sqrt
+
+from warnings import warn
+
+def image2Graph(data, gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
+
+ _,ch,rows,cols = data.shape
+
+ x = torch.reshape(data,(ch,rows*cols)).permute((1,0)).to(device)
+
+ if mask is not None:
+ # Mask out nodes
+ node_mask = torch.where(mask.flatten())
+ x = x[node_mask]
+
+ if gt is not None:
+ y = gt.flatten().to(device)
+ if mask is not None:
+ y = y[node_mask]
+
+ if x_only:
+ if gt is not None:
+ return x,y
+ else:
+ return x
+
+ im_pos = gh.getImPos(rows,cols)
+
+ if mask is not None:
+ im_pos = im_pos[node_mask]
+
+ # Make "point cloud" for clustering
+ pos2D = gh.convertImPos(im_pos,flip_y=False)
+
+ # Generate initial graph
+ edge_index = gh.grid2Edges(pos2D)
+ directions = pos2D[edge_index[1]] - pos2D[edge_index[0]]
+ selections = gh.edges2Selections(edge_index,directions,interpolated=False,y_down=True)
+
+ # Generate info for downsampled versions of the graph
+ clusters, edge_indexes, selections_list = cl.makeImageClusters(pos2D,cols,rows,edge_index,selections,depth=depth,device=device)
+
+ # Make final graph and metadata needed for mapping the result after going through the network
+ graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=None)
+ metadata = Data(original=data,im_pos=im_pos.long(),rows=rows,cols=cols,ch=ch)
+
+ if gt is not None:
+ graph.y = y
+
+ return graph,metadata
+
+def graph2Image(result,metadata,canvas=None):
+
+ x = utils.toNumpy(result,permute=False)
+ im_pos = utils.toNumpy(metadata.im_pos,permute=False)
+ if canvas is None:
+ canvas = utils.makeCanvas(x,metadata.original)
+
+ # Paint over the original image (neccesary for masked images)
+ canvas[im_pos[:,0],im_pos[:,1]] = x
+
+ return canvas
+
+### Begin Interpolated Methods ###
+
+def sphere2Graph(data, structure="layering", cluster_method="layering", scale=1.0, stride=2, interpolation_mode = "angle", gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'):
+
+ _,ch,rows,cols = data.shape
+
+ if structure == "equirec":
+ # Use the original data to start with
+ cartesian, spherical = sh.sampleSphere_Equirec(scale*rows,scale*cols)
+ elif structure == "layering":
+ cartesian, spherical = sh.sampleSphere_Layering(scale*rows)
+ elif structure == "spiral":
+ cartesian, spherical = sh.sampleSphere_Spiral(scale*rows,scale*cols)
+ elif structure == "icosphere":
+ cartesian, spherical = sh.sampleSphere_Icosphere(scale*rows)
+ elif structure == "random":
+ cartesian, spherical = sh.sampleSphere_Random(scale*rows,scale*cols)
+ else:
+ raise ValueError("Sphere structure unknown")
+
+ if interpolation_mode == "bary":
+ bary_d = pi/(scale*rows)
+ else:
+ bary_d = None
+
+ # Get the landing point for each node
+ sample_x, sample_y = sh.spherical2equirec(spherical[:,0],spherical[:,1],rows,cols)
+
+ if mask is not None:
+
+ node_mask = gh.maskPoints(mask,sample_x,sample_y)
+ sample_x = sample_x[node_mask]
+ sample_y = sample_y[node_mask]
+ spherical = spherical[node_mask]
+ cartesian = cartesian[node_mask]
+
+ features = utils.bilinear_interpolate(data, sample_x, sample_y).to(device)
+
+ if gt is not None:
+ features_y = utils.bilinear_interpolate(gt.unsqueeze(0), sample_x, sample_y).to(device)
+
+ if x_only:
+ if gt is not None:
+ return features,features_y
+ else:
+ return features
+
+ # Build initial graph
+ edge_index,directions = gh.surface2Edges(cartesian,cartesian)
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
+
+ # Generate info for downsampled versions of the graph
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSphereClusters(cartesian,edge_index,selections,interps,rows*scale,cols*scale,cluster_method,stride=stride,bary_d=bary_d,depth=depth,device=device)
+
+ # Make final graph and metadata needed for mapping the result after going through the network
+ graph = Data(x=features,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
+ metadata = Data(original=data,pos3D=cartesian,mask=mask,rows=rows,cols=cols,ch=ch)
+
+ if gt is not None:
+ graph.y = features_y
+
+ return graph, metadata
+
+def graph2Sphere(features,metadata):
+
+ # Generate equirectangular points and their 3D locations
+ theta, phi = sh.equirec2spherical(metadata.rows, metadata.cols)
+ x,y,z = sh.spherical2xyz(theta,phi)
+
+ v = torch.stack((x,y,z),dim=1)
+
+ # Find closest 3D point to each equirectangular point
+ nearest = torch.reshape(knn(metadata.pos3D,v,3)[1],(len(v),3))
+
+ #Interpolate based on proximty to each node
+ w0 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,0]]),dim=1, keepdim=True).to(features.device)
+ w1 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,1]]),dim=1, keepdim=True).to(features.device)
+ w2 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,2]]),dim=1, keepdim=True).to(features.device)
+
+ w0 = torch.nan_to_num(w0, nan=1e6)
+ w1 = torch.nan_to_num(w1, nan=1e6)
+ w2 = torch.nan_to_num(w2, nan=1e6)
+
+ w0 = torch.clamp(w0,0,1e6)
+ w1 = torch.clamp(w1,0,1e6)
+ w2 = torch.clamp(w2,0,1e6)
+
+ total = w0 + w1 + w2
+
+ #w0,w1,w2 = mh.getBarycentricWeights(v,metadata.pos3D[nearest[:,0]],metadata.pos3D[nearest[:,1]],metadata.pos3D[nearest[:,2]])
+
+ #w0 = w0.unsqueeze(1).to(features.device)
+ #w1 = w1.unsqueeze(1).to(features.device)
+ #w2 = w2.unsqueeze(1).to(features.device)
+
+ result = (w0*features[nearest[:,0]] + w1*features[nearest[:,1]] + w2*features[nearest[:,2]])/total
+
+ #result = result.clamp(0,1)
+
+ if hasattr(metadata,"mask"):
+ mask = utils.toNumpy(metadata.mask.squeeze(),permute=False)
+ canvas = utils.makeCanvas(result,metadata.original)
+ result = np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
+ canvas[np.where(mask)] = result[np.where(mask)]
+ return canvas
+ else:
+ return np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1]))
+
+
+
+def splat2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, depth = 1, device = 'cpu'):
+ """ Sample mesh faces to determine graph """
+
+ if up_vector == None:
+ up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
+ #up_vector = 2*torch.rand((1,3))-1
+ up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
+
+
+ #position, normal vector, uv coordinates in the texture map, x is color
+ pos3D, normals = mh.sampleSurface(mesh,N)
+
+ # Build initial graph
+ #edge_index are neighbors of a point, directions are the directions from that point
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
+ #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
+
+ # Generate info for downsampled versions of the graph
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+ #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+
+ # Make final graph and metadata needed for mapping the result after going through the network
+ graph = Data(clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
+ metadata = Data(original=data,pos3D=pos3D,mesh=mesh)
+
+ return graph,metadata
+
+def mesh2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, mask = None, depth = 1, x_only = False, device = 'cpu'):
+ """ Sample mesh faces to determine graph """
+
+ if up_vector == None:
+ up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
+ #up_vector = 2*torch.rand((1,3))-1
+ up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
+
+ if mask is not None:
+ warn("Masks are not currently implemented for mesh graphs")
+
+ #position, normal vector, uv coordinates in the texture map, x is color
+ pos3D, normals, uvs, x = mh.sampleSurface(mesh,N,return_x=True)
+
+ x = x.to(device)
+
+ if x_only:
+ warn("x_only returns randomly selected points for mesh2Graph. Do not use with previous graph structures")
+ return x
+
+ # Build initial graph
+ #edge_index are neighbors of a point, directions are the directions from that point
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
+ #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
+
+ # Generate info for downsampled versions of the graph
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+ #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+
+ # Make final graph and metadata needed for mapping the result after going through the network
+ graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
+ metadata = Data(original=data,pos3D=pos3D,uvs=uvs,mesh=mesh)
+
+ return graph,metadata
+
+def graph2Splat(features,metadata,view3D=False):
+
+ features = features.cpu().numpy()
+
+ canvas = utils.toNumpy(metadata.original)
+ rows,cols,ch = canvas.shape
+
+ # Get 2D positions by scaling uv
+ pos2D = metadata.uvs.cpu().numpy()
+ pos2D[:,0] = pos2D[:,0]*cols
+ pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
+ pos2D[:,1] = pos2D[:,1]*rows
+
+ # Generate desired points
+ row_space = np.arange(rows)
+ col_space = np.arange(cols)
+ col_image,row_image = np.meshgrid(col_space,row_space)
+
+ canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
+ canvas = np.clip(canvas,0,1)
+
+ if view3D:
+ mesh = mh.setTexture(metadata.mesh,canvas)
+ mesh.show()
+
+ return canvas
+
+
+def graph2Mesh(features,metadata,view3D=False):
+
+ features = features.cpu().numpy()
+
+ canvas = utils.toNumpy(metadata.original)
+ rows,cols,ch = canvas.shape
+
+ # Get 2D positions by scaling uv
+ pos2D = metadata.uvs.cpu().numpy()
+ pos2D[:,0] = pos2D[:,0]*cols
+ pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom
+ pos2D[:,1] = pos2D[:,1]*rows
+
+ # Generate desired points
+ row_space = np.arange(rows)
+ col_space = np.arange(cols)
+ col_image,row_image = np.meshgrid(col_space,row_space)
+
+ canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image)
+ canvas = np.clip(canvas,0,1)
+
+ if view3D:
+ mesh = mh.setTexture(metadata.mesh,canvas)
+ mesh.show()
+
+ return canvas
diff --git a/graph_networks/.DS_Store b/graph_networks/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5562a58c1a663d343a951dca0a853bba3e02bafc
Binary files /dev/null and b/graph_networks/.DS_Store differ
diff --git a/graph_networks/LinearStyleTransfer/.DS_Store b/graph_networks/LinearStyleTransfer/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..ee83e0ec67801c85f3e109379b5003d78c5cacda
Binary files /dev/null and b/graph_networks/LinearStyleTransfer/.DS_Store differ
diff --git a/graph_networks/LinearStyleTransfer/LICENSE b/graph_networks/LinearStyleTransfer/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..b3fb3a86103b77107cc211e4aa9224fd79076f97
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/LICENSE
@@ -0,0 +1,25 @@
+BSD 2-Clause License
+
+Copyright (c) 2018, SunshineAtNoon
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/graph_networks/LinearStyleTransfer/README.md b/graph_networks/LinearStyleTransfer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8fc41f7b2d4f5ce38771709cc7730e93125a7098
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/README.md
@@ -0,0 +1,102 @@
+## Learning Linear Transformations for Fast Image and Video Style Transfer
+**[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)**
+
+

+

+
+## Prerequisites
+- [Pytorch](http://pytorch.org/)
+- [torchvision](https://github.com/pytorch/vision)
+- [opencv](https://opencv.org/) for video generation
+
+**All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2**
+
+## Style Transfer
+- Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer`
+- Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing).
+- Uncompress to root folder :
+```
+cd LinearStyleTransfer
+unzip models.zip
+rm models.zip
+```
+
+#### Artistic style transfer
+```
+python TestArtistic.py
+```
+or conduct style transfer on relu_31 features
+```
+python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
+```
+
+#### Photo-realistic style transfer
+For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository.
+```
+cd libs/pytorch_spn
+sh make.sh
+cd ../..
+```
+Then:
+```
+python TestPhotoReal.py
+```
+Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py).
+
+#### Video style transfer
+```
+python TestVideo.py
+```
+
+#### Real-time video demo
+```
+python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31
+```
+
+## Model Training
+### Data Preparation
+- MSCOCO
+```
+wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip
+```
+- WikiArt
+ - Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers).
+ - Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running:
+ ```
+ kg download -u -p -c painter-by-numbers -f train.zip
+ ```
+
+### Training
+#### Train a style transfer model
+To train a model that transfers relu4_1 features, run:
+```
+python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
+```
+or train a model that transfers relu3_1 features:
+```
+python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR
+```
+Key hyper-parameters:
+- style_layers: which features to compute style loss.
+- style_weight: larger style weight leads to heavier style in transferred images.
+
+Intermediate results and weight will be stored in `OUTPUT_DIR`
+
+#### Train a SPN model to cancel distortions for photo-realistic style transfer
+Run:
+```
+python TrainSPN.py --contentPath PATH_TO_MSCOCO
+```
+
+### Acknowledgement
+- We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer.
+
+### Citation
+```
+@inproceedings{li2018learning,
+ author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan},
+ title = {Learning Linear Transformations for Fast Arbitrary Style Transfer},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
+ year = {2019}
+}
+```
diff --git a/graph_networks/LinearStyleTransfer/TestArtistic.py b/graph_networks/LinearStyleTransfer/TestArtistic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8435542c201b9448145a8642de0cd4111d715ee5
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/TestArtistic.py
@@ -0,0 +1,98 @@
+import os
+import torch
+import argparse
+from libs.Loader import Dataset
+from libs.Matrix import MulLayer
+import torchvision.utils as vutils
+import torch.backends.cudnn as cudnn
+from libs.utils import print_options
+from libs.models import encoder3,encoder4, encoder5
+from libs.models import decoder3,decoder4, decoder5
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--matrixPath", default='models/r41.pth',
+ help='pre-trained model path')
+parser.add_argument("--stylePath", default="./data/style/",
+ help='path to style image')
+parser.add_argument("--contentPath", default="./data/content/",
+ help='path to frames')
+parser.add_argument("--outf", default="Artistic/",
+ help='path to transferred images')
+parser.add_argument("--batchSize", type=int,default=1,
+ help='batch size')
+parser.add_argument('--loadSize', type=int, default=256,
+ help='scale image size')
+parser.add_argument('--fineSize', type=int, default=256,
+ help='crop image size')
+parser.add_argument("--layer", default="r41",
+ help='which features to transfer, either r31 or r41')
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.cuda = torch.cuda.is_available()
+print_options(opt)
+
+os.makedirs(opt.outf,exist_ok=True)
+cudnn.benchmark = True
+
+################# DATA #################
+content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True)
+content_loader = torch.utils.data.DataLoader(dataset=content_dataset,
+ batch_size = opt.batchSize,
+ shuffle = False)
+ #num_workers = 1)
+style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True)
+style_loader = torch.utils.data.DataLoader(dataset=style_dataset,
+ batch_size = opt.batchSize,
+ shuffle = False)
+ #num_workers = 1)
+
+################# MODEL #################
+if(opt.layer == 'r31'):
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ vgg = encoder4()
+ dec = decoder4()
+matrix = MulLayer(opt.layer)
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.decoder_dir))
+matrix.load_state_dict(torch.load(opt.matrixPath))
+
+################# GLOBAL VARIABLE #################
+contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ matrix.cuda()
+ contentV = contentV.cuda()
+ styleV = styleV.cuda()
+
+for ci,(content,contentName) in enumerate(content_loader):
+ contentName = contentName[0]
+ contentV.resize_(content.size()).copy_(content)
+ for sj,(style,styleName) in enumerate(style_loader):
+ styleName = styleName[0]
+ styleV.resize_(style.size()).copy_(style)
+
+ # forward
+ with torch.no_grad():
+ sF = vgg(styleV)
+ cF = vgg(contentV)
+
+ if(opt.layer == 'r41'):
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
+ else:
+ feature,transmatrix = matrix(cF,sF)
+ transfer = dec(feature)
+
+ transfer = transfer.clamp(0,1)
+ vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize)
+ print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName))
diff --git a/graph_networks/LinearStyleTransfer/TestPhotoReal.py b/graph_networks/LinearStyleTransfer/TestPhotoReal.py
new file mode 100644
index 0000000000000000000000000000000000000000..e00c7bb3468dfd8e678076a7b384300c521bf8a7
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/TestPhotoReal.py
@@ -0,0 +1,118 @@
+import os
+import cv2
+import time
+import torch
+import argparse
+import numpy as np
+from PIL import Image
+from libs.SPN import SPN
+import torchvision.utils as vutils
+from libs.utils import print_options
+from libs.MatrixTest import MulLayer
+import torch.backends.cudnn as cudnn
+from libs.LoaderPhotoReal import Dataset
+from libs.models import encoder3,encoder4
+from libs.models import decoder3,decoder4
+import torchvision.transforms as transforms
+from libs.smooth_filter import smooth_filter
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--matrixPath", default='models/r41.pth',
+ help='pre-trained model path')
+parser.add_argument("--stylePath", default="data/photo_real/style/images/",
+ help='path to style image')
+parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/",
+ help='path to style image masks')
+parser.add_argument("--contentPath", default="data/photo_real/content/images/",
+ help='path to content image')
+parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/",
+ help='path to content image masks')
+parser.add_argument("--outf", default="PhotoReal/",
+ help='path to save output images')
+parser.add_argument("--batchSize", type=int,default=1,
+ help='batch size')
+parser.add_argument('--fineSize', type=int, default=512,
+ help='image size')
+parser.add_argument("--layer", default="r41",
+ help='features of which layer to transform, either r31 or r41')
+parser.add_argument("--spn_dir", default='models/r41_spn.pth',
+ help='path to pretrained SPN model')
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.cuda = torch.cuda.is_available()
+print_options(opt)
+
+os.makedirs(opt.outf, exist_ok=True)
+
+cudnn.benchmark = True
+
+################# DATA #################
+dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize)
+loader = torch.utils.data.DataLoader(dataset=dataset,
+ batch_size=1,
+ shuffle=False)
+
+################# MODEL #################
+if(opt.layer == 'r31'):
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ vgg = encoder4()
+ dec = decoder4()
+matrix = MulLayer(opt.layer)
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.decoder_dir))
+matrix.load_state_dict(torch.load(opt.matrixPath))
+spn = SPN()
+spn.load_state_dict(torch.load(opt.spn_dir))
+
+################# GLOBAL VARIABLE #################
+contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ spn.cuda()
+ matrix.cuda()
+ contentV = contentV.cuda()
+ styleV = styleV.cuda()
+ whitenV = whitenV.cuda()
+
+for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader):
+ imname = imname[0]
+ contentV.resize_(contentImg.size()).copy_(contentImg)
+ styleV.resize_(styleImg.size()).copy_(styleImg)
+ whitenV.resize_(whitenImg.size()).copy_(whitenImg)
+
+ # forward
+ sF = vgg(styleV)
+ cF = vgg(contentV)
+
+ with torch.no_grad():
+ if(opt.layer == 'r41'):
+ feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks)
+ else:
+ feature = matrix(cF,sF,cmasks,smasks)
+ transfer = dec(feature)
+ filtered = spn(transfer,whitenV)
+ vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0])))
+
+ filtered = filtered.clamp(0,1)
+ filtered = filtered.cpu()
+ vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0]))
+ out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
+ content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy()
+ content = content.copy()
+ out_img = out_img.copy()
+ smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1)
+ smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0]))
+ print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \
+ %(opt.outf,imname,opt.outf,imname.split('.')[0]))
diff --git a/graph_networks/LinearStyleTransfer/TestVideo.py b/graph_networks/LinearStyleTransfer/TestVideo.py
new file mode 100644
index 0000000000000000000000000000000000000000..94323aac35353cdd4723048c1d52aa1534608998
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/TestVideo.py
@@ -0,0 +1,108 @@
+import os
+import torch
+import argparse
+from PIL import Image
+from libs.Loader import Dataset
+from libs.Matrix import MulLayer
+import torch.backends.cudnn as cudnn
+from libs.models import encoder3,encoder4
+from libs.models import decoder3,decoder4
+import torchvision.transforms as transforms
+from libs.utils import makeVideo, print_options
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--matrix_dir", default="models/r31.pth",
+ help='path to pre-trained model')
+parser.add_argument("--style", default="data/style/in2.jpg",
+ help='path to style image')
+parser.add_argument("--content_dir", default="data/videos/content/mountain_2/",
+ help='path to video frames')
+parser.add_argument('--loadSize', type=int, default=512,
+ help='scale image size')
+parser.add_argument('--fineSize', type=int, default=512,
+ help='crop image size')
+parser.add_argument("--name",default="transferred_video",
+ help="name of generated video")
+parser.add_argument("--layer",default="r31",
+ help="features of which layer to transform")
+parser.add_argument("--outf",default="videos",
+ help="output folder")
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.cuda = torch.cuda.is_available()
+print_options(opt)
+
+os.makedirs(opt.outf,exist_ok=True)
+cudnn.benchmark = True
+
+################# DATA #################
+def loadImg(imgPath):
+ img = Image.open(imgPath).convert('RGB')
+ transform = transforms.Compose([
+ transforms.Scale(opt.fineSize),
+ transforms.ToTensor()])
+ return transform(img)
+styleV = loadImg(opt.style).unsqueeze(0)
+
+content_dataset = Dataset(opt.content_dir,
+ loadSize = opt.loadSize,
+ fineSize = opt.fineSize,
+ test = True,
+ video = True)
+content_loader = torch.utils.data.DataLoader(dataset = content_dataset,
+ batch_size = 1,
+ shuffle = False)
+
+################# MODEL #################
+if(opt.layer == 'r31'):
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ vgg = encoder4()
+ dec = decoder4()
+matrix = MulLayer(layer=opt.layer)
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.decoder_dir))
+matrix.load_state_dict(torch.load(opt.matrix_dir))
+
+################# GLOBAL VARIABLE #################
+contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ matrix.cuda()
+
+ styleV = styleV.cuda()
+ contentV = contentV.cuda()
+
+result_frames = []
+contents = []
+style = styleV.squeeze(0).cpu().numpy()
+sF = vgg(styleV)
+
+for i,(content,contentName) in enumerate(content_loader):
+ print('Transfer frame %d...'%i)
+ contentName = contentName[0]
+ contentV.resize_(content.size()).copy_(content)
+ contents.append(content.squeeze(0).float().numpy())
+ # forward
+ with torch.no_grad():
+ cF = vgg(contentV)
+
+ if(opt.layer == 'r41'):
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
+ else:
+ feature,transmatrix = matrix(cF,sF)
+ transfer = dec(feature)
+
+ transfer = transfer.clamp(0,1)
+ result_frames.append(transfer.squeeze(0).cpu().numpy())
+
+makeVideo(contents,style,result_frames,opt.outf)
diff --git a/graph_networks/LinearStyleTransfer/Train.py b/graph_networks/LinearStyleTransfer/Train.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c4a1bda9c99c63557f9d96739d5475cf56db3cb
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/Train.py
@@ -0,0 +1,185 @@
+import os
+import torch
+import argparse
+import torch.nn as nn
+import torch.optim as optim
+from libs.Loader import Dataset
+from libs.Matrix import MulLayer
+import torchvision.utils as vutils
+import torch.backends.cudnn as cudnn
+from libs.utils import print_options
+from libs.Criterion import LossCriterion
+from libs.models import encoder3,encoder4
+from libs.models import decoder3,decoder4
+from libs.models import encoder5 as loss_network
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth',
+ help='used for loss network')
+parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/",
+ help='path to wikiArt dataset')
+parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
+ help='path to MSCOCO dataset')
+parser.add_argument("--outf", default="trainingOutput/",
+ help='folder to output images and model checkpoints')
+parser.add_argument("--content_layers", default="r41",
+ help='layers for content')
+parser.add_argument("--style_layers", default="r11,r21,r31,r41",
+ help='layers for style')
+parser.add_argument("--batchSize", type=int,default=8,
+ help='batch size')
+parser.add_argument("--niter", type=int,default=100000,
+ help='iterations to train the model')
+parser.add_argument('--loadSize', type=int, default=300,
+ help='scale image size')
+parser.add_argument('--fineSize', type=int, default=256,
+ help='crop image size')
+parser.add_argument("--lr", type=float, default=1e-4,
+ help='learning rate')
+parser.add_argument("--content_weight", type=float, default=1.0,
+ help='content loss weight')
+parser.add_argument("--style_weight", type=float, default=0.02,
+ help='style loss weight')
+parser.add_argument("--log_interval", type=int, default=500,
+ help='log interval')
+parser.add_argument("--gpu_id", type=int, default=0,
+ help='which gpu to use')
+parser.add_argument("--save_interval", type=int, default=5000,
+ help='checkpoint save interval')
+parser.add_argument("--layer", default="r41",
+ help='which features to transfer, either r31 or r41')
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.content_layers = opt.content_layers.split(',')
+opt.style_layers = opt.style_layers.split(',')
+opt.cuda = torch.cuda.is_available()
+if(opt.cuda):
+ torch.cuda.set_device(opt.gpu_id)
+
+os.makedirs(opt.outf,exist_ok=True)
+cudnn.benchmark = True
+print_options(opt)
+
+################# DATA #################
+content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
+content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset,
+ batch_size = opt.batchSize,
+ shuffle = True,
+ num_workers = 1,
+ drop_last = True)
+content_loader = iter(content_loader_)
+style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize)
+style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset,
+ batch_size = opt.batchSize,
+ shuffle = True,
+ num_workers = 1,
+ drop_last = True)
+style_loader = iter(style_loader_)
+
+################# MODEL #################
+vgg5 = loss_network()
+if(opt.layer == 'r31'):
+ matrix = MulLayer('r31')
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ matrix = MulLayer('r41')
+ vgg = encoder4()
+ dec = decoder4()
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.decoder_dir))
+vgg5.load_state_dict(torch.load(opt.loss_network_dir))
+
+for param in vgg.parameters():
+ param.requires_grad = False
+for param in vgg5.parameters():
+ param.requires_grad = False
+for param in dec.parameters():
+ param.requires_grad = False
+
+################# LOSS & OPTIMIZER #################
+criterion = LossCriterion(opt.style_layers,
+ opt.content_layers,
+ opt.style_weight,
+ opt.content_weight)
+optimizer = optim.Adam(matrix.parameters(), opt.lr)
+
+################# GLOBAL VARIABLE #################
+contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ vgg5.cuda()
+ matrix.cuda()
+ contentV = contentV.cuda()
+ styleV = styleV.cuda()
+
+################# TRAINING #################
+def adjust_learning_rate(optimizer, iteration):
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = opt.lr / (1+iteration*1e-5)
+
+for iteration in range(1,opt.niter+1):
+ optimizer.zero_grad()
+ try:
+ content,_ = content_loader.next()
+ except IOError:
+ content,_ = content_loader.next()
+ except StopIteration:
+ content_loader = iter(content_loader_)
+ content,_ = content_loader.next()
+ except:
+ continue
+
+ try:
+ style,_ = style_loader.next()
+ except IOError:
+ style,_ = style_loader.next()
+ except StopIteration:
+ style_loader = iter(style_loader_)
+ style,_ = style_loader.next()
+ except:
+ continue
+
+ contentV.resize_(content.size()).copy_(content)
+ styleV.resize_(style.size()).copy_(style)
+
+ # forward
+ sF = vgg(styleV)
+ cF = vgg(contentV)
+
+ if(opt.layer == 'r41'):
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
+ else:
+ feature,transmatrix = matrix(cF,sF)
+ transfer = dec(feature)
+
+ sF_loss = vgg5(styleV)
+ cF_loss = vgg5(contentV)
+ tF = vgg5(transfer)
+ loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss)
+
+ # backward & optimization
+ loss.backward()
+ optimizer.step()
+ print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'%
+ (opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr']))
+
+ adjust_learning_rate(optimizer,iteration)
+
+ if((iteration) % opt.log_interval == 0):
+ transfer = transfer.clamp(0,1)
+ concat = torch.cat((content,style,transfer.cpu()),dim=0)
+ vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize)
+
+ if(iteration > 0 and (iteration) % opt.save_interval == 0):
+ torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer))
diff --git a/graph_networks/LinearStyleTransfer/TrainSPN.py b/graph_networks/LinearStyleTransfer/TrainSPN.py
new file mode 100644
index 0000000000000000000000000000000000000000..b188fb40be531bd96d96aacb134edbd0bdebb7c5
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/TrainSPN.py
@@ -0,0 +1,141 @@
+from __future__ import print_function
+import os
+import argparse
+
+from libs.SPN import SPN
+from libs.Loader import Dataset
+from libs.models import encoder4
+from libs.models import decoder4
+from libs.utils import print_options
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import torchvision.utils as vutils
+import torch.backends.cudnn as cudnn
+import torchvision.transforms as transforms
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r41.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--decoder_dir", default='models/dec_r41.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/",
+ help='path to MSCOCO dataset')
+parser.add_argument("--outf", default="trainingSPNOutput/",
+ help='folder to output images and model checkpoints')
+parser.add_argument("--layer", default="r41",
+ help='layers for content')
+parser.add_argument("--batchSize", type=int,default=8,
+ help='batch size')
+parser.add_argument("--niter", type=int,default=100000,
+ help='iterations to train the model')
+parser.add_argument('--loadSize', type=int, default=512,
+ help='scale image size')
+parser.add_argument('--fineSize', type=int, default=256,
+ help='crop image size')
+parser.add_argument("--lr", type=float, default=1e-3,
+ help='learning rate')
+parser.add_argument("--log_interval", type=int, default=500,
+ help='log interval')
+parser.add_argument("--save_interval", type=int, default=5000,
+ help='checkpoint save interval')
+parser.add_argument("--spn_num", type=int, default=1,
+ help='number of spn filters')
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.cuda = torch.cuda.is_available()
+print_options(opt)
+
+
+os.makedirs(opt.outf, exist_ok = True)
+
+cudnn.benchmark = True
+
+################# DATA #################
+content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize)
+content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset,
+ batch_size = opt.batchSize,
+ shuffle = True,
+ num_workers = 4,
+ drop_last = True)
+content_loader = iter(content_loader_)
+
+################# MODEL #################
+spn = SPN(spn=opt.spn_num)
+if(opt.layer == 'r31'):
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ vgg = encoder4()
+ dec = decoder4()
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.decoder_dir))
+
+for param in vgg.parameters():
+ param.requires_grad = False
+for param in dec.parameters():
+ param.requires_grad = False
+
+################# LOSS & OPTIMIZER #################
+criterion = nn.MSELoss(size_average=False)
+#optimizer_spn = optim.SGD(spn.parameters(), opt.lr)
+optimizer_spn = optim.Adam(spn.parameters(), opt.lr)
+
+################# GLOBAL VARIABLE #################
+contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ spn.cuda()
+ contentV = contentV.cuda()
+
+################# TRAINING #################
+def adjust_learning_rate(optimizer, iteration):
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = opt.lr / (1+iteration*1e-5)
+
+spn.train()
+for iteration in range(1,opt.niter+1):
+ optimizer_spn.zero_grad()
+ try:
+ content,_ = content_loader.next()
+ except IOError:
+ content,_ = content_loader.next()
+ except StopIteration:
+ content_loader = iter(content_loader_)
+ content,_ = content_loader.next()
+ except:
+ continue
+
+ contentV.resize_(content.size()).copy_(content)
+
+ # forward
+ cF = vgg(contentV)
+ transfer = dec(cF['r41'])
+
+
+ propagated = spn(transfer,contentV)
+ loss = criterion(propagated,contentV)
+
+ # backward & optimization
+ loss.backward()
+ #nn.utils.clip_grad_norm(spn.parameters(), 1000)
+ optimizer_spn.step()
+ print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f'
+ %(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr']))
+
+ adjust_learning_rate(optimizer_spn,iteration)
+
+ if((iteration) % opt.log_interval == 0):
+ transfer = transfer.clamp(0,1)
+ propagated = propagated.clamp(0,1)
+ vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration))
+ vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration))
+
+ if(iteration > 0 and (iteration) % opt.save_interval == 0):
+ torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer))
diff --git a/graph_networks/LinearStyleTransfer/__init__.py b/graph_networks/LinearStyleTransfer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/LinearStyleTransfer/libs/.DS_Store b/graph_networks/LinearStyleTransfer/libs/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..570ec19abcb7d1dfa0fd6f0f8eaac9d721f485e6
Binary files /dev/null and b/graph_networks/LinearStyleTransfer/libs/.DS_Store differ
diff --git a/graph_networks/LinearStyleTransfer/libs/Criterion.py b/graph_networks/LinearStyleTransfer/libs/Criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e287d77c891268ef7b2c7ef9a5b55f6303205c0
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/Criterion.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+
+class styleLoss(nn.Module):
+ def forward(self,input,target):
+ ib,ic,ih,iw = input.size()
+ iF = input.view(ib,ic,-1)
+ iMean = torch.mean(iF,dim=2)
+ iCov = GramMatrix()(input)
+
+ tb,tc,th,tw = target.size()
+ tF = target.view(tb,tc,-1)
+ tMean = torch.mean(tF,dim=2)
+ tCov = GramMatrix()(target)
+
+ loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov)
+ return loss/tb
+
+class GramMatrix(nn.Module):
+ def forward(self,input):
+ b, c, h, w = input.size()
+ f = input.view(b,c,h*w) # bxcx(hxw)
+ # torch.bmm(batch1, batch2, out=None) #
+ # batch1: bxmxp, batch2: bxpxn -> bxmxn #
+ G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc
+ return G.div_(c*h*w)
+
+class LossCriterion(nn.Module):
+ def __init__(self,style_layers,content_layers,style_weight,content_weight):
+ super(LossCriterion,self).__init__()
+
+ self.style_layers = style_layers
+ self.content_layers = content_layers
+ self.style_weight = style_weight
+ self.content_weight = content_weight
+
+ self.styleLosses = [styleLoss()] * len(style_layers)
+ self.contentLosses = [nn.MSELoss()] * len(content_layers)
+
+ def forward(self,tF,sF,cF):
+ # content loss
+ totalContentLoss = 0
+ for i,layer in enumerate(self.content_layers):
+ cf_i = cF[layer]
+ cf_i = cf_i.detach()
+ tf_i = tF[layer]
+ loss_i = self.contentLosses[i]
+ totalContentLoss += loss_i(tf_i,cf_i)
+ totalContentLoss = totalContentLoss * self.content_weight
+
+ # style loss
+ totalStyleLoss = 0
+ for i,layer in enumerate(self.style_layers):
+ sf_i = sF[layer]
+ sf_i = sf_i.detach()
+ tf_i = tF[layer]
+ loss_i = self.styleLosses[i]
+ totalStyleLoss += loss_i(tf_i,sf_i)
+ totalStyleLoss = totalStyleLoss * self.style_weight
+ loss = totalStyleLoss + totalContentLoss
+
+ return loss,totalStyleLoss,totalContentLoss
diff --git a/graph_networks/LinearStyleTransfer/libs/Loader.py b/graph_networks/LinearStyleTransfer/libs/Loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0616bc43835ebb112c9a87f5b206e3d03a5c5fe0
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/Loader.py
@@ -0,0 +1,44 @@
+import os
+from PIL import Image
+import torch.utils.data as data
+import torchvision.transforms as transforms
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+class Dataset(data.Dataset):
+ def __init__(self,dataPath,loadSize,fineSize,test=False,video=False):
+ super(Dataset,self).__init__()
+ self.dataPath = dataPath
+ self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)]
+ self.image_list = sorted(self.image_list)
+ if(video):
+ self.image_list = sorted(self.image_list)
+ if not test:
+ self.transform = transforms.Compose([
+ transforms.Resize(fineSize),
+ transforms.RandomCrop(fineSize),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor()])
+ else:
+ self.transform = transforms.Compose([
+ transforms.Resize(fineSize),
+ transforms.ToTensor()])
+
+ self.test = test
+
+ def __getitem__(self,index):
+ dataPath = os.path.join(self.dataPath,self.image_list[index])
+
+ Img = default_loader(dataPath)
+ ImgA = self.transform(Img)
+
+ imgName = self.image_list[index]
+ imgName = imgName.split('.')[0]
+ return ImgA,imgName
+
+ def __len__(self):
+ return len(self.image_list)
diff --git a/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py b/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py
new file mode 100644
index 0000000000000000000000000000000000000000..af1b8e07cff12ef8e72c76786d9d1714adda405a
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py
@@ -0,0 +1,162 @@
+from PIL import Image
+import torchvision.transforms as transforms
+import torchvision.utils as vutils
+import torch.utils.data as data
+from os import listdir
+from os.path import join
+import numpy as np
+import torch
+import os
+import torch.nn as nn
+from torch.autograd import Variable
+import numpy as np
+from libs.utils import whiten
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
+
+def default_loader(path,fineSize):
+ img = Image.open(path).convert('RGB')
+ w,h = img.size
+ if(w < h):
+ neww = fineSize
+ newh = h * neww / w
+ newh = int(newh / 8) * 8
+ else:
+ newh = fineSize
+ neww = w * newh / h
+ neww = int(neww / 8) * 8
+ img = img.resize((neww,newh))
+ return img
+
+def MaskHelper(seg,color):
+ # green
+ mask = torch.Tensor()
+ if(color == 'green'):
+ mask = torch.lt(seg[0],0.1)
+ mask = torch.mul(mask,torch.gt(seg[1],1-0.1))
+ mask = torch.mul(mask,torch.lt(seg[2],0.1))
+ elif(color == 'black'):
+ mask = torch.lt(seg[0], 0.1)
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
+ elif(color == 'white'):
+ mask = torch.gt(seg[0], 1-0.1)
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
+ elif(color == 'red'):
+ mask = torch.gt(seg[0], 1-0.1)
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
+ elif(color == 'blue'):
+ mask = torch.lt(seg[0], 0.1)
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
+ elif(color == 'yellow'):
+ mask = torch.gt(seg[0], 1-0.1)
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
+ elif(color == 'grey'):
+ mask = torch.lt(seg[0], 0.1)
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
+ mask = torch.mul(mask,torch.lt(seg[2], 0.1))
+ elif(color == 'lightblue'):
+ mask = torch.lt(seg[0], 0.1)
+ mask = torch.mul(mask,torch.gt(seg[1], 1-0.1))
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
+ elif(color == 'purple'):
+ mask = torch.gt(seg[0], 1-0.1)
+ mask = torch.mul(mask,torch.lt(seg[1], 0.1))
+ mask = torch.mul(mask,torch.gt(seg[2], 1-0.1))
+ else:
+ print('MaskHelper(): color not recognized, color = ' + color)
+ return mask.float()
+
+def ExtractMask(Seg):
+ # Given segmentation for content and style, we get a list of segmentation for each color
+ '''
+ Test Code:
+ content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg)
+ for i,mask in enumerate(content_masks):
+ vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True)
+ for i,mask in enumerate(style_masks):
+ vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True)
+ '''
+ color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple']
+ masks = []
+ for color in color_codes:
+ mask = MaskHelper(Seg,color)
+ masks.append(mask)
+ return masks
+
+def calculate_size(h,w,fineSize):
+ if(h > w):
+ newh = fineSize
+ neww = int(w * 1.0 * newh / h)
+ else:
+ neww = fineSize
+ newh = int(h * 1.0 * neww / w)
+ newh = (newh // 8) * 8
+ neww = (neww // 8) * 8
+ return neww, newh
+
+class Dataset(data.Dataset):
+ def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize):
+ super(Dataset,self).__init__()
+ self.contentPath = contentPath
+ self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
+ self.stylePath = stylePath
+ self.contentSegPath = contentSegPath
+ self.styleSegPath = styleSegPath
+ self.fineSize = fineSize
+
+ def __getitem__(self,index):
+ contentImgPath = os.path.join(self.contentPath,self.image_list[index])
+ styleImgPath = os.path.join(self.stylePath,self.image_list[index])
+ contentImg = default_loader(contentImgPath,self.fineSize)
+ styleImg = default_loader(styleImgPath,self.fineSize)
+
+ try:
+ contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index])
+ contentSegImg = default_loader(contentSegImgPath,self.fineSize)
+ except :
+ print('no mask provided, fake a whole black one')
+ contentSegImg = Image.new('RGB', (contentImg.size))
+
+ try:
+ styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index])
+ styleSegImg = default_loader(styleSegImgPath,self.fineSize)
+ except :
+ print('no mask provided, fake a whole black one')
+ styleSegImg = Image.new('RGB', (styleImg.size))
+
+
+ hs, ws = styleImg.size
+ newhs, newws = calculate_size(hs,ws,self.fineSize)
+
+ transform = transforms.Compose([
+ transforms.Resize((newhs, newws)),
+ transforms.ToTensor()])
+ # Turning segmentation images into masks
+ styleSegImg = transform(styleSegImg)
+ styleImgArbi = transform(styleImg)
+
+ hc, wc = contentImg.size
+ newhc, newwc = calculate_size(hc,wc,self.fineSize)
+
+ transform = transforms.Compose([
+ transforms.Resize((newhc, newwc)),
+ transforms.ToTensor()])
+ contentSegImg = transform(contentSegImg)
+ contentImgArbi = transform(contentImg)
+
+ content_masks = ExtractMask(contentSegImg)
+ style_masks = ExtractMask(styleSegImg)
+
+ ImgW = whiten(contentImgArbi.view(3,-1).double())
+ ImgW = ImgW.view(contentImgArbi.size()).float()
+
+ return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index]
+
+ def __len__(self):
+ return len(self.image_list)
diff --git a/graph_networks/LinearStyleTransfer/libs/Matrix.py b/graph_networks/LinearStyleTransfer/libs/Matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..2deb029457e6f1536e843efa2386bdce3fe86ad5
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/Matrix.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+
+class CNN(nn.Module):
+ def __init__(self,layer,matrixSize=32):
+ super(CNN,self).__init__()
+ if(layer == 'r31'):
+ # 256x64x64
+ self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128,64,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64,matrixSize,3,1,1))
+ elif(layer == 'r41'):
+ # 512x32x32
+ self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256,128,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128,matrixSize,3,1,1))
+
+ # 32x8x8
+ self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
+ #self.fc = nn.Linear(32*64,256*256)
+
+ def forward(self,x):
+ out = self.convs(x)
+ # 32x8x8
+ b,c,h,w = out.size()
+ out = out.view(b,c,-1)
+ # 32x64
+ out = torch.bmm(out,out.transpose(1,2)).div(h*w)
+ # 32x32
+ out = out.view(out.size(0),-1)
+ return self.fc(out)
+
+class MulLayer(nn.Module):
+ def __init__(self,layer,matrixSize=32):
+ super(MulLayer,self).__init__()
+ self.snet = CNN(layer,matrixSize)
+ self.cnet = CNN(layer,matrixSize)
+ self.matrixSize = matrixSize
+
+ if(layer == 'r41'):
+ self.compress = nn.Conv2d(512,matrixSize,1,1,0)
+ self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
+ elif(layer == 'r31'):
+ self.compress = nn.Conv2d(256,matrixSize,1,1,0)
+ self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
+ self.transmatrix = None
+
+ def forward(self,cF,sF,trans=True):
+ cFBK = cF.clone()
+ cb,cc,ch,cw = cF.size()
+ cFF = cF.view(cb,cc,-1)
+ cMean = torch.mean(cFF,dim=2,keepdim=True)
+ cMean = cMean.unsqueeze(3)
+ cMean = cMean.expand_as(cF)
+ cF = cF - cMean
+
+ sb,sc,sh,sw = sF.size()
+ sFF = sF.view(sb,sc,-1)
+ sMean = torch.mean(sFF,dim=2,keepdim=True)
+ sMean = sMean.unsqueeze(3)
+ sMeanC = sMean.expand_as(cF)
+ sMeanS = sMean.expand_as(sF)
+ sF = sF - sMeanS
+
+
+ compress_content = self.compress(cF)
+ b,c,h,w = compress_content.size()
+ compress_content = compress_content.view(b,c,-1)
+
+ if(trans):
+ cMatrix = self.cnet(cF)
+ sMatrix = self.snet(sF)
+
+ sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize)
+ cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize)
+ transmatrix = torch.bmm(sMatrix,cMatrix)
+ print(cMatrix)
+ transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w)
+ out = self.unzip(transfeature.view(b,c,h,w))
+ out = out + sMeanC
+ return out, transmatrix
+ else:
+ out = self.unzip(compress_content.view(b,c,h,w))
+ out = out + cMean
+ return out
diff --git a/graph_networks/LinearStyleTransfer/libs/MatrixTest.py b/graph_networks/LinearStyleTransfer/libs/MatrixTest.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5de4147ebda6c10a756566ccac485f0d83b0f1a
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/MatrixTest.py
@@ -0,0 +1,154 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+import numpy as np
+import cv2
+from torch.autograd import Variable
+import torchvision.utils as vutils
+
+
+class CNN(nn.Module):
+ def __init__(self,layer,matrixSize=32):
+ super(CNN,self).__init__()
+ # 256x64x64
+ if(layer == 'r31'):
+ self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128,64,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(64,matrixSize,3,1,1))
+ elif(layer == 'r41'):
+ # 512x32x32
+ self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256,128,3,1,1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128,matrixSize,3,1,1))
+ self.fc = nn.Linear(32*32,32*32)
+
+ def forward(self,x,masks,style=False):
+ color_code_number = 9
+ xb,xc,xh,xw = x.size()
+ x = x.view(xc,-1)
+ feature_sub_mean = x.clone()
+ for i in range(color_code_number):
+ mask = masks[i].clone().squeeze(0)
+ mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST)
+ mask = torch.FloatTensor(mask)
+ mask = mask.long()
+ if(torch.sum(mask) >= 10):
+ mask = mask.view(-1)
+
+ # dilation here
+ """
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
+ mask = mask.cpu().numpy()
+ mask = cv2.dilate(mask.astype(np.float32), kernel)
+ mask = torch.from_numpy(mask)
+ mask = mask.squeeze()
+ """
+
+ fgmask = (mask>0).nonzero().squeeze(1)
+ fgmask = fgmask.cuda()
+ selectFeature = torch.index_select(x,1,fgmask) # 32x96
+ # subtract mean
+ f_mean = torch.mean(selectFeature,1)
+ f_mean = f_mean.unsqueeze(1).expand_as(selectFeature)
+ selectFeature = selectFeature - f_mean
+ feature_sub_mean.index_copy_(1,fgmask,selectFeature)
+
+ feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw))
+ # 32x16x16
+ b,c,h,w = feature.size()
+ transMatrices = {}
+ feature = feature.view(c,-1)
+
+ for i in range(color_code_number):
+ mask = masks[i].clone().squeeze(0)
+ mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST)
+ mask = torch.FloatTensor(mask)
+ mask = mask.long()
+ if(torch.sum(mask) >= 10):
+ mask = mask.view(-1)
+ fgmask = Variable((mask==1).nonzero().squeeze(1))
+ fgmask = fgmask.cuda()
+ selectFeature = torch.index_select(feature,1,fgmask) # 32x96
+ tc,tN = selectFeature.size()
+
+ covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN)
+ transmatrix = self.fc(covMatrix.view(-1))
+ transMatrices[i] = transmatrix
+ return transMatrices,feature_sub_mean
+
+class MulLayer(nn.Module):
+ def __init__(self,layer,matrixSize=32):
+ super(MulLayer,self).__init__()
+ self.snet = CNN(layer)
+ self.cnet = CNN(layer)
+ self.matrixSize = matrixSize
+
+ if(layer == 'r41'):
+ self.compress = nn.Conv2d(512,matrixSize,1,1,0)
+ self.unzip = nn.Conv2d(matrixSize,512,1,1,0)
+ elif(layer == 'r31'):
+ self.compress = nn.Conv2d(256,matrixSize,1,1,0)
+ self.unzip = nn.Conv2d(matrixSize,256,1,1,0)
+
+ def forward(self,cF,sF,cmasks,smasks):
+
+ sb,sc,sh,sw = sF.size()
+
+ sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True)
+ cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False)
+
+ compress_content = self.compress(cF_sub_mean.view(cF.size()))
+ cb,cc,ch,cw = compress_content.size()
+ compress_content = compress_content.view(cc,-1)
+ transfeature = compress_content.clone()
+ color_code_number = 9
+ finalSMean = Variable(torch.zeros(cF.size()).cuda(0))
+ finalSMean = finalSMean.view(sc,-1)
+ for i in range(color_code_number):
+ cmask = cmasks[i].clone().squeeze(0)
+ smask = smasks[i].clone().squeeze(0)
+
+ cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST)
+ cmask = torch.FloatTensor(cmask)
+ cmask = cmask.long()
+ smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST)
+ smask = torch.FloatTensor(smask)
+ smask = smask.long()
+ if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10
+ and (i in sMatrices) and (i in cMatrices)):
+ cmask = cmask.view(-1)
+ fgcmask = Variable((cmask==1).nonzero().squeeze(1))
+ fgcmask = fgcmask.cuda()
+
+ smask = smask.view(-1)
+ fgsmask = Variable((smask==1).nonzero().squeeze(1))
+ fgsmask = fgsmask.cuda()
+
+ sFF = sF.view(sc,-1)
+ sFF_select = torch.index_select(sFF,1,fgsmask)
+ sMean = torch.mean(sFF_select,dim=1,keepdim=True)
+ sMean = sMean.view(1,sc,1,1)
+ sMean = sMean.expand_as(cF)
+
+ sMatrix = sMatrices[i]
+ cMatrix = cMatrices[i]
+
+ sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
+ cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
+
+ transmatrix = torch.mm(sMatrix,cMatrix) # (C*C)
+
+ compress_content_select = torch.index_select(compress_content,1,fgcmask)
+
+ transfeatureFG = torch.mm(transmatrix,compress_content_select)
+ transfeature.index_copy_(1,fgcmask,transfeatureFG)
+
+ sMean = sMean.contiguous()
+ sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask)
+ finalSMean.index_copy_(1,fgcmask,sMean_select)
+ out = self.unzip(transfeature.view(cb,cc,ch,cw))
+ return out + finalSMean.view(out.size())
diff --git a/graph_networks/LinearStyleTransfer/libs/SPN.py b/graph_networks/LinearStyleTransfer/libs/SPN.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9bf7c39411ec132afaa50e4889d78f4580c3a08
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/SPN.py
@@ -0,0 +1,156 @@
+import torch
+import torch.nn as nn
+from torchvision.models import vgg16
+from torch.autograd import Variable
+from collections import OrderedDict
+import torch.nn.functional as F
+import sys
+sys.path.append('../')
+from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
+
+class spn_block(nn.Module):
+ def __init__(self, horizontal, reverse):
+ super(spn_block, self).__init__()
+ self.propagator = GateRecurrent2dnoind(horizontal,reverse)
+
+ def forward(self,x,G1,G2,G3):
+ sum_abs = G1.abs() + G2.abs() + G3.abs()
+ sum_abs.data[sum_abs.data == 0] = 1e-6
+ mask_need_norm = sum_abs.ge(1)
+ mask_need_norm = mask_need_norm.float()
+ G1_norm = torch.div(G1, sum_abs)
+ G2_norm = torch.div(G2, sum_abs)
+ G3_norm = torch.div(G3, sum_abs)
+
+ G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
+ G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
+ G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
+
+ return self.propagator(x,G1,G2,G3)
+
+class VGG(nn.Module):
+ def __init__(self,nf):
+ super(VGG,self).__init__()
+ self.conv1 = nn.Conv2d(3,nf,3,padding = 1)
+ # 256 x 256
+ self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
+ self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1)
+ # 128 x 128
+ self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
+ self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1)
+ # 64 x 64
+ self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1)
+ # 32 x 32
+ self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1)
+
+ def forward(self,x):
+ output = {}
+ output['conv1'] = self.conv1(x)
+ x = F.relu(output['conv1'])
+ x = self.pool1(x)
+ output['conv2'] = self.conv2(x)
+ # 128 x 128
+ x = F.relu(output['conv2'])
+ x = self.pool2(x)
+ output['conv3'] = self.conv3(x)
+ # 64 x 64
+ x = F.relu(output['conv3'])
+ output['pool3'] = self.pool3(x)
+ # 32 x 32
+ output['conv4'] = self.conv4(output['pool3'])
+ return output
+
+class Decoder(nn.Module):
+ def __init__(self,nf=32,spn=1):
+ super(Decoder,self).__init__()
+ # 32 x 32
+ self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5
+ self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear')
+ self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8
+ nn.ELU(inplace=True))
+ # 64 x 64
+ self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear')
+ self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8
+ nn.ELU(inplace=True))
+ # 128 x 128
+ self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear')
+ self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8
+ nn.ELU(inplace=True))
+ if(spn == 1):
+ self.layer7 = nn.Conv2d(nf,nf*12,3,1,1)
+ else:
+ self.layer7 = nn.Conv2d(nf,nf*24,3,1,1)
+ self.spn = spn
+ # 256 x 256
+
+ def forward(self,encode_feature):
+ output = {}
+ output['0'] = self.layer0(encode_feature['conv4'])
+ output['1'] = self.layer1(output['0'])
+
+ output['2'] = self.layer2(output['1'])
+ output['2res'] = output['2'] + encode_feature['conv3']
+ # 64 x 64
+
+ output['3'] = self.layer3(output['2res'])
+ output['4'] = self.layer4(output['3'])
+ output['4res'] = output['4'] + encode_feature['conv2']
+ # 128 x 128
+
+ output['5'] = self.layer5(output['4res'])
+ output['6'] = self.layer6(output['5'])
+ output['6res'] = output['6'] + encode_feature['conv1']
+
+ output['7'] = self.layer7(output['6res'])
+
+ return output['7']
+
+
+class SPN(nn.Module):
+ def __init__(self,nf=32,spn=1):
+ super(SPN,self).__init__()
+ # conv for mask
+ self.mask_conv = nn.Conv2d(3,nf,3,1,1)
+
+ # guidance network
+ self.encoder = VGG(nf)
+ self.decoder = Decoder(nf,spn)
+
+ # spn blocks
+ self.left_right = spn_block(True,False)
+ self.right_left = spn_block(True,True)
+ self.top_down = spn_block(False, False)
+ self.down_top = spn_block(False,True)
+
+ # post upsample
+ self.post = nn.Conv2d(nf,3,3,1,1)
+ self.nf = nf
+
+ def forward(self,x,rgb):
+ # feature for mask
+ X = self.mask_conv(x)
+
+ # guidance
+ features = self.encoder(rgb)
+ guide = self.decoder(features)
+
+ G = torch.split(guide,self.nf,1)
+ out1 = self.left_right(X,G[0],G[1],G[2])
+ out2 = self.right_left(X,G[3],G[4],G[5])
+ out3 = self.top_down(X,G[6],G[7],G[8])
+ out4 = self.down_top(X,G[9],G[10],G[11])
+
+ out = torch.max(out1,out2)
+ out = torch.max(out,out3)
+ out = torch.max(out,out4)
+
+ return self.post(out)
+
+if __name__ == '__main__':
+ spn = SPN()
+ spn = spn.cuda()
+ for i in range(100):
+ x = Variable(torch.Tensor(1,3,256,256)).cuda()
+ rgb = Variable(torch.Tensor(1,3,256,256)).cuda()
+ output = spn(x,rgb)
+ print(output.size())
diff --git a/graph_networks/LinearStyleTransfer/libs/__init__.py b/graph_networks/LinearStyleTransfer/libs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/LinearStyleTransfer/libs/models.py b/graph_networks/LinearStyleTransfer/libs/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..964f273f2d28083b6aa15d09ed3876111c29a49d
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/models.py
@@ -0,0 +1,662 @@
+import torch
+import torch.nn as nn
+
+class encoder3(nn.Module):
+ def __init__(self):
+ super(encoder3,self).__init__()
+ # vgg
+ # 224 x 224
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
+ # 226 x 226
+
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
+ self.relu2 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
+ self.relu3 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
+ # 112 x 112
+
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
+ self.relu4 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
+ self.relu5 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True)
+ # 56 x 56
+
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
+ self.relu6 = nn.ReLU(inplace=True)
+ # 56 x 56
+ def forward(self,x):
+ out = self.conv1(x)
+ out = self.reflecPad1(out)
+ out = self.conv2(out)
+ out = self.relu2(out)
+ out = self.reflecPad3(out)
+ out = self.conv3(out)
+ pool1 = self.relu3(out)
+ out,pool_idx = self.maxPool(pool1)
+ out = self.reflecPad4(out)
+ out = self.conv4(out)
+ out = self.relu4(out)
+ out = self.reflecPad5(out)
+ out = self.conv5(out)
+ pool2 = self.relu5(out)
+ out,pool_idx2 = self.maxPool2(pool2)
+ out = self.reflecPad6(out)
+ out = self.conv6(out)
+ out = self.relu6(out)
+ return out
+
+class decoder3(nn.Module):
+ def __init__(self):
+ super(decoder3,self).__init__()
+ # decoder
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv7 = nn.Conv2d(256,128,3,1,0)
+ self.relu7 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
+ # 112 x 112
+
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv8 = nn.Conv2d(128,128,3,1,0)
+ self.relu8 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv9 = nn.Conv2d(128,64,3,1,0)
+ self.relu9 = nn.ReLU(inplace=True)
+
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 224 x 224
+
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv10 = nn.Conv2d(64,64,3,1,0)
+ self.relu10 = nn.ReLU(inplace=True)
+
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv11 = nn.Conv2d(64,3,3,1,0)
+
+ def forward(self,x):
+ output = {}
+ out = self.reflecPad7(x)
+ out = self.conv7(out)
+ out = self.relu7(out)
+ out = self.unpool(out)
+ out = self.reflecPad8(out)
+ out = self.conv8(out)
+ out = self.relu8(out)
+ out = self.reflecPad9(out)
+ out = self.conv9(out)
+ out_relu9 = self.relu9(out)
+ out = self.unpool2(out_relu9)
+ out = self.reflecPad10(out)
+ out = self.conv10(out)
+ out = self.relu10(out)
+ out = self.reflecPad11(out)
+ out = self.conv11(out)
+ return out
+
+class encoder4(nn.Module):
+ def __init__(self):
+ super(encoder4,self).__init__()
+ # vgg
+ # 224 x 224
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
+ # 226 x 226
+
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
+ self.relu2 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
+ self.relu3 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 112 x 112
+
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
+ self.relu4 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
+ self.relu5 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 56 x 56
+
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
+ self.relu6 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
+ self.relu7 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
+ self.relu8 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
+ self.relu9 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 28 x 28
+
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
+ self.relu10 = nn.ReLU(inplace=True)
+ # 28 x 28
+ def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None):
+ output = {}
+ out = self.conv1(x)
+ out = self.reflecPad1(out)
+ out = self.conv2(out)
+ output['r11'] = self.relu2(out)
+ out = self.reflecPad7(output['r11'])
+
+ out = self.conv3(out)
+ output['r12'] = self.relu3(out)
+
+ output['p1'] = self.maxPool(output['r12'])
+ out = self.reflecPad4(output['p1'])
+ out = self.conv4(out)
+ output['r21'] = self.relu4(out)
+ out = self.reflecPad7(output['r21'])
+
+ out = self.conv5(out)
+ output['r22'] = self.relu5(out)
+
+ output['p2'] = self.maxPool2(output['r22'])
+ out = self.reflecPad6(output['p2'])
+ out = self.conv6(out)
+ output['r31'] = self.relu6(out)
+ if(matrix31 is not None):
+ feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
+ out = self.reflecPad7(feature3)
+ else:
+ out = self.reflecPad7(output['r31'])
+ out = self.conv7(out)
+ output['r32'] = self.relu7(out)
+
+ out = self.reflecPad8(output['r32'])
+ out = self.conv8(out)
+ output['r33'] = self.relu8(out)
+
+ out = self.reflecPad9(output['r33'])
+ out = self.conv9(out)
+ output['r34'] = self.relu9(out)
+
+ output['p3'] = self.maxPool3(output['r34'])
+ out = self.reflecPad10(output['p3'])
+ out = self.conv10(out)
+ output['r41'] = self.relu10(out)
+
+ return output
+
+class decoder4(nn.Module):
+ def __init__(self):
+ super(decoder4,self).__init__()
+ # decoder
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv11 = nn.Conv2d(512,256,3,1,0)
+ self.relu11 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
+ # 56 x 56
+
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv12 = nn.Conv2d(256,256,3,1,0)
+ self.relu12 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv13 = nn.Conv2d(256,256,3,1,0)
+ self.relu13 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv14 = nn.Conv2d(256,256,3,1,0)
+ self.relu14 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv15 = nn.Conv2d(256,128,3,1,0)
+ self.relu15 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 112 x 112
+
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv16 = nn.Conv2d(128,128,3,1,0)
+ self.relu16 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv17 = nn.Conv2d(128,64,3,1,0)
+ self.relu17 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 224 x 224
+
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv18 = nn.Conv2d(64,64,3,1,0)
+ self.relu18 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv19 = nn.Conv2d(64,3,3,1,0)
+
+ def forward(self,x):
+ # decoder
+ out = self.reflecPad11(x)
+ out = self.conv11(out)
+ out = self.relu11(out)
+ out = self.unpool(out)
+ out = self.reflecPad12(out)
+ out = self.conv12(out)
+
+ out = self.relu12(out)
+ out = self.reflecPad13(out)
+ out = self.conv13(out)
+ out = self.relu13(out)
+ out = self.reflecPad14(out)
+ out = self.conv14(out)
+ out = self.relu14(out)
+ out = self.reflecPad15(out)
+ out = self.conv15(out)
+ out = self.relu15(out)
+ out = self.unpool2(out)
+ out = self.reflecPad16(out)
+ out = self.conv16(out)
+ out = self.relu16(out)
+ out = self.reflecPad17(out)
+ out = self.conv17(out)
+ out = self.relu17(out)
+ out = self.unpool3(out)
+ out = self.reflecPad18(out)
+ out = self.conv18(out)
+ out = self.relu18(out)
+ out = self.reflecPad19(out)
+ out = self.conv19(out)
+ return out
+
+class decoder4(nn.Module):
+ def __init__(self):
+ super(decoder4,self).__init__()
+ # decoder
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv11 = nn.Conv2d(512,256,3,1,0)
+ self.relu11 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
+ # 56 x 56
+
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv12 = nn.Conv2d(256,256,3,1,0)
+ self.relu12 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv13 = nn.Conv2d(256,256,3,1,0)
+ self.relu13 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv14 = nn.Conv2d(256,256,3,1,0)
+ self.relu14 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv15 = nn.Conv2d(256,128,3,1,0)
+ self.relu15 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 112 x 112
+
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv16 = nn.Conv2d(128,128,3,1,0)
+ self.relu16 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv17 = nn.Conv2d(128,64,3,1,0)
+ self.relu17 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 224 x 224
+
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv18 = nn.Conv2d(64,64,3,1,0)
+ self.relu18 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv19 = nn.Conv2d(64,3,3,1,0)
+
+ def forward(self,x):
+ # decoder
+ out = self.reflecPad11(x)
+ out = self.conv11(out)
+ out = self.relu11(out)
+ out = self.unpool(out)
+ out = self.reflecPad12(out)
+ out = self.conv12(out)
+
+ out = self.relu12(out)
+ out = self.reflecPad13(out)
+ out = self.conv13(out)
+ out = self.relu13(out)
+ out = self.reflecPad14(out)
+ out = self.conv14(out)
+ out = self.relu14(out)
+ out = self.reflecPad15(out)
+ out = self.conv15(out)
+ out = self.relu15(out)
+ out = self.unpool2(out)
+ out = self.reflecPad16(out)
+ out = self.conv16(out)
+ out = self.relu16(out)
+ out = self.reflecPad17(out)
+ out = self.conv17(out)
+ out = self.relu17(out)
+ out = self.unpool3(out)
+ out = self.reflecPad18(out)
+ out = self.conv18(out)
+ out = self.relu18(out)
+ out = self.reflecPad19(out)
+ out = self.conv19(out)
+ return out
+
+class encoder5(nn.Module):
+ def __init__(self):
+ super(encoder5,self).__init__()
+ # vgg
+ # 224 x 224
+ self.conv1 = nn.Conv2d(3,3,1,1,0)
+ self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1))
+ # 226 x 226
+
+ self.conv2 = nn.Conv2d(3,64,3,1,0)
+ self.relu2 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv3 = nn.Conv2d(64,64,3,1,0)
+ self.relu3 = nn.ReLU(inplace=True)
+ # 224 x 224
+
+ self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 112 x 112
+
+ self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv4 = nn.Conv2d(64,128,3,1,0)
+ self.relu4 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv5 = nn.Conv2d(128,128,3,1,0)
+ self.relu5 = nn.ReLU(inplace=True)
+ # 112 x 112
+
+ self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 56 x 56
+
+ self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv6 = nn.Conv2d(128,256,3,1,0)
+ self.relu6 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv7 = nn.Conv2d(256,256,3,1,0)
+ self.relu7 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv8 = nn.Conv2d(256,256,3,1,0)
+ self.relu8 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv9 = nn.Conv2d(256,256,3,1,0)
+ self.relu9 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2)
+ # 28 x 28
+
+ self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv10 = nn.Conv2d(256,512,3,1,0)
+ self.relu10 = nn.ReLU(inplace=True)
+
+ self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv11 = nn.Conv2d(512,512,3,1,0)
+ self.relu11 = nn.ReLU(inplace=True)
+
+ self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv12 = nn.Conv2d(512,512,3,1,0)
+ self.relu12 = nn.ReLU(inplace=True)
+
+ self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv13 = nn.Conv2d(512,512,3,1,0)
+ self.relu13 = nn.ReLU(inplace=True)
+
+ self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2)
+ self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv14 = nn.Conv2d(512,512,3,1,0)
+ self.relu14 = nn.ReLU(inplace=True)
+
+ def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None):
+ output = {}
+ out = self.conv1(x)
+ out = self.reflecPad1(out)
+ out = self.conv2(out)
+ output['r11'] = self.relu2(out)
+ out = self.reflecPad7(output['r11'])
+
+ #out = self.reflecPad3(output['r11'])
+ out = self.conv3(out)
+ output['r12'] = self.relu3(out)
+
+ output['p1'] = self.maxPool(output['r12'])
+ out = self.reflecPad4(output['p1'])
+ out = self.conv4(out)
+ output['r21'] = self.relu4(out)
+ out = self.reflecPad7(output['r21'])
+
+ #out = self.reflecPad5(output['r21'])
+ out = self.conv5(out)
+ output['r22'] = self.relu5(out)
+
+ output['p2'] = self.maxPool2(output['r22'])
+ out = self.reflecPad6(output['p2'])
+ out = self.conv6(out)
+ output['r31'] = self.relu6(out)
+ if(styleV256 is not None):
+ feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256)
+ out = self.reflecPad7(feature)
+ else:
+ out = self.reflecPad7(output['r31'])
+ out = self.conv7(out)
+ output['r32'] = self.relu7(out)
+
+ out = self.reflecPad8(output['r32'])
+ out = self.conv8(out)
+ output['r33'] = self.relu8(out)
+
+ out = self.reflecPad9(output['r33'])
+ out = self.conv9(out)
+ output['r34'] = self.relu9(out)
+
+ output['p3'] = self.maxPool3(output['r34'])
+ out = self.reflecPad10(output['p3'])
+ out = self.conv10(out)
+ output['r41'] = self.relu10(out)
+
+ out = self.reflecPad11(output['r41'])
+ out = self.conv11(out)
+ output['r42'] = self.relu11(out)
+
+ out = self.reflecPad12(output['r42'])
+ out = self.conv12(out)
+ output['r43'] = self.relu12(out)
+
+ out = self.reflecPad13(output['r43'])
+ out = self.conv13(out)
+ output['r44'] = self.relu13(out)
+
+ output['p4'] = self.maxPool4(output['r44'])
+
+ out = self.reflecPad14(output['p4'])
+ out = self.conv14(out)
+ output['r51'] = self.relu14(out)
+ return output
+
+class decoder5(nn.Module):
+ def __init__(self):
+ super(decoder5,self).__init__()
+
+ # decoder
+ self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv15 = nn.Conv2d(512,512,3,1,0)
+ self.relu15 = nn.ReLU(inplace=True)
+
+ self.unpool = nn.UpsamplingNearest2d(scale_factor=2)
+ # 28 x 28
+
+ self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv16 = nn.Conv2d(512,512,3,1,0)
+ self.relu16 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv17 = nn.Conv2d(512,512,3,1,0)
+ self.relu17 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv18 = nn.Conv2d(512,512,3,1,0)
+ self.relu18 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv19 = nn.Conv2d(512,256,3,1,0)
+ self.relu19 = nn.ReLU(inplace=True)
+ # 28 x 28
+
+ self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 56 x 56
+
+ self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv20 = nn.Conv2d(256,256,3,1,0)
+ self.relu20 = nn.ReLU(inplace=True)
+ # 56 x 56
+
+ self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv21 = nn.Conv2d(256,256,3,1,0)
+ self.relu21 = nn.ReLU(inplace=True)
+
+ self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv22 = nn.Conv2d(256,256,3,1,0)
+ self.relu22 = nn.ReLU(inplace=True)
+
+ self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv23 = nn.Conv2d(256,128,3,1,0)
+ self.relu23 = nn.ReLU(inplace=True)
+
+ self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2)
+ # 112 X 112
+
+ self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv24 = nn.Conv2d(128,128,3,1,0)
+ self.relu24 = nn.ReLU(inplace=True)
+
+ self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv25 = nn.Conv2d(128,64,3,1,0)
+ self.relu25 = nn.ReLU(inplace=True)
+
+ self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2)
+
+ self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv26 = nn.Conv2d(64,64,3,1,0)
+ self.relu26 = nn.ReLU(inplace=True)
+
+ self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1))
+ self.conv27 = nn.Conv2d(64,3,3,1,0)
+
+ def forward(self,x):
+ # decoder
+ out = self.reflecPad15(x)
+ out = self.conv15(out)
+ out = self.relu15(out)
+ out = self.unpool(out)
+ out = self.reflecPad16(out)
+ out = self.conv16(out)
+ out = self.relu16(out)
+ out = self.reflecPad17(out)
+ out = self.conv17(out)
+ out = self.relu17(out)
+ out = self.reflecPad18(out)
+ out = self.conv18(out)
+ out = self.relu18(out)
+ out = self.reflecPad19(out)
+ out = self.conv19(out)
+ out = self.relu19(out)
+ out = self.unpool2(out)
+ out = self.reflecPad20(out)
+ out = self.conv20(out)
+ out = self.relu20(out)
+ out = self.reflecPad21(out)
+ out = self.conv21(out)
+ out = self.relu21(out)
+ out = self.reflecPad22(out)
+ out = self.conv22(out)
+ out = self.relu22(out)
+ out = self.reflecPad23(out)
+ out = self.conv23(out)
+ out = self.relu23(out)
+ out = self.unpool3(out)
+ out = self.reflecPad24(out)
+ out = self.conv24(out)
+ out = self.relu24(out)
+ out = self.reflecPad25(out)
+ out = self.conv25(out)
+ out = self.relu25(out)
+ out = self.unpool4(out)
+ out = self.reflecPad26(out)
+ out = self.conv26(out)
+ out = self.relu26(out)
+ out = self.reflecPad27(out)
+ out = self.conv27(out)
+ return out
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f3c3c7e83fb5e80b0a63d14d4e546772f49cc1e6
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md
@@ -0,0 +1,12 @@
+# pytorch_spn
+To build, install [pytorch](https://github.com/pytorch) and run:
+
+$ sh make.sh
+
+See left_right_demo.py for usage:
+
+$ mv left_right_demo.py ../
+
+$ python left_right_demo.py
+
+The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git).
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab752cdaacefb3a3a2987dd235f75daf3df150d
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py
@@ -0,0 +1,15 @@
+
+from torch.utils.ffi import _wrap_function
+from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi
+
+__all__ = []
+def _import_symbols(locals):
+ for symbol in dir(_lib):
+ fn = getattr(_lib, symbol)
+ if callable(fn):
+ locals[symbol] = _wrap_function(fn, _ffi)
+ else:
+ locals[symbol] = fn
+ __all__.append(symbol)
+
+_import_symbols(locals())
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..70aecaecdc5b935cf2addd4e80786ee922d8d259
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py
@@ -0,0 +1,34 @@
+import os
+import torch
+from torch.utils.ffi import create_extension
+
+this_file = os.path.dirname(__file__)
+
+sources = []
+headers = []
+defines = []
+with_cuda = False
+
+if torch.cuda.is_available():
+ print('Including CUDA code.')
+ sources += ['src/gaterecurrent2dnoind_cuda.c']
+ headers += ['src/gaterecurrent2dnoind_cuda.h']
+ defines += [('WITH_CUDA', None)]
+ with_cuda = True
+
+this_file = os.path.dirname(os.path.realpath(__file__))
+extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o']
+extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
+
+ffi = create_extension(
+ '_ext.gaterecurrent2dnoind',
+ headers=headers,
+ sources=sources,
+ define_macros=defines,
+ relative_to=__file__,
+ with_cuda=with_cuda,
+ extra_objects=extra_objects
+)
+
+if __name__ == '__main__':
+ ffi.build()
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py
new file mode 100644
index 0000000000000000000000000000000000000000..972b726ce8093322c22ecd54aa0460c59273a985
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py
@@ -0,0 +1,47 @@
+import torch
+from torch.autograd import Function
+from .._ext import gaterecurrent2dnoind as gaterecurrent2d
+
+class GateRecurrent2dnoindFunction(Function):
+ def __init__(self, horizontal_, reverse_):
+ self.horizontal = horizontal_
+ self.reverse = reverse_
+
+ def forward(self, X, G1, G2, G3):
+ num, channels, height, width = X.size()
+ output = torch.zeros(num, channels, height, width)
+
+ if not X.is_cuda:
+ print("cpu version is not ready at this time")
+ return 0
+ else:
+ output = output.cuda()
+ gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output)
+
+ self.X = X
+ self.G1 = G1
+ self.G2 = G2
+ self.G3 = G3
+ self.output = output
+ self.hiddensize = X.size()
+ return output
+
+ def backward(self, grad_output):
+ assert(self.hiddensize is not None and grad_output.is_cuda)
+ num, channels, height, width = self.hiddensize
+
+ grad_X = torch.zeros(num, channels, height, width).cuda()
+ grad_G1 = torch.zeros(num, channels, height, width).cuda()
+ grad_G2 = torch.zeros(num, channels, height, width).cuda()
+ grad_G3 = torch.zeros(num, channels, height, width).cuda()
+
+ gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3)
+
+ del self.hiddensize
+ del self.G1
+ del self.G2
+ del self.G3
+ del self.output
+ del self.X
+
+ return grad_X, grad_G1, grad_G2, grad_G3
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8d58ebbafe29ed90e9f8530c68402156ecce74
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py
@@ -0,0 +1,46 @@
+"""
+An example of left->right propagation
+
+Other direction settings:
+left->right: Propagator = GateRecurrent2dnoind(True,False)
+right->left: Propagator = GateRecurrent2dnoind(True,True)
+top->bottom: Propagator = GateRecurrent2dnoind(False,False)
+bottom->top: Propagator = GateRecurrent2dnoind(False,True)
+
+X: any signal/feature map to be filtered
+G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom)
+
+Note:
+1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images)
+2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper)
+"""
+import torch
+from torch.autograd import Variable
+from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind
+
+Propagator = GateRecurrent2dnoind(True,False)
+
+X = Variable(torch.randn(1,3,10,10))
+G1 = Variable(torch.randn(1,3,10,10))
+G2 = Variable(torch.randn(1,3,10,10))
+G3 = Variable(torch.randn(1,3,10,10))
+
+sum_abs = G1.abs() + G2.abs() + G3.abs()
+mask_need_norm = sum_abs.ge(1)
+mask_need_norm = mask_need_norm.float()
+G1_norm = torch.div(G1, sum_abs)
+G2_norm = torch.div(G2, sum_abs)
+G3_norm = torch.div(G3, sum_abs)
+
+G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm
+G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm
+G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm
+
+X = X.cuda()
+G1 = G1.cuda()
+G2 = G2.cuda()
+G3 = G3.cuda()
+
+output = Propagator.forward(X,G1,G2,G3)
+print(X)
+print(output)
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7bafe812c4c8f94986f42e71fcf4f8ca6aa72904
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh
@@ -0,0 +1,9 @@
+#!/usr/bin/env bash
+
+CUDA_PATH=/usr/local/cuda/
+
+cd src/cuda/
+echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..."
+nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
+cd ../../
+python build.py
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py
@@ -0,0 +1 @@
+
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ffe82da20aea4546b13fd2b0bdd6ac7a548192
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py
@@ -0,0 +1,12 @@
+import torch.nn as nn
+from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction
+
+class GateRecurrent2dnoind(nn.Module):
+ """docstring for ."""
+ def __init__(self, horizontal_, reverse_):
+ super(GateRecurrent2dnoind, self).__init__()
+ self.horizontal = horizontal_
+ self.reverse = reverse_
+
+ def forward(self, X, G1, G2, G3):
+ return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3)
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..6c7b6212579ceb149ee62bdbb280ee4aa553a616
Binary files /dev/null and b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store differ
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ddad6f59cbbf7675a42f0679fe236b8745e5f1e9
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu
@@ -0,0 +1,697 @@
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#include
+#include
+#include
+#include "gaterecurrent2dnoind_kernel.h"
+
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+ i += blockDim.x * gridDim.x)
+
+__device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse)
+{
+ if(horizontal && ! reverse) // left -> right
+ {
+ if(w1>w2)
+ {
+ out[0]=h1;
+ out[1]=w1;
+ }
+ else
+ {
+ out[0]=h2;
+ out[1]=w2;
+ }
+ }
+ if(horizontal && reverse) // right -> left
+ {
+ if(w1 bottom
+ {
+ if(h1>h2)
+ {
+ out[0]=h1;
+ out[1]=w1;
+ }
+ else
+ {
+ out[0]=h2;
+ out[1]=w2;
+ }
+ }
+ if(!horizontal && reverse) // bottom -> top
+ {
+ if(h1=height)
+ return 0;
+ if(w<0 || w >= width)
+ return 0;
+
+ return data[n*channels*height*width + c * height*width + h * width + w];
+}
+
+__device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v)
+{
+ if(h<0 || h >=height)
+ return ;
+ if(w<0 || w >= width)
+ return ;
+
+ data[n*channels*height*width + c * height*width + h * width + w]=v;
+}
+
+__device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse)
+{
+ if(h1<0 || h1 >=height)
+ return 0;
+ if(w1<0 || w1 >= width)
+ return 0;
+ if(h2<0 || h2 >=height)
+ return 0;
+ if(w2<0 || w2 >= width)
+ return 0;
+ int idx[2];
+
+ get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
+
+ int h = idx[0];
+ int w = idx[1];
+
+ return data[n*channels*height*width + c * height*width + h * width + w];
+}
+
+__device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v)
+{
+ if(h1<0 || h1 >=height)
+ return ;
+ if(w1<0 || w1 >= width)
+ return ;
+ if(h2<0 || h2 >=height)
+ return ;
+ if(w2<0 || w2 >= width)
+ return ;
+ int idx[2];
+
+ get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse);
+
+ int h = idx[0];
+ int w = idx[1];
+
+ data[n*channels*height*width + c * height*width + h * width + w]=v;
+}
+
+// we do not use set_gate_add_sf(...) in the caffe implimentation
+// avoid using atomicAdd
+
+__global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) {
+ CUDA_1D_KERNEL_LOOP(index, count) {
+
+ int hc_count = height * channels;
+
+ int n,c,h,w;
+ int temp=index;
+ w = T;
+ n = temp / hc_count;
+ temp = temp % hc_count;
+ c = temp / height;
+ temp = temp % height;
+ h = temp;
+
+
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
+
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
+
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse);
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1);
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
+
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
+
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
+
+ float h_data = x_hype + h_hype;
+
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
+
+ }
+}
+
+__global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
+ CUDA_1D_KERNEL_LOOP(index, count) {
+
+ int hc_count = height * channels;
+ int n,c,h,w;
+ int temp=index;
+ w = T;
+ n = temp / hc_count;
+ temp = temp % hc_count;
+ c = temp / height;
+ temp = temp % height;
+ h = temp;
+
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
+
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
+
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse);
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1);
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
+
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
+
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
+
+ float h_data = x_hype + h_hype;
+
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
+
+ }
+}
+
+__global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
+ CUDA_1D_KERNEL_LOOP(index, count) {
+
+ int wc_count = width * channels;
+
+ int n,c,h,w;
+ int temp=index;
+ h = T;
+ n = temp / wc_count;
+ temp = temp % wc_count;
+ c = temp / width;
+ temp = temp % width;
+ w = temp;
+
+
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
+
+
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse);
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1);
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
+
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse);
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w);
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
+
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse);
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1);
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
+
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
+
+ float h_data = x_hype + h_hype;
+
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
+
+ }
+}
+
+__global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) {
+ CUDA_1D_KERNEL_LOOP(index, count) {
+
+ int wc_count = width * channels;
+
+ int n,c,h,w;
+ int temp=index;
+ h = T;
+ n = temp / wc_count;
+ temp = temp % wc_count;
+ c = temp / width;
+ temp = temp % width;
+ w = temp;
+
+
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
+
+
+ float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse);
+ float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1);
+ float h1_minus1 = g_data_1 * h_minus1_data_1;
+
+
+ float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse);
+ float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w);
+ float h2_minus1 = g_data_2 * h_minus1_data_2;
+
+ float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse);
+ float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1);
+ float h3_minus1 = g_data_3 * h_minus1_data_3;
+
+ float h_hype = h1_minus1 + h2_minus1 + h3_minus1;
+ float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data;
+
+ float h_data = x_hype + h_hype;
+
+ set_data_sf(H,num,channels,height,width,n,c,h,w,h_data);
+
+ }
+}
+
+
+__global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) {
+ CUDA_1D_KERNEL_LOOP(index, count) {
+
+ int hc_count = height * channels;
+
+ int n,c,h,w;
+ int temp=index;
+
+ w = T;
+ n = temp / hc_count;
+ temp = temp % hc_count;
+ c = temp / height;
+ temp = temp % height;
+ h = temp;
+
+ float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w);
+
+ //h(t)_diff = top(t)_diff
+ float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w);
+
+ //h(t)_diff += h(t+1)_diff * g(t+1) if t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = height_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = width_ - 1; t >= 0; t--) {
+ forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = width_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t=0; t< height_; t++) {
+ forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = width_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = height_-1; t >= 0; t--) {
+ forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = height_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = width_ -1; t>=0; t--)
+ {
+ backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = height_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = width_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = height_-1; t>=0; t--)
+ {
+ backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream)
+{
+ int count = width_ * channels_ * num_;
+ int kThreadsPerBlock = 1024;
+ cudaError_t err;
+
+ for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_);
+
+ err = cudaGetLastError();
+ if(cudaSuccess != err)
+ {
+ fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
+ exit( -1 );
+ }
+ }
+ return 1;
+}
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o
new file mode 100644
index 0000000000000000000000000000000000000000..5d12c0a4eb2089d47b24bb29aec8ded0b91f0517
Binary files /dev/null and b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o differ
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h
new file mode 100644
index 0000000000000000000000000000000000000000..ebfadd16a03ef4c8d34365940ce034ee0caea1fe
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h
@@ -0,0 +1,28 @@
+#ifndef _GATERECURRENT2DNOIND_KERNEL
+#define _GATERECURRENT2DNOIND_KERNEL
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
+
+int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c
new file mode 100644
index 0000000000000000000000000000000000000000..2211c588e4a264c22b1f9e47cde53f3d986c14a3
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c
@@ -0,0 +1,91 @@
+// gaterecurrent2dnoind_cuda.c
+#include
+#include
+#include "gaterecurrent2dnoind_cuda.h"
+#include "cuda/gaterecurrent2dnoind_kernel.h"
+
+// typedef bool boolean;
+
+// this symbol will be resolved automatically from PyTorch libs
+extern THCState *state;
+
+int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output)
+{
+ // Grab the input tensor to flat
+ float * X_data = THCudaTensor_data(state, X);
+ float * G1_data = THCudaTensor_data(state, G1);
+ float * G2_data = THCudaTensor_data(state, G2);
+ float * G3_data = THCudaTensor_data(state, G3);
+ float * H_data = THCudaTensor_data(state, output);
+
+ // dimensions
+ int num_ = THCudaTensor_size(state, X, 0);
+ int channels_ = THCudaTensor_size(state, X, 1);
+ int height_ = THCudaTensor_size(state, X, 2);
+ int width_ = THCudaTensor_size(state, X, 3);
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+
+ if(horizontal_ && !reverse_) // left to right
+ {
+ //const int count = height_ * channels_ * num_;
+ Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
+ }
+ else if(horizontal_ && reverse_) // right to left
+ {
+ Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
+ }
+ else if(!horizontal_ && !reverse_) // top to bottom
+ {
+ Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
+ }
+ else
+ {
+ Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream);
+ }
+
+ return 1;
+}
+
+int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad)
+{
+ //Grab the input tensor to flat
+ float * X_data = THCudaTensor_data(state, X);
+ float * G1_data = THCudaTensor_data(state, G1);
+ float * G2_data = THCudaTensor_data(state, G2);
+ float * G3_data = THCudaTensor_data(state, G3);
+ float * H_data = THCudaTensor_data(state, top);
+
+ float * H_diff = THCudaTensor_data(state, top_grad);
+
+ float * X_diff = THCudaTensor_data(state, X_grad);
+ float * G1_diff = THCudaTensor_data(state, G1_grad);
+ float * G2_diff = THCudaTensor_data(state, G2_grad);
+ float * G3_diff = THCudaTensor_data(state, G3_grad);
+
+ // dimensions
+ int num_ = THCudaTensor_size(state, X, 0);
+ int channels_ = THCudaTensor_size(state, X, 1);
+ int height_ = THCudaTensor_size(state, X, 2);
+ int width_ = THCudaTensor_size(state, X, 3);
+
+ cudaStream_t stream = THCState_getCurrentStream(state);
+
+ if(horizontal_ && ! reverse_) //left to right
+ {
+ Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
+ }
+ else if(horizontal_ && reverse_) //right to left
+ {
+ Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
+ }
+ else if(!horizontal_ && !reverse_) //top to bottom
+ {
+ Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
+ }
+ else {
+ Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream);
+ }
+
+ return 1;
+}
diff --git a/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h
new file mode 100644
index 0000000000000000000000000000000000000000..da60473306ca9c3addcf0a2164bd13c86b226d7e
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h
@@ -0,0 +1,6 @@
+
+// #include
+// gaterecurrent2dnoind_cuda.h
+int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output);
+
+int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff);
diff --git a/graph_networks/LinearStyleTransfer/libs/smooth_filter.py b/graph_networks/LinearStyleTransfer/libs/smooth_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4968615b1883dc095fe1d5900fe15cd3c6fe8ecf
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/smooth_filter.py
@@ -0,0 +1,407 @@
+"""
+Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py
+"""
+src = '''
+ #include "/usr/local/cuda/include/math_functions.h"
+ #define TB 256
+ #define EPS 1e-7
+
+ __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) {
+ double m[16], inv[16];
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ m[i * 4 + j] = m_in[i][j];
+ }
+ }
+
+ inv[0] = m[5] * m[10] * m[15] -
+ m[5] * m[11] * m[14] -
+ m[9] * m[6] * m[15] +
+ m[9] * m[7] * m[14] +
+ m[13] * m[6] * m[11] -
+ m[13] * m[7] * m[10];
+
+ inv[4] = -m[4] * m[10] * m[15] +
+ m[4] * m[11] * m[14] +
+ m[8] * m[6] * m[15] -
+ m[8] * m[7] * m[14] -
+ m[12] * m[6] * m[11] +
+ m[12] * m[7] * m[10];
+
+ inv[8] = m[4] * m[9] * m[15] -
+ m[4] * m[11] * m[13] -
+ m[8] * m[5] * m[15] +
+ m[8] * m[7] * m[13] +
+ m[12] * m[5] * m[11] -
+ m[12] * m[7] * m[9];
+
+ inv[12] = -m[4] * m[9] * m[14] +
+ m[4] * m[10] * m[13] +
+ m[8] * m[5] * m[14] -
+ m[8] * m[6] * m[13] -
+ m[12] * m[5] * m[10] +
+ m[12] * m[6] * m[9];
+
+ inv[1] = -m[1] * m[10] * m[15] +
+ m[1] * m[11] * m[14] +
+ m[9] * m[2] * m[15] -
+ m[9] * m[3] * m[14] -
+ m[13] * m[2] * m[11] +
+ m[13] * m[3] * m[10];
+
+ inv[5] = m[0] * m[10] * m[15] -
+ m[0] * m[11] * m[14] -
+ m[8] * m[2] * m[15] +
+ m[8] * m[3] * m[14] +
+ m[12] * m[2] * m[11] -
+ m[12] * m[3] * m[10];
+
+ inv[9] = -m[0] * m[9] * m[15] +
+ m[0] * m[11] * m[13] +
+ m[8] * m[1] * m[15] -
+ m[8] * m[3] * m[13] -
+ m[12] * m[1] * m[11] +
+ m[12] * m[3] * m[9];
+
+ inv[13] = m[0] * m[9] * m[14] -
+ m[0] * m[10] * m[13] -
+ m[8] * m[1] * m[14] +
+ m[8] * m[2] * m[13] +
+ m[12] * m[1] * m[10] -
+ m[12] * m[2] * m[9];
+
+ inv[2] = m[1] * m[6] * m[15] -
+ m[1] * m[7] * m[14] -
+ m[5] * m[2] * m[15] +
+ m[5] * m[3] * m[14] +
+ m[13] * m[2] * m[7] -
+ m[13] * m[3] * m[6];
+
+ inv[6] = -m[0] * m[6] * m[15] +
+ m[0] * m[7] * m[14] +
+ m[4] * m[2] * m[15] -
+ m[4] * m[3] * m[14] -
+ m[12] * m[2] * m[7] +
+ m[12] * m[3] * m[6];
+
+ inv[10] = m[0] * m[5] * m[15] -
+ m[0] * m[7] * m[13] -
+ m[4] * m[1] * m[15] +
+ m[4] * m[3] * m[13] +
+ m[12] * m[1] * m[7] -
+ m[12] * m[3] * m[5];
+
+ inv[14] = -m[0] * m[5] * m[14] +
+ m[0] * m[6] * m[13] +
+ m[4] * m[1] * m[14] -
+ m[4] * m[2] * m[13] -
+ m[12] * m[1] * m[6] +
+ m[12] * m[2] * m[5];
+
+ inv[3] = -m[1] * m[6] * m[11] +
+ m[1] * m[7] * m[10] +
+ m[5] * m[2] * m[11] -
+ m[5] * m[3] * m[10] -
+ m[9] * m[2] * m[7] +
+ m[9] * m[3] * m[6];
+
+ inv[7] = m[0] * m[6] * m[11] -
+ m[0] * m[7] * m[10] -
+ m[4] * m[2] * m[11] +
+ m[4] * m[3] * m[10] +
+ m[8] * m[2] * m[7] -
+ m[8] * m[3] * m[6];
+
+ inv[11] = -m[0] * m[5] * m[11] +
+ m[0] * m[7] * m[9] +
+ m[4] * m[1] * m[11] -
+ m[4] * m[3] * m[9] -
+ m[8] * m[1] * m[7] +
+ m[8] * m[3] * m[5];
+
+ inv[15] = m[0] * m[5] * m[10] -
+ m[0] * m[6] * m[9] -
+ m[4] * m[1] * m[10] +
+ m[4] * m[2] * m[9] +
+ m[8] * m[1] * m[6] -
+ m[8] * m[2] * m[5];
+
+ double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12];
+
+ if (abs(det) < 1e-9) {
+ return false;
+ }
+
+
+ det = 1.0 / det;
+
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 4; j++) {
+ inv_out[i][j] = inv[i * 4 + j] * det;
+ }
+ }
+
+ return true;
+ }
+
+ extern "C"
+ __global__ void best_local_affine_kernel(
+ float *output, float *input, float *affine_model,
+ int h, int w, float epsilon, int kernel_radius
+ )
+ {
+ int size = h * w;
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (id < size) {
+ int x = id % w, y = id / w;
+
+ double Mt_M[4][4] = {}; // 4x4
+ double invMt_M[4][4] = {};
+ double Mt_S[3][4] = {}; // RGB -> 1x4
+ double A[3][4] = {};
+ for (int i = 0; i < 4; i++)
+ for (int j = 0; j < 4; j++) {
+ Mt_M[i][j] = 0, invMt_M[i][j] = 0;
+ if (i != 3) {
+ Mt_S[i][j] = 0, A[i][j] = 0;
+ if (i == j)
+ Mt_M[i][j] = 1e-3;
+ }
+ }
+
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
+
+ int xx = x + dx, yy = y + dy;
+ int id2 = yy * w + xx;
+
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
+
+ Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size];
+ Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size];
+ Mt_M[0][2] += input[id2 + 2*size] * input[id2];
+ Mt_M[0][3] += input[id2 + 2*size];
+
+ Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size];
+ Mt_M[1][1] += input[id2 + size] * input[id2 + size];
+ Mt_M[1][2] += input[id2 + size] * input[id2];
+ Mt_M[1][3] += input[id2 + size];
+
+ Mt_M[2][0] += input[id2] * input[id2 + 2*size];
+ Mt_M[2][1] += input[id2] * input[id2 + size];
+ Mt_M[2][2] += input[id2] * input[id2];
+ Mt_M[2][3] += input[id2];
+
+ Mt_M[3][0] += input[id2 + 2*size];
+ Mt_M[3][1] += input[id2 + size];
+ Mt_M[3][2] += input[id2];
+ Mt_M[3][3] += 1;
+
+ Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size];
+ Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size];
+ Mt_S[0][2] += input[id2] * output[id2 + 2*size];
+ Mt_S[0][3] += output[id2 + 2*size];
+
+ Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size];
+ Mt_S[1][1] += input[id2 + size] * output[id2 + size];
+ Mt_S[1][2] += input[id2] * output[id2 + size];
+ Mt_S[1][3] += output[id2 + size];
+
+ Mt_S[2][0] += input[id2 + 2*size] * output[id2];
+ Mt_S[2][1] += input[id2 + size] * output[id2];
+ Mt_S[2][2] += input[id2] * output[id2];
+ Mt_S[2][3] += output[id2];
+ }
+ }
+ }
+
+ bool success = InverseMat4x4(Mt_M, invMt_M);
+
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ for (int k = 0; k < 4; k++) {
+ A[i][j] += invMt_M[j][k] * Mt_S[i][k];
+ }
+ }
+ }
+
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ int affine_id = i * 4 + j;
+ affine_model[12 * id + affine_id] = A[i][j];
+ }
+ }
+ }
+ return ;
+ }
+
+ extern "C"
+ __global__ void bilateral_smooth_kernel(
+ float *affine_model, float *filtered_affine_model, float *guide,
+ int h, int w, int kernel_radius, float sigma1, float sigma2
+ )
+ {
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
+ int size = h * w;
+ if (id < size) {
+ int x = id % w;
+ int y = id / w;
+
+ double sum_affine[12] = {};
+ double sum_weight = 0;
+ for (int dx = -kernel_radius; dx <= kernel_radius; dx++) {
+ for (int dy = -kernel_radius; dy <= kernel_radius; dy++) {
+ int yy = y + dy, xx = x + dx;
+ int id2 = yy * w + xx;
+ if (0 <= xx && xx < w && 0 <= yy && yy < h) {
+ float color_diff1 = guide[yy*w + xx] - guide[y*w + x];
+ float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size];
+ float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size];
+ float color_diff_sqr =
+ (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3;
+
+ float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1));
+ float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2));
+ float weight = v1 * v2;
+
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ int affine_id = i * 4 + j;
+ sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id];
+ }
+ }
+ sum_weight += weight;
+ }
+ }
+ }
+
+ for (int i = 0; i < 3; i++) {
+ for (int j = 0; j < 4; j++) {
+ int affine_id = i * 4 + j;
+ filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight;
+ }
+ }
+ }
+ return ;
+ }
+
+
+ extern "C"
+ __global__ void reconstruction_best_kernel(
+ float *input, float *filtered_affine_model, float *filtered_best_output,
+ int h, int w
+ )
+ {
+ int id = blockIdx.x * blockDim.x + threadIdx.x;
+ int size = h * w;
+ if (id < size) {
+ double out1 =
+ input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] +
+ input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] +
+ input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] +
+ filtered_affine_model[id*12 + 3]; //A[0][3];
+ double out2 =
+ input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] +
+ input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] +
+ input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] +
+ filtered_affine_model[id*12 + 7]; //A[1][3];
+ double out3 =
+ input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] +
+ input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] +
+ input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] +
+ filtered_affine_model[id*12 + 11]; // A[2][3];
+
+ filtered_best_output[id] = out1;
+ filtered_best_output[id + size] = out2;
+ filtered_best_output[id + 2*size] = out3;
+ }
+ return ;
+ }
+ '''
+
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+from cupy.cuda import function
+from pynvrtc.compiler import Program
+from collections import namedtuple
+
+
+def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e):
+ # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8'))
+ # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')])
+ program = Program(src, 'best_local_affine_kernel.cu')
+ ptx = program.compile(['-I/usr/local/cuda/include'])
+ m = function.Module()
+ m.load(bytes(ptx.encode()))
+
+ _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel')
+ _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel')
+ _best_local_affine_kernel = m.get_function('best_local_affine_kernel')
+ Stream = namedtuple('Stream', ['ptr'])
+ s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
+
+ filter_radius = f_r
+ sigma1 = filter_radius / 3
+ sigma2 = f_e
+ radius = (patch - 1) / 2
+
+ filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda()
+ affine_model = torch.zeros((h * w, 12)).cuda()
+ filtered_affine_model =torch.zeros((h * w, 12)).cuda()
+
+ input_ = torch.from_numpy(input_cpu).cuda()
+ output_ = torch.from_numpy(output_cpu).cuda()
+ _best_local_affine_kernel(
+ grid=(int((h * w) / 256 + 1), 1),
+ block=(256, 1, 1),
+ args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(),
+ np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s
+ )
+
+ _bilateral_smooth_kernel(
+ grid=(int((h * w) / 256 + 1), 1),
+ block=(256, 1, 1),
+ args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s
+ )
+
+ _reconstruction_best_kernel(
+ grid=(int((h * w) / 256 + 1), 1),
+ block=(256, 1, 1),
+ args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(),
+ np.int32(h), np.int32(w)], stream=s
+ )
+ numpy_filtered_best_output = filtered_best_output.cpu().numpy()
+ return numpy_filtered_best_output
+
+
+def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1):
+ '''
+ :param initImg: intermediate output. Either image path or PIL Image
+ :param contentImg: content image output. Either path or PIL Image
+ :return: stylized output image. PIL Image
+ '''
+ if type(initImg) == str:
+ initImg = Image.open(initImg).convert("RGB")
+ best_image_bgr = np.array(initImg, dtype=np.float32)
+ bW, bH, bC = best_image_bgr.shape
+ best_image_bgr = best_image_bgr[:, :, ::-1]
+ best_image_bgr = best_image_bgr.transpose((2, 0, 1))
+
+ if type(contentImg) == str:
+ contentImg = Image.open(contentImg).convert("RGB")
+ content_input = contentImg.resize((bH,bW))
+ else:
+ content_input = cv2.resize(contentImg,(bH,bW))
+ content_input = np.array(content_input, dtype=np.float32)
+ content_input = content_input[:, :, ::-1]
+ content_input = content_input.transpose((2, 0, 1))
+ input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255.
+ _, H, W = np.shape(input_)
+ output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255.
+ best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge)
+ best_ = best_.transpose(1, 2, 0)
+ result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.)))
+ return result
diff --git a/graph_networks/LinearStyleTransfer/libs/utils.py b/graph_networks/LinearStyleTransfer/libs/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c4a274028d3c396aa686476cb0ae400113a1fe
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/libs/utils.py
@@ -0,0 +1,92 @@
+from __future__ import division
+import os
+import cv2
+import time
+import torch
+import scipy.misc
+import numpy as np
+import scipy.sparse
+from PIL import Image
+import scipy.sparse.linalg
+from cv2.ximgproc import jointBilateralFilter
+from numpy.lib.stride_tricks import as_strided
+
+def whiten(cF):
+ cFSize = cF.size()
+ c_mean = torch.mean(cF,1) # c x (h x w)
+ c_mean = c_mean.unsqueeze(1).expand_as(cF)
+ cF = cF - c_mean
+
+ contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double()
+ c_u,c_e,c_v = torch.svd(contentConv,some=False)
+
+ k_c = cFSize[0]
+ for i in range(cFSize[0]):
+ if c_e[i] < 0.00001:
+ k_c = i
+ break
+
+ c_d = (c_e[0:k_c]).pow(-0.5)
+ step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d))
+ step2 = torch.mm(step1,(c_v[:,0:k_c].t()))
+ whiten_cF = torch.mm(step2,cF)
+ return whiten_cF
+
+def numpy2cv2(cont,style,prop,width,height):
+ cont = cont.transpose((1,2,0))
+ cont = cont[...,::-1]
+ cont = cont * 255
+ cont = cv2.resize(cont,(width,height))
+ #cv2.resize(iimg,(width,height))
+ style = style.transpose((1,2,0))
+ style = style[...,::-1]
+ style = style * 255
+ style = cv2.resize(style,(width,height))
+
+ prop = prop.transpose((1,2,0))
+ prop = prop[...,::-1]
+ prop = prop * 255
+ prop = cv2.resize(prop,(width,height))
+
+ #return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1)
+ return prop,cont
+
+def makeVideo(content,style,props,outf):
+ print('Stack transferred frames back to video...')
+ layers,height,width = content[0].shape
+ fourcc = cv2.VideoWriter_fourcc(*'MJPG')
+ video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height))
+ ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height))
+ for j in range(len(content)):
+ prop,cont = numpy2cv2(content[j],style,props[j],width,height)
+ cv2.imwrite('prop.png',prop)
+ cv2.imwrite('content.png',cont)
+ # TODO: this is ugly, fix this
+ imgj = cv2.imread('prop.png')
+ imgc = cv2.imread('content.png')
+
+ video.write(imgj)
+ ori_video.write(imgc)
+ # RGB or BRG, yuks
+ video.release()
+ ori_video.release()
+ os.remove('prop.png')
+ os.remove('content.png')
+ print('Transferred video saved at %s.'%outf)
+
+def print_options(opt):
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.outf)
+ os.makedirs(expr_dir,exist_ok=True)
+ file_name = os.path.join(expr_dir, 'opt.txt')
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
diff --git a/graph_networks/LinearStyleTransfer/models/dec_r31.pth b/graph_networks/LinearStyleTransfer/models/dec_r31.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b7ddc0933441ef7622c0b0e752c0644828b9db4c
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/dec_r31.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ccc3bbc97a15e1d002d0b13523543c518dda2d0346c8f4d39c1d381d8490f68
+size 2221888
diff --git a/graph_networks/LinearStyleTransfer/models/dec_r41.pth b/graph_networks/LinearStyleTransfer/models/dec_r41.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7407cde48fb8b841852732487e5f9dad0b874cdb
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/dec_r41.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e6858f96d3d0882fa3b40652a0315928219086e4bcb0e3efbe43bd04ea631911
+size 14023509
diff --git a/graph_networks/LinearStyleTransfer/models/r31.pth b/graph_networks/LinearStyleTransfer/models/r31.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d8c9a170e28f28d84dd6c1cbfb161422aeedb48c
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/r31.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bb75be684b331a105a8fad266343556067e4eec888249f7abb693a66b5ad7e3
+size 11564438
diff --git a/graph_networks/LinearStyleTransfer/models/r41.pth b/graph_networks/LinearStyleTransfer/models/r41.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d7ddefca5c34b0d5281c6fff1daeaa09570c6780
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/r41.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98cdafb2f553ea3071782255cc5e739eecf2a74fcdde280e5d7e09b8017fad4d
+size 20627360
diff --git a/graph_networks/LinearStyleTransfer/models/r41_spn.pth b/graph_networks/LinearStyleTransfer/models/r41_spn.pth
new file mode 100644
index 0000000000000000000000000000000000000000..b12cb57b30e0adc3e107ce6568603b89af640d90
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/r41_spn.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:79db437b39507a97b642b0f7ed00b86fb62c9bf49fd86d3eab0a62ce023a9db8
+size 3098678
diff --git a/graph_networks/LinearStyleTransfer/models/vgg_r31.pth b/graph_networks/LinearStyleTransfer/models/vgg_r31.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0c058c9cd91bcb91957c390c981b5f39d629d522
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/vgg_r31.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f9bffd632f87360d4de1bcd1627246d2b41a6278aa46e1c1fe7212796b646b7e
+size 2223422
diff --git a/graph_networks/LinearStyleTransfer/models/vgg_r41.pth b/graph_networks/LinearStyleTransfer/models/vgg_r41.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d2f5756f6c0669490b202621ca7d5dc69838cbcf
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/vgg_r41.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7949c5ea891cd75de2fce4a9cdb0a14f9bc1672053f5c92db80ded641b7e57d0
+size 14026238
diff --git a/graph_networks/LinearStyleTransfer/models/vgg_r51.pth b/graph_networks/LinearStyleTransfer/models/vgg_r51.pth
new file mode 100644
index 0000000000000000000000000000000000000000..ade4f4e2febab7f63b3a8aeeac1bd3f31a6dee66
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/models/vgg_r51.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71041e0ee08540e690cfa08982f2da86b590c90335468808854092ff07c81cfc
+size 51784517
diff --git a/graph_networks/LinearStyleTransfer/real-time-demo.py b/graph_networks/LinearStyleTransfer/real-time-demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d3565ee2a84db8ced9e72f1d39c6473e0d2937a
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer/real-time-demo.py
@@ -0,0 +1,120 @@
+import os
+import cv2
+import torch
+import argparse
+import numpy as np
+from PIL import Image
+from libs.Loader import Dataset
+from libs.Matrix import MulLayer
+from libs.utils import makeVideo
+import torch.backends.cudnn as cudnn
+from libs.models import encoder3,encoder4
+from libs.models import decoder3,decoder4
+import torchvision.transforms as transforms
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--vgg_dir", default='models/vgg_r31.pth',
+ help='pre-trained encoder path')
+parser.add_argument("--decoder_dir", default='models/dec_r31.pth',
+ help='pre-trained decoder path')
+parser.add_argument("--style", default="data/style/in2.jpg",
+ help='path to style image')
+parser.add_argument("--matrixPath", default="models/r31.pth",
+ help='path to pre-trained model')
+parser.add_argument('--fineSize', type=int, default=256,
+ help='crop image size')
+parser.add_argument("--name",default="transferred_video",
+ help="name of generated video")
+parser.add_argument("--layer",default="r31",
+ help="features of which layer to transfer")
+parser.add_argument("--outf",default="real_time_demo_output",
+ help="output folder")
+
+################# PREPARATIONS #################
+opt = parser.parse_args()
+opt.cuda = torch.cuda.is_available()
+print(opt)
+os.makedirs(opt.outf,exist_ok=True)
+cudnn.benchmark = True
+
+################# DATA #################
+def loadImg(imgPath):
+ img = Image.open(imgPath).convert('RGB')
+ transform = transforms.Compose([
+ transforms.Scale(opt.fineSize),
+ transforms.ToTensor()])
+ return transform(img)
+style = loadImg(opt.style).unsqueeze(0)
+
+################# MODEL #################
+if(opt.layer == 'r31'):
+ matrix = MulLayer(layer='r31')
+ vgg = encoder3()
+ dec = decoder3()
+elif(opt.layer == 'r41'):
+ matrix = MulLayer(layer='r41')
+ vgg = encoder4()
+ dec = decoder4()
+vgg.load_state_dict(torch.load(opt.vgg_dir))
+dec.load_state_dict(torch.load(opt.dec_dir))
+matrix.load_state_dict(torch.load(opt.matrixPath))
+for param in vgg.parameters():
+ param.requires_grad = False
+for param in dec.parameters():
+ param.requires_grad = False
+for param in matrix.parameters():
+ param.requires_grad = False
+
+################# GLOBAL VARIABLE #################
+content = torch.Tensor(1,3,opt.fineSize,opt.fineSize)
+
+################# GPU #################
+if(opt.cuda):
+ vgg.cuda()
+ dec.cuda()
+ matrix.cuda()
+
+ style = style.cuda()
+ content = content.cuda()
+
+totalTime = 0
+imageCounter = 0
+result_frames = []
+contents = []
+styles = []
+cap = cv2.VideoCapture(0)
+cap.set(3,256)
+cap.set(4,512)
+fourcc = cv2.VideoWriter_fourcc(*'MJPG')
+out = cv2.VideoWriter(os.path.join(opt.outf,opt.name+'.avi'),fourcc,20.0,(512,256))
+
+with torch.no_grad():
+ sF = vgg(style)
+
+while(True):
+ ret,frame = cap.read()
+ frame = cv2.resize(frame,(512,256),interpolation=cv2.INTER_CUBIC)
+ frame = frame.transpose((2,0,1))
+ frame = frame[::-1,:,:]
+ frame = frame/255.0
+ frame = torch.from_numpy(frame.copy()).unsqueeze(0)
+ content.data.resize_(frame.size()).copy_(frame)
+ with torch.no_grad():
+ cF = vgg(content)
+ if(opt.layer == 'r41'):
+ feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer])
+ else:
+ feature,transmatrix = matrix(cF,sF)
+ transfer = dec(feature)
+ transfer = transfer.clamp(0,1).squeeze(0).data.cpu().numpy()
+ transfer = transfer.transpose((1,2,0))
+ transfer = transfer[...,::-1]
+ out.write(np.uint8(transfer*255))
+ cv2.imshow('frame',transfer)
+ if cv2.waitKey(1) & 0xFF == ord('q'):
+ break
+
+# When everything done, release the capture
+out.release()
+cap.release()
+cv2.destroyAllWindows()
diff --git a/graph_networks/LinearStyleTransfer_matrix.py b/graph_networks/LinearStyleTransfer_matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ed08fe4baed3fd97156a40c03ea0ae1ad8ece0f
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer_matrix.py
@@ -0,0 +1,79 @@
+import torch
+import torch.nn as nn
+
+from copy import deepcopy
+
+from selectionConv import SelectionConv
+
+class CNN(nn.Module):
+ def __init__(self,matrixSize=32):
+ super(CNN,self).__init__()
+
+ self.conv1 = SelectionConv(512,256,3,padding_mode="zeros")
+ self.conv2 = SelectionConv(256,128,3,padding_mode="zeros")
+ self.conv3 = SelectionConv(128,matrixSize,3,padding_mode="zeros")
+ self.relu = torch.nn.ReLU()
+
+ self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize)
+
+ def forward(self,x,edge_index,selections,interp_values=None):
+ out = self.relu(self.conv1(x,edge_index,selections,interp_values))
+ out = self.relu(self.conv2(out,edge_index,selections,interp_values))
+ out = self.conv3(out,edge_index,selections,interp_values)
+
+ n,ch = out.size()
+
+ out = torch.mm(out.t(), out).div(n)
+
+ out = out.view(-1)
+ return self.fc(out)
+
+class TransformLayer(nn.Module):
+ def __init__(self,matrixSize=32):
+ super(TransformLayer,self).__init__()
+ self.snet = CNN(matrixSize)
+ self.cnet = CNN(matrixSize)
+ self.matrixSize = matrixSize
+
+ self.compress = SelectionConv(512,matrixSize,1)
+ self.unzip = SelectionConv(matrixSize,512,1)
+
+ def forward(self,cF,sF,content_edge_index,content_selections,style_edge_index,style_selections,content_interps=None,style_interps=None,trans=True):
+ cMean = torch.mean(cF,dim=0,keepdim=True)
+ cF = cF - cMean
+
+ sMean = torch.mean(sF,dim=0,keepdim=True)
+ sF = sF - sMean
+
+ compress_content = self.compress(cF,content_edge_index,content_selections,content_interps)
+
+ if(trans):
+ cMatrix = self.cnet(cF,content_edge_index,content_selections,content_interps)
+ sMatrix = self.snet(sF,style_edge_index,style_selections,style_interps)
+
+ sMatrix = sMatrix.view(self.matrixSize,self.matrixSize)
+ cMatrix = cMatrix.view(self.matrixSize,self.matrixSize)
+ transmatrix = torch.mm(sMatrix,cMatrix)
+ transfeature = torch.mm(transmatrix,compress_content.transpose(1,0))
+ out = self.unzip(transfeature.transpose(1,0),content_edge_index,content_selections,content_interps)
+ out = out + sMean
+ return out, transmatrix
+ else:
+ out = self.unzip(compress_content,content_edge_index,content_selections,content_interps)
+ out = out + cMean
+ return out
+
+ def copy_weights(self, model):
+ self.cnet.conv1.copy_weights(model.cnet.convs[0].weight,model.cnet.convs[0].bias)
+ self.cnet.conv2.copy_weights(model.cnet.convs[2].weight,model.cnet.convs[2].bias)
+ self.cnet.conv3.copy_weights(model.cnet.convs[4].weight,model.cnet.convs[4].bias)
+
+ self.snet.conv1.copy_weights(model.snet.convs[0].weight,model.snet.convs[0].bias)
+ self.snet.conv2.copy_weights(model.snet.convs[2].weight,model.snet.convs[2].bias)
+ self.snet.conv3.copy_weights(model.snet.convs[4].weight,model.snet.convs[4].bias)
+
+ self.cnet.fc = deepcopy(model.cnet.fc)
+ self.snet.fc = deepcopy(model.snet.fc)
+
+ self.compress.copy_weights(model.compress.weight,model.compress.bias)
+ self.unzip.copy_weights(model.unzip.weight,model.unzip.bias)
diff --git a/graph_networks/LinearStyleTransfer_vgg.py b/graph_networks/LinearStyleTransfer_vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..179325787c65be0cf91e6d33d992850a5aa32f3b
--- /dev/null
+++ b/graph_networks/LinearStyleTransfer_vgg.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn.functional as F
+from torch.nn import Sequential as Seq, Linear as Lin, ReLU
+
+from selectionConv import SelectionConv
+
+from pooling import unpoolCluster, maxPoolCluster
+
+class encoder(torch.nn.Module):
+
+ def __init__(self,padding_mode="reflect"):
+ super(encoder, self).__init__()
+ self.conv1 = SelectionConv(3,3,1,padding_mode=padding_mode)
+ self.conv2 = SelectionConv(3,64,3,padding_mode=padding_mode)
+ self.conv3 = SelectionConv(64,64,3,padding_mode=padding_mode)
+
+ self.conv4 = SelectionConv(64,128,3,padding_mode=padding_mode)
+ self.conv5 = SelectionConv(128,128,3,padding_mode=padding_mode)
+
+ self.conv6 = SelectionConv(128,256,3,padding_mode=padding_mode)
+ self.conv7 = SelectionConv(256,256,3,padding_mode=padding_mode)
+ self.conv8 = SelectionConv(256,256,3,padding_mode=padding_mode)
+ self.conv9 = SelectionConv(256,256,3,padding_mode=padding_mode)
+
+ self.conv10 = SelectionConv(256,512,3,padding_mode=padding_mode)
+
+ self.relu = torch.nn.ReLU()
+
+ def copy_weights(self, model):
+ self.conv1.copy_weights(model.conv1.weight,model.conv1.bias)
+ self.conv2.copy_weights(model.conv2.weight,model.conv2.bias)
+ self.conv3.copy_weights(model.conv3.weight,model.conv3.bias)
+ self.conv4.copy_weights(model.conv4.weight,model.conv4.bias)
+ self.conv5.copy_weights(model.conv5.weight,model.conv5.bias)
+ self.conv6.copy_weights(model.conv6.weight,model.conv6.bias)
+ self.conv7.copy_weights(model.conv7.weight,model.conv7.bias)
+ self.conv8.copy_weights(model.conv8.weight,model.conv8.bias)
+ self.conv9.copy_weights(model.conv9.weight,model.conv9.bias)
+ self.conv10.copy_weights(model.conv10.weight,model.conv10.bias)
+
+
+ def forward(self,graph):
+
+ edge_index = graph.edge_indexes[0]
+ selections = graph.selections_list[0]
+ interps = graph.interps_list[0] if hasattr(graph,'interps_list') else None
+
+ output = {}
+ out = self.conv1(graph.x,edge_index,selections,interps)
+ out = self.conv2(out,edge_index,selections,interps)
+ output['r11'] = self.relu(out)
+
+ out = self.conv3(output['r11'],edge_index,selections,interps)
+ output['r12'] = self.relu(out)
+
+ output['p1'] = maxPoolCluster(output['r12'],graph.clusters[0])
+ edge_index = graph.edge_indexes[1]
+ selections = graph.selections_list[1]
+ interps = graph.interps_list[1] if hasattr(graph,'interps_list') else None
+
+ out = self.conv4(output['p1'],edge_index,selections,interps)
+ output['r21'] = self.relu(out)
+
+ out = self.conv5(output['r21'],edge_index,selections,interps)
+ output['r22'] = self.relu(out)
+
+ output['p2'] = maxPoolCluster(output['r22'],graph.clusters[1])
+ edge_index = graph.edge_indexes[2]
+ selections = graph.selections_list[2]
+ interps = graph.interps_list[2] if hasattr(graph,'interps_list') else None
+
+ out = self.conv6(output['p2'],edge_index,selections,interps)
+ output['r31'] = self.relu(out)
+ #if(matrix31 is not None):
+ # feature3,transmatrix3 = matrix31(output['r31'],sF['r31'])
+ # out = self.reflecPad7(feature3)
+ #else:
+ # out = self.reflecPad7()
+ out = self.conv7(output['r31'],edge_index,selections,interps)
+ output['r32'] = self.relu(out)
+
+ out = self.conv8(output['r32'],edge_index,selections,interps)
+ output['r33'] = self.relu(out)
+
+ out = self.conv9(output['r33'],edge_index,selections,interps)
+ output['r34'] = self.relu(out)
+
+ output['p3'] = maxPoolCluster(output['r34'],graph.clusters[2])
+ edge_index = graph.edge_indexes[3]
+ selections = graph.selections_list[3]
+ interps = graph.interps_list[3] if hasattr(graph,'interps_list') else None
+
+ out = self.conv10(output['p3'],edge_index,selections,interps)
+ output['r41'] = self.relu(out)
+
+ return output
+
+class decoder(torch.nn.Module):
+
+ def __init__(self,padding_mode="reflect"):
+ super(decoder, self).__init__()
+ self.conv11 = SelectionConv(512,256,3,padding_mode=padding_mode)
+
+ self.conv12 = SelectionConv(256,256,3,padding_mode=padding_mode)
+ self.conv13 = SelectionConv(256,256,3,padding_mode=padding_mode)
+ self.conv14 = SelectionConv(256,256,3,padding_mode=padding_mode)
+ self.conv15 = SelectionConv(256,128,3,padding_mode=padding_mode)
+
+ self.conv16 = SelectionConv(128,128,3,padding_mode=padding_mode)
+ self.conv17 = SelectionConv(128,64,3,padding_mode=padding_mode)
+
+ self.conv18 = SelectionConv(64,64,3,padding_mode=padding_mode)
+ self.conv19 = SelectionConv(64,3,3,padding_mode=padding_mode)
+
+ self.relu = torch.nn.ReLU()
+
+ def copy_weights(self, model):
+ self.conv11.copy_weights(model.conv11.weight,model.conv11.bias)
+ self.conv12.copy_weights(model.conv12.weight,model.conv12.bias)
+ self.conv13.copy_weights(model.conv13.weight,model.conv13.bias)
+ self.conv14.copy_weights(model.conv14.weight,model.conv14.bias)
+ self.conv15.copy_weights(model.conv15.weight,model.conv15.bias)
+ self.conv16.copy_weights(model.conv16.weight,model.conv16.bias)
+ self.conv17.copy_weights(model.conv17.weight,model.conv17.bias)
+ self.conv18.copy_weights(model.conv18.weight,model.conv18.bias)
+ self.conv19.copy_weights(model.conv19.weight,model.conv19.bias)
+
+
+ def forward(self,x,graph):
+
+ edge_index = graph.edge_indexes[3]
+ selections = graph.selections_list[3]
+ interps = graph.interps_list[3] if hasattr(graph,'interps_list') else None
+
+ out = self.conv11(x,edge_index,selections,interps)
+ out = self.relu(out)
+
+ out = unpoolCluster(out,graph.clusters[2])
+ edge_index = graph.edge_indexes[2]
+ selections = graph.selections_list[2]
+ interps = graph.interps_list[2] if hasattr(graph,'interps_list') else None
+
+ out = self.conv12(out,edge_index,selections,interps)
+ out = self.relu(out)
+ out = self.conv13(out,edge_index,selections,interps)
+ out = self.relu(out)
+ out = self.conv14(out,edge_index,selections,interps)
+ out = self.relu(out)
+ out = self.conv15(out,edge_index,selections,interps)
+ out = self.relu(out)
+
+ out = unpoolCluster(out,graph.clusters[1])
+ edge_index = graph.edge_indexes[1]
+ selections = graph.selections_list[1]
+ interps = graph.interps_list[1] if hasattr(graph,'interps_list') else None
+
+ out = self.conv16(out,edge_index,selections,interps)
+ out = self.relu(out)
+ out = self.conv17(out,edge_index,selections,interps)
+ out = self.relu(out)
+
+ out = unpoolCluster(out,graph.clusters[0])
+ edge_index = graph.edge_indexes[0]
+ selections = graph.selections_list[0]
+ interps = graph.interps_list[0] if hasattr(graph,'interps_list') else None
+
+ out = self.conv18(out,edge_index,selections,interps)
+ out = self.relu(out)
+ out = self.conv19(out,edge_index,selections,interps)
+
+ return out
+
diff --git a/graph_networks/__init__.py b/graph_networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/graph_networks/graph_transforms.py b/graph_networks/graph_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a89590343c25eb76cde1f0366920f3cf531a52f
--- /dev/null
+++ b/graph_networks/graph_transforms.py
@@ -0,0 +1,389 @@
+""" graph_transforms.py
+
+This is the implementation of transforming a traditional CNN
+to a SelectionConv-based graph CNN
+So far this is just used for segmentation
+"""
+from copy import deepcopy
+from typing import Dict, Iterable, OrderedDict, Tuple, Union #,Literal Only supported in Python 3.8+
+from typing_extensions import Literal
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_geometric.data import Data
+from torchvision.models.segmentation.fcn import FCN
+
+from selectionConv import SelectionConv
+import pooling as P
+
+def transform_network(network: nn.Module):
+ """ Transforms a neural network from a tensor based network to a graph based network
+
+ Parameters
+ ----------
+ - network: the network to transform
+
+ Returns
+ -------
+ - the transformed network
+ """
+ network = deepcopy(network)
+ if type(network) in __MAPPING:
+ with torch.no_grad():
+ return __MAPPING[type(network)].from_torch(network)
+ if not isinstance(network, nn.Module):
+ raise ValueError(f"Must be of type Module but got: {type(network)}")
+ if not list(network.children()):
+ raise NotImplementedError(f"{type(network)} is not implemented yet")
+ for name, child in network.named_children():
+ transformed_child = transform_network(child)
+ setattr(network, name, transformed_child)
+ return network
+
+
+class GraphTracker:
+ """ a wrapper around the graph data for easily overwriting the forward function of existing modules.
+
+ Parameters
+ ----------
+ - graph: the graph data
+ - level: the current depth the graph is being operated on
+ - x: the node data at the current level
+ """
+ def __init__(self, graph, x=None, level=0):
+ self.graph = graph
+ self.x = graph.x if x is None else x
+ self.level = level
+
+ def from_x(self, x):
+ """ create the same graph with different node values"""
+ return GraphTracker(self.graph,x,level=self.level)
+
+ def edge_index(self):
+ return self.graph.edge_indexes[self.level]
+
+ def selections(self):
+ return self.graph.selections_list[self.level]
+
+ def interps(self):
+ if hasattr(self.graph,"interps_list"):
+ return self.graph.interps_list[self.level]
+ else:
+ return None
+
+ def cluster(self):
+ return self.graph.clusters[self.level]
+
+ def __iadd__(self, other):
+ self.x = self.x + other.x
+ return self
+
+ def __repr__(self):
+ return f"GraphTracker(x={tuple(self.x.shape)},level={self.level})"
+
+
+
+def _single(pair, name):
+ """ converts a tuple into a single number
+
+ Parameters
+ ----------
+ - pair: the potential pair of values
+ - name: the name of the values for more readable errors
+
+ Returns
+ -------
+ - the single value
+ """
+ if isinstance(pair, int):
+ return pair
+ if not isinstance(pair, tuple):
+ raise ValueError(f"{name} must either be int or tuple but got: {type(pair)}")
+ if len(pair) != 2:
+ raise ValueError(f"{name} must be a 2-tuple but got: {pair}")
+ if pair[0] != pair[1]:
+ raise ValueError(f"{name} must be a square tuple")
+ return pair[0]
+
+
+class SelModule(nn.Module):
+ """ A super class for all graph based modules to inherit from
+ """
+ @classmethod
+ def from_torch(cls, network):
+ """ creates a new graph based module from an existing 2d based module and copies weights accordingly. Each child class should implement this method
+
+ Parameters
+ ----------
+ - network: the existing 2d based module
+
+ Returns
+ -------
+ - the new graph based module
+ """
+ raise NotImplementedError
+
+
+class SelConv(SelModule, nn.modules.conv._ConvNd):
+ """ A wrapper class around the SelectionConv class that allows for easy
+ use in a transformed network
+
+ Parameters
+ ----------
+ - in_channels: the number of incoming channels
+ - out_channels: the number of outgoing channels
+ - kernel_size: the size of the convolution kernel
+ - stride: the stride at which to perform convolution
+ - padding: the amount of padding to be used
+ - dilation: the dilation of the kernel
+ - groups: the groups of filters for the convolution
+ - bias: whether or not to include a bias
+ - padding_mode: the type of padding to be used
+ """
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]]=1,
+ padding: Union[int, Tuple[int, int]]=0,
+ dilation: Union[int, Tuple[int, int]]=1,
+ groups: int = 1,
+ bias: bool=True,
+ padding_mode: str='zeros',
+ device=None,
+ dtype=None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, (0,), groups, bias, padding_mode, **factory_kwargs)
+ self.single_stride = _single(stride, "stride")
+ if self.single_stride not in (1, 2):
+ raise NotImplementedError(f"Only strides of 1 and 2 are supported but got {stride}")
+ self.conv_operation = SelectionConv(
+ in_channels,
+ out_channels,
+ _single(kernel_size, "kernel_size"),
+ _single(dilation, "dilation"),
+ padding_mode,
+ )
+
+ def forward(self, inputs: GraphTracker):
+ x = self.conv_operation(inputs.x, inputs.edge_index(), inputs.selections(), inputs.interps())
+ ret = inputs.from_x(x)
+ if self.single_stride == 2:
+ x = P.stridePoolCluster(x, ret.cluster())
+ ret.x = x
+ ret.level += 1
+ return ret
+
+ @classmethod
+ def from_torch(cls, network):
+ ret = SelConv(
+ network.in_channels,
+ network.out_channels,
+ network.kernel_size,
+ network.stride,
+ network.padding,
+ network.dilation,
+ network.groups,
+ network.bias is not None,
+ network.padding_mode,
+ )
+ ret.conv_operation.copy_weights(network.weight, network.bias)
+ return ret
+
+class SelMaxPool(SelModule):
+ """ A graph based max pool module
+
+ Parameters
+ ----------
+ - kernel_size: the size of the maxpool kernel
+ """
+
+ def __init__(self, kernel_size):
+ super().__init__()
+ self.kernel_size = kernel_size
+
+ def forward(self, inputs):
+ x = P.maxPoolKernel(inputs.x, inputs.edge_index(), inputs.selections(), inputs.cluster(), self.kernel_size)
+ ret = inputs.from_x(x)
+ ret.level += 1
+ return ret
+
+ @classmethod
+ def from_torch(cls, network):
+ ret = SelMaxPool(network.kernel_size)
+ return ret
+
+class SelBatchNorm(SelModule):
+ """ A graph based BatchNorm module
+ """
+ def __init__(self,num_features):
+ super().__init__()
+ self.bn = nn.BatchNorm1d(num_features)
+ #self.bn = SimpleBatchNorm()
+
+ def forward(self, inputs):
+ x = self.bn(inputs.x)
+ ret = inputs.from_x(x)
+ return ret
+
+ def copyBatchNorm(self,source):
+ self.bn.weight = source.weight
+ self.bn.bias = source.bias
+ self.bn.running_mean = source.running_mean
+ self.bn.running_var = source.running_var
+ self.bn.eps = source.eps
+
+ @classmethod
+ def from_torch(cls, network):
+ ret = SelBatchNorm(network.num_features)
+ ret.copyBatchNorm(network)
+ #ret.bn.set_values(network)
+ return ret
+
+class SelReLU(SelModule):
+ """ A graph based ReLU module
+
+ Parameters
+ ----------
+ - inplace: whether or not to perform relu in place
+ """
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+
+ def forward(self, inputs):
+ if self.inplace:
+ inputs.x = F.relu(inputs.x, self.inplace)
+ return inputs
+ else:
+ x = F.relu(inputs.x, self.inplace)
+ ret = inputs.from_x(x)
+ return ret
+
+ @classmethod
+ def from_torch(cls, network):
+ return SelReLU(network.inplace)
+
+
+class SelSequential(SelModule, nn.Sequential):
+ """ A graph based Sequential module
+ """
+ @classmethod
+ def from_torch(cls, network: nn.Sequential):
+ return SelSequential(*map(transform_network, network))
+
+
+class SelDropout(SelModule):
+ """ A graph based dropout module
+ """
+ def forward(self, inputs):
+ return inputs
+
+ @classmethod
+ def from_torch(cls, network):
+ return SelDropout()
+
+def sel_binlinear_interp(
+ inputs: GraphTracker,
+ up_or_down: Literal["up", "down"]="up",
+ ) -> GraphTracker:
+ """ Performs bilinear interpolation as a single cluster step
+
+ Parameters
+ ----------
+ - inputs: the input graph
+ - up_or_down: either "up" or "down" indicating if it is upsampling or downsampling
+
+ Returns
+ -------
+ - the interpolated graph
+ """
+ supported_up_or_downs = ("up", "down")
+ if up_or_down not in supported_up_or_downs:
+ raise ValueError(f"up_or_down must either be 'up' or 'down' not: {up_or_down}")
+ ret = inputs.from_x(inputs.x)
+ dx = -1 if up_or_down == "up" else 1
+ ret.level += dx
+ cluster = ret.cluster()
+ up_edge_index = ret.edge_index()
+
+ #up_selections = ret.selections()
+ #ret.x = P.unpoolBilinear(ret.x, cluster, up_edge_index, up_selections)
+
+ up_interps = ret.interps()
+ ret.x = P.unpoolInterpolated(ret.x,cluster,up_edge_index,up_interps)
+
+ #ret.x = P.unpoolCluster(inputs.x, inputs.clusters[inputs.cluster_id])
+ return ret
+
+
+def sel_interpolate(
+ inputs: GraphTracker,
+ target_level: int,
+ ) -> GraphTracker:
+ """ interpolates a graph to a given cluster_id
+
+ Parameters
+ ----------
+ - inputs: the input graph data
+ - target_cluster_id: the target cluster
+
+ Returns
+ -------
+ - the interpolated graph
+ """
+ up_or_down = "up" if target_level < inputs.level else "down"
+ while inputs.level != target_level:
+ inputs = sel_binlinear_interp(inputs, up_or_down)
+ return inputs
+
+
+class SelSimpleSegmentationModel(SelModule):
+ """ A graph version of the simple segmentation model defined in torchvision's segmentation model. This is needed since the interpolate function we use needs different parameters than what is used in torch.
+ """
+ __constants__ = ["aux_classifier"]
+ def __init__(self, backbone, classifier, aux_classifier = None):
+ super().__init__()
+ self.backbone = backbone
+ self.classifier = classifier
+ self.aux_classifier = aux_classifier
+
+ def forward(self, x: GraphTracker) -> Dict[str, GraphTracker]:
+ starting_level = x.level
+ features = self.backbone(x)
+ result = OrderedDict()
+ x = features["out"]
+ x = self.classifier(x)
+ x = sel_interpolate(x, starting_level)
+ result["out"] = x
+
+ if self.aux_classifier is not None:
+ x = features["aux"]
+ x = self.aux_classifier(x)
+ x = sel_interpolate(x, starting_level)
+ result["aux"] = x
+ return result
+
+ @classmethod
+ def from_torch(cls, network):
+ ret = SelSimpleSegmentationModel(
+ backbone = transform_network(network.backbone),
+ classifier = transform_network(network.classifier),
+ aux_classifier=transform_network(network.aux_classifier) if network.aux_classifier is not None else None,
+ )
+ return ret
+
+
+__MAPPING = {
+ nn.Conv2d: SelConv,
+ nn.BatchNorm2d: SelBatchNorm,
+ nn.ReLU: SelReLU,
+ nn.Sequential: SelSequential,
+ nn.Dropout: SelDropout,
+ nn.MaxPool2d: SelMaxPool,
+ FCN: SelSimpleSegmentationModel,
+}
diff --git a/mesh_config.py b/mesh_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7c236bb9f3399c6452dc7852e9e7aa8303f5d8f
--- /dev/null
+++ b/mesh_config.py
@@ -0,0 +1,54 @@
+
+mesh_info = {}
+
+mesh_info["teddy"] = {
+ "load_directory":"mesh_data/original/teddy_mesh/",
+ "mesh_fn":"teddy.obj",
+ "texture_fn":"teddy.png",
+ "save_directory":"mesh_data/edited/teddy_mesh/"
+ }
+
+mesh_info["lime"] = {
+ "load_directory":"mesh_data/original/lime_mesh/",
+ "mesh_fn":"lime.obj",
+ "texture_fn":"lime.jpg",
+ "save_directory":"mesh_data/edited/lime_mesh/"
+ }
+
+mesh_info["crate"] = {
+ "load_directory":"mesh_data/original/crate_mesh/",
+ "mesh_fn":"wooden_crate_01_4k.gltf",
+ "texture_fn":"textures/wooden_crate_01_diff_4k.jpg",
+ "save_directory":"mesh_data/edited/crate_mesh/"
+ }
+
+mesh_info["barrel"] = {
+ "load_directory":"mesh_data/original/barrel_mesh/",
+ "mesh_fn":"barrel.obj",
+ #"mesh_fn":"barrel_03_4k.gltf",
+ "texture_fn":"textures/barrel_03_diff_4k.jpg",
+ "save_directory":"mesh_data/edited/barrel_mesh/"
+ }
+
+mesh_info["hammer"] = {
+ "load_directory":"mesh_data/original/hammer_mesh/",
+ "mesh_fn":"hammer.obj",
+ #"mesh_fn":"wooden_hammer_01_2k.gltf",
+ "texture_fn":"textures/wooden_hammer_01_diff_2k.jpg",
+ "save_directory":"mesh_data/edited/hammer_mesh/"
+ }
+
+mesh_info["crow"] = {
+ "load_directory":"mesh_data/original/crow_mesh/",
+ "mesh_fn":"Kruk.obj",
+ #"mesh_fn":"wooden_hammer_01_2k.gltf",
+ "texture_fn":"kruk-mat_Diffuse.jpg",
+ "save_directory":"mesh_data/edited/crow_mesh/"
+ }
+
+mesh_info["horse"] = {
+ "load_directory":"mesh_data/original/horse_mesh/",
+ "mesh_fn":"horse.obj",
+ "texture_fn":"teddy.png",
+ "save_directory":"mesh_data/edited/horse_mesh/"
+ }
\ No newline at end of file
diff --git a/mesh_helpers.py b/mesh_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55bde9f304b4ec89e88d332bb1bf248e7930723
--- /dev/null
+++ b/mesh_helpers.py
@@ -0,0 +1,93 @@
+import trimesh
+import torch
+import numpy as np
+
+def loadMesh(mesh_fn):
+ mesh = trimesh.load(mesh_fn, force="mesh")
+ #mesh = trimesh.load_mesh(mesh_fn)
+ return mesh
+
+def getUVs(mesh):
+ if (isinstance(mesh.visual,trimesh.visual.color.ColorVisuals)):
+ uvs = mesh.visual.to_texture().uv
+ else:
+ uvs = mesh.visual.uv
+
+ return uvs
+
+def setTexture(mesh,texture):
+ from PIL import Image
+ if texture.dtype == np.float32 or texture.dtype == np.float64:
+ texture = (255*texture).astype(np.uint8)
+ im = Image.fromarray(texture)
+ tex = trimesh.visual.TextureVisuals(uv=getUVs(mesh),image=im)
+ new_mesh=trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=tex, validate=True, process=True)
+ return new_mesh
+
+def sampleSurface(mesh,N,return_x=False,device='cpu'):
+ result = trimesh.sample.sample_surface(mesh,N,face_weight=None,sample_color=return_x)
+ pos3D = result[0]
+ face_indexes = result[1]
+ if return_x:
+ x = result[2][:,:3]/255 # No alpha channels, between 0-1
+
+ # Determine normals (the same as face normals)
+ normals = mesh.face_normals[face_indexes]
+
+ if return_x:
+
+ # Determine UVs associated with 3D points
+ vertices_all = mesh.vertices
+ uvs_all = getUVs(mesh)
+ faces = mesh.faces[face_indexes]
+
+ a = vertices_all[faces[:,0]]
+ b = vertices_all[faces[:,1]]
+ c = vertices_all[faces[:,2]]
+
+ w0, w1, w2 = getBarycentricWeights(pos3D,a,b,c,use_numpy=True)
+
+ w0 = np.expand_dims(w0,axis=1)
+ w1 = np.expand_dims(w1,axis=1)
+ w2 = np.expand_dims(w2,axis=1)
+
+ uvs = w0 * uvs_all[faces[:,0]] + w1 * uvs_all[faces[:,1]] + w2 * uvs_all[faces[:,2]]
+
+ uvs = torch.tensor(uvs,dtype=torch.float).to(device)
+
+ # Return as tensor
+ pos3D = torch.tensor(pos3D,dtype=torch.float).to(device)
+ normals = torch.tensor(normals,dtype=torch.float).to(device)
+
+ if return_x:
+ x = torch.tensor(x,dtype=torch.float).to(device)
+ return pos3D,normals,uvs,x
+ else:
+ return pos3D,normals
+
+def getBarycentricWeights(p,a,b,c,use_numpy=False):
+
+ # Taken from https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates
+ v0 = b-a
+ v1 = c-a
+ v2 = p-a
+
+ if use_numpy:
+ d00 = np.sum(v0*v0,axis=1)
+ d01 = np.sum(v0*v1,axis=1)
+ d11 = np.sum(v1*v1,axis=1)
+ d20 = np.sum(v2*v0,axis=1)
+ d21 = np.sum(v2*v1,axis=1)
+ else:
+ d00 = torch.sum(v0*v0,dim=1)
+ d01 = torch.sum(v0*v1,dim=1)
+ d11 = torch.sum(v1*v1,dim=1)
+ d20 = torch.sum(v2*v0,dim=1)
+ d21 = torch.sum(v2*v1,dim=1)
+ denom = d00*d11 - d01*d01
+ w1 = (d11 * d20 - d01 * d21)/denom
+ w2 = (d00 * d21 - d01 * d20)/denom
+
+ w0 = 1 - w1 - w2
+
+ return w0,w1,w2
\ No newline at end of file
diff --git a/output.splat b/output.splat
new file mode 100644
index 0000000000000000000000000000000000000000..ce04350c02b3cb30f7add04400d510b8cb32005b
--- /dev/null
+++ b/output.splat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9c98911aa0af25cefc959d8db8da26fa4770986b630d26ea39d4c223ab478c45
+size 4149408
diff --git a/plyio.py b/plyio.py
new file mode 100644
index 0000000000000000000000000000000000000000..602271149c33d3eb10b7e6820f7be0e5d81cb219
--- /dev/null
+++ b/plyio.py
@@ -0,0 +1,136 @@
+# Modified from https://antimatter15.com/splat
+
+from plyfile import PlyData
+import numpy as np
+import argparse
+from io import BytesIO
+
+
+def splat_to_numpy(file_path):
+ with open(file_path, 'rb') as f:
+ splat_data = f.read()
+
+ splat_dtype = np.dtype([
+ ('position', np.float32, 3),
+ ('scale', np.float32, 3),
+ ('color', np.uint8, 4),
+ ('rotation', np.uint8, 4)
+ ])
+
+
+ splat_array = np.frombuffer(splat_data, dtype=splat_dtype)
+
+ points = splat_array["position"]
+ scales = splat_array["scale"]
+ rots = (splat_array["rotation"]/255)*2 - 1
+ color = splat_array["color"]/255
+
+
+ return points, scales, rots.astype(np.float32), color.astype(np.float32)
+
+
+def ply_to_numpy(ply_file_path):
+ plydata = PlyData.read(ply_file_path)
+ vert = plydata["vertex"]
+ sorted_indices = np.argsort(
+ -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"])
+ / (1 + np.exp(-vert["opacity"]))
+ )
+ buffer = BytesIO()
+
+ positions = np.zeros((len(sorted_indices), 3), dtype=np.float32)
+ scales = np.zeros((len(sorted_indices), 3), dtype=np.float32)
+ rots = np.zeros((len(sorted_indices), 4), dtype=np.float32)
+ colors = np.zeros((len(sorted_indices), 4), dtype=np.float32)
+
+ for idx in sorted_indices:
+ v = plydata["vertex"][idx]
+ position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
+ scale = np.exp(
+ np.array(
+ [v["scale_0"], v["scale_1"], v["scale_2"]],
+ dtype=np.float32,
+ )
+ )
+ rot = np.array(
+ [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
+ dtype=np.float32,
+ )
+ SH_C0 = 0.28209479177387814
+ color = np.array(
+ [
+ 0.5 + SH_C0 * v["f_dc_0"],
+ 0.5 + SH_C0 * v["f_dc_1"],
+ 0.5 + SH_C0 * v["f_dc_2"],
+ 1 / (1 + np.exp(-v["opacity"])),
+ ]
+ )
+
+
+ positions[idx] = position
+ scales[idx] = scale
+ rots[idx] = rot
+ colors[idx] = color
+ return positions, scales, rots, colors
+
+
+
+
+def numpy_to_splat(positions, scales, rots, colors, output_path, file_type):
+ buffer = BytesIO()
+ if file_type == 'ply':
+
+ for idx in range(len(positions)):
+ position = positions[idx]
+ scale = scales[idx]
+ rot = rots[idx]
+ color = colors[idx]
+ buffer.write(position.tobytes())
+ buffer.write(scale.tobytes())
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
+ buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes()
+ )
+
+ splat_data = buffer.getvalue()
+ with open(output_path, "wb") as f:
+ f.write(splat_data)
+ else:
+ for idx in range(len(positions)):
+ position = positions[idx]
+ scale = scales[idx]
+ rot = rots[idx]
+ color = colors[idx]
+ buffer.write(position.tobytes())
+ buffer.write(scale.tobytes())
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
+ buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes()
+ )
+
+ splat_data = buffer.getvalue()
+ with open(output_path, "wb") as f:
+ f.write(splat_data)
+
+ return splat_data
+
+
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Convert PLY files to SPLAT format.")
+ parser.add_argument(
+ "input_files", nargs="+", help="The input PLY files to process."
+ )
+ parser.add_argument(
+ "--output", "-o", default="output.splat", help="The output SPLAT file."
+ )
+ args = parser.parse_args()
+ for input_file in args.input_files:
+ print(f"Processing {input_file}...")
+ positions, scales, rotations, colors = ply_to_numpy(input_file)
+
+ numpy_to_splat(positions, scales, rotations, colors, args.output)
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/pointCloudToMesh.py b/pointCloudToMesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..6015a03c30c1014435e3095b1ea4445ae85dc5c2
--- /dev/null
+++ b/pointCloudToMesh.py
@@ -0,0 +1,201 @@
+import numpy as np
+import open3d as o3d
+from skimage import measure
+from scipy.spatial import cKDTree
+import splat_helpers as splt
+
+# based on code from https://towardsdatascience.com/transform-point-clouds-into-3d-meshes-a-python-guide-8b0407a780e6
+# credit Florent Poux
+# Towards Data Science (2024)
+
+def MarchingCubes_from_ply(dataset, voxel_size, iso_level_percentile):
+
+ pcd = o3d.io.read_point_cloud(dataset)
+
+ # Convert Open3D point cloud to numpy array
+ points = np.asarray(pcd.points)
+ # Compute the bounds of the point cloud
+ mins = np.min(points, axis=0)
+ maxs = np.max(points, axis=0)
+ # Create a 3D grid
+ x = np.arange(mins[0], maxs[0], voxel_size)
+ y = np.arange(mins[1], maxs[1], voxel_size)
+ z = np.arange(mins[2], maxs[2], voxel_size)
+ x, y, z = np.meshgrid(x, y, z, indexing='ij')
+
+ # Create a KD-tree for efficient nearest neighbor search
+ tree = cKDTree(points)
+
+ # Compute the scalar field (distance to nearest point)
+ grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T
+ distances, _ = tree.query(grid_points)
+ scalar_field = distances.reshape(x.shape)
+
+ # Determine iso-level based on percentile of distances
+ iso_level = np.percentile(distances, iso_level_percentile)
+
+ # Apply Marching Cubes
+ verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level)
+
+ # Scale and translate vertices back to original coordinate system
+ verts = verts * voxel_size + mins
+ # Create mesh
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ # Compute vertex normals
+ mesh.compute_vertex_normals()
+ # Visualize the result
+ o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True)
+
+
+
+def MarchingCubes_with_filtering(dataset, voxel_size, iso_level_percentile, threshold=99, out_file = 'out.obj'):
+
+ pos3D, _, _, _, _, _, _ = splt.splat_unpacker_threshold(25, dataset, threshold)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pos3D.numpy())
+
+ # Convert Open3D point cloud to numpy array
+ points = np.asarray(pcd.points)
+ # Compute the bounds of the point cloud
+ mins = np.min(points, axis=0)
+ maxs = np.max(points, axis=0)
+ # Create a 3D grid
+ x = np.arange(mins[0], maxs[0], voxel_size)
+ y = np.arange(mins[1], maxs[1], voxel_size)
+ z = np.arange(mins[2], maxs[2], voxel_size)
+ x, y, z = np.meshgrid(x, y, z, indexing='ij')
+
+ # Create a KD-tree for efficient nearest neighbor search
+ tree = cKDTree(points)
+
+ # Compute the scalar field (distance to nearest point)
+ grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T
+ distances, _ = tree.query(grid_points)
+ scalar_field = distances.reshape(x.shape)
+
+ # Determine iso-level based on percentile of distances
+ iso_level = np.percentile(distances, iso_level_percentile)
+
+ # Apply Marching Cubes
+ verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level)
+
+ # Scale and translate vertices back to original coordinate system
+ verts = verts * voxel_size + mins
+ # Create mesh
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ # Compute vertex normals
+ mesh.compute_vertex_normals()
+ # Save the result
+ o3d.io.write_triangle_mesh(out_file, mesh)
+ # Visualize the result
+ o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(np.asarray(mesh.vertices))
+ pcd.estimate_normals()
+ o3d.visualization.draw_geometries([pcd])
+
+
+
+def MarchingCubes_return_vertices(dataset, visualize = False):
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(dataset.numpy())
+ voxel_size_tensor = (abs(dataset.min()) + abs(dataset.max()))/100
+ voxel_size = voxel_size_tensor.item()
+ iso_level_percentile = 5
+
+ # Convert Open3D point cloud to numpy array
+ points = np.asarray(pcd.points)
+ # Compute the bounds of the point cloud
+ mins = np.min(points, axis=0)
+ maxs = np.max(points, axis=0)
+ # Create a 3D grid
+ x = np.arange(mins[0], maxs[0], voxel_size)
+ y = np.arange(mins[1], maxs[1], voxel_size)
+ z = np.arange(mins[2], maxs[2], voxel_size)
+ x, y, z = np.meshgrid(x, y, z, indexing='ij')
+
+ # Create a KD-tree for efficient nearest neighbor search
+ tree = cKDTree(points)
+
+ # Compute the scalar field (distance to nearest point)
+ grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T
+ distances, _ = tree.query(grid_points)
+ scalar_field = distances.reshape(x.shape)
+
+ # Determine iso-level based on percentile of distances
+ iso_level = np.percentile(distances, iso_level_percentile)
+
+ # Apply Marching Cubes
+ verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level)
+
+ # Scale and translate vertices back to original coordinate system
+ verts = verts * voxel_size + mins
+ # Create mesh
+ mesh = o3d.geometry.TriangleMesh()
+ mesh.vertices = o3d.utility.Vector3dVector(verts)
+ mesh.triangles = o3d.utility.Vector3iVector(faces)
+ # Compute vertex normals
+ mesh.compute_vertex_normals()
+ # Save the result
+ #o3d.io.write_triangle_mesh(out_file, mesh)
+
+ pcd2 = o3d.geometry.PointCloud()
+ pcd2.points = o3d.utility.Vector3dVector(np.asarray(mesh.vertices))
+ pcd2.estimate_normals()
+ if visualize == True:
+ # Visualize the result
+ o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True)
+ o3d.visualization.draw_geometries([pcd2])
+
+ vertices = np.asarray(pcd2.points, dtype= np.float32)
+ return vertices
+
+def graph_Points(points, colors):
+ import matplotlib.pyplot as plt
+
+ x = points[:,0]
+ y = points[:,1]
+ z = points[:,2]
+
+
+ # 3D Plot
+ fig = plt.figure()
+ ax = fig.add_subplot(projection='3d')
+
+ # Scatter plot with colors
+ ax.scatter(x, y, z, c=colors)
+
+ # Set labels for axes
+ ax.set_xlabel('X')
+ ax.set_ylabel('Y')
+ ax.set_zlabel('Z')
+
+ # Display the plot
+ plt.show()
+
+ return
+
+def Estimate_Normals(points, n = 25, threshold=75):
+
+ #pos3D, _, _, _, _ = splt.splat_unpacker_threshold(n, dataset, threshold)
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points.numpy())
+
+ pcd.estimate_normals()
+ pcd.normalize_normals()
+ normals = np.asarray(pcd.normals)
+
+ #scaleSize = normals.max()
+ #normalsBad = np.random.rand(normals.shape[0],normals.shape[1])
+ #normalsBad = normalsBad * scaleSize
+ #return normalsBad
+
+ return normals
+
+
+
\ No newline at end of file
diff --git a/pooling.py b/pooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..17401e3e42bbc521e331c038a0403e02f3069787
--- /dev/null
+++ b/pooling.py
@@ -0,0 +1,143 @@
+import torch
+from torch_scatter import scatter
+
+def avgPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]):
+
+ is_even = kernel_size % 2 == 0
+ full_passes = kernel_size//2 - int(is_even)
+
+ # Assumes the lowest number node index is the topleft most position in the cluster
+ indices = torch.arange(len(x)).to(x.device)
+
+ # Find the minimum node index in each cluster and select those x values
+ picks = scatter(indices, cluster, dim=0, reduce='min')
+
+ # Send max pool messages the appropriate number of times
+ for _ in range(full_passes):
+ message = x[edge_index[1]]
+ x = scatter(message,edge_index[0],dim=0,reduce='mean') # Aggregate
+
+ # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner
+ # Repeat the process one more time going to the bottom right
+ if is_even:
+
+ # Prefilter edge_index
+ keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device)
+ for i in even_dirs:
+ keep[torch.where(selections == i)] = True
+ even_edge_index = edge_index[:,torch.where(keep)[0]]
+
+ message = x[even_edge_index[1]]
+ x = scatter(message,even_edge_index[0],dim=0,reduce='mean') # Aggregate
+
+ # Take the previously selected nodes
+ x = x[picks]
+
+ return x
+
+def maxPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]):
+
+ is_even = kernel_size % 2 == 0
+ full_passes = kernel_size//2 - int(is_even)
+
+ # Assumes the lowest number node index is the topleft most position in the cluster
+ indices = torch.arange(len(x)).to(x.device)
+
+ # Find the minimum node index in each cluster and select those x values
+ picks = scatter(indices, cluster, dim=0, reduce='min')
+
+ # Send max pool messages the appropriate number of times
+ for _ in range(full_passes):
+ message = x[edge_index[1]]
+ x = scatter(message,edge_index[0],dim=0,reduce='max') # Aggregate
+
+ # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner
+ # Repeat the process one more time going to the bottom right
+ if is_even:
+
+ # Prefilter edge_index
+ keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device)
+ for i in even_dirs:
+ keep[torch.where(selections == i)] = True
+ even_edge_index = edge_index[:,torch.where(keep)[0]]
+
+ message = x[even_edge_index[1]]
+ x = scatter(message,even_edge_index[0],dim=0,reduce='max') # Aggregate
+
+ # Take the previously selected nodes
+ x = x[picks]
+
+ return x
+
+
+def stridePoolCluster(x,cluster):
+
+ # Assumes the lowest number node index is the topleft most position in the cluster
+ indices = torch.arange(len(x)).to(x.device)
+
+ # Find the minimum node index in each cluster and select those x values
+ picks = scatter(indices, cluster, dim=0, reduce='min')
+ x = x[picks]
+
+ return x
+
+
+def maxPoolCluster(x,cluster):
+
+ x = scatter(x, cluster, dim=0, reduce='max')
+ return x
+
+def avgPoolCluster(x,cluster,edge_index=None, edge_weight=None):
+
+ x = scatter(x, cluster, dim=0, reduce='mean')
+ return x
+
+
+def unpoolInterpolated(x,cluster,up_edge_index,up_interps=None):
+
+ if up_interps is None:
+ return unpoolEdgeAverage(x,cluster,up_edge_index)
+
+ # Determine node averages based on based on interps
+ target_clusters = cluster[up_edge_index[1]]
+
+ node_vals = x[target_clusters]*up_interps.unsqueeze(1)
+ x = scatter(node_vals,up_edge_index[0],dim=0,reduce='add')
+ norm = scatter(up_interps,up_edge_index[0],dim=0)
+ x/=norm.unsqueeze(1)
+
+ return x
+
+def unpoolBilinear(x,cluster,up_edge_index,up_selections,selection_dirs=[0,1,7,8]):
+
+ # Remove edges that won't be used for the bilinear interpolation calculation
+ keep = torch.zeros_like(up_selections,dtype=torch.bool).to(x.device)
+ for i in selection_dirs:
+ keep[torch.where(up_selections == i)] = True
+
+ ref_edge_index = up_edge_index[:,torch.where(keep)[0]]
+
+ cluster_index = torch.vstack((ref_edge_index[0],cluster[ref_edge_index[1]]))
+ cluster_index = torch.unique(cluster_index,dim=1)
+ x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean')
+
+ return x
+
+def unpoolEdgeAverage(x,cluster,up_edge_index,weighted=True):
+ # Interpolates based on the number of connections to each previous cluster. Works best with dense data.
+ # If weighted = False, clusters are weighted equally regardless of the number of connections
+
+ if weighted:
+ target_clusters = cluster[up_edge_index[1]]
+ x = scatter(x[target_clusters],up_edge_index[0],dim=0,reduce='mean')
+
+ else:
+ cluster_index = torch.vstack((up_edge_index[0],cluster[up_edge_index[1]]))
+ cluster_index = torch.unique(cluster_index,dim=1)
+ x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean')
+
+ return x
+
+def unpoolCluster(x,cluster):
+
+ return x[cluster]
diff --git a/pyntcloud_io.py b/pyntcloud_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c51e45c6f1656f27a4de6a43318be0a5c47807d
--- /dev/null
+++ b/pyntcloud_io.py
@@ -0,0 +1,327 @@
+#
+#CREDIT TO https://github.com/daavoo/pyntcloud/blob/master/pyntcloud/io/ply.py
+#
+#User daavoo
+#
+
+# HAKUNA MATATA
+
+import sys
+import numpy as np
+import pandas as pd
+from collections import defaultdict
+
+sys_byteorder = ('>', '<')[sys.byteorder == 'little']
+
+ply_dtypes = dict([
+ (b'int8', 'i1'),
+ (b'char', 'i1'),
+ (b'uint8', 'u1'),
+ (b'uchar', 'b1'),
+ (b'uchar', 'u1'),
+ (b'int16', 'i2'),
+ (b'short', 'i2'),
+ (b'uint16', 'u2'),
+ (b'ushort', 'u2'),
+ (b'int32', 'i4'),
+ (b'int', 'i4'),
+ (b'uint32', 'u4'),
+ (b'uint', 'u4'),
+ (b'float32', 'f4'),
+ (b'float', 'f4'),
+ (b'float64', 'f8'),
+ (b'double', 'f8')
+])
+
+valid_formats = {'ascii': '', 'binary_big_endian': '>',
+ 'binary_little_endian': '<'}
+
+
+def read_ply(filename, allow_bool=False):
+ """ Read a .ply (binary or ascii) file and store the elements in pandas DataFrame.
+
+ Parameters
+ ----------
+ filename: str
+ Path to the filename
+ allow_bool: bool
+ flag to allow bool as a valid PLY dtype. False by default to mirror original PLY specification.
+
+ Returns
+ -------
+ data: dict
+ Elements as pandas DataFrames; comments and ob_info as list of string
+ """
+ if allow_bool:
+ ply_dtypes[b'bool'] = '?'
+
+ with open(filename, 'rb') as ply:
+
+ if b'ply' not in ply.readline():
+ raise ValueError('The file does not start with the word ply')
+ # get binary_little/big or ascii
+ fmt = ply.readline().split()[1].decode()
+ # get extension for building the numpy dtypes
+ ext = valid_formats[fmt]
+
+ line = []
+ dtypes = defaultdict(list)
+ count = 2
+ points_size = None
+ mesh_size = None
+ has_texture = False
+ comments = []
+ while b'end_header' not in line and line != b'':
+ line = ply.readline()
+
+ if b'element' in line:
+ line = line.split()
+ name = line[1].decode()
+ size = int(line[2])
+ if name == "vertex":
+ points_size = size
+ elif name == "face":
+ mesh_size = size
+
+ elif b'property' in line:
+ line = line.split()
+ # element mesh
+ if b'list' in line:
+
+ if b"vertex_indices" in line[-1] or b"vertex_index" in line[-1]:
+ mesh_names = ["n_points", "v1", "v2", "v3"]
+ else:
+ has_texture = True
+ mesh_names = ["n_coords"] + ["v1_u", "v1_v", "v2_u", "v2_v", "v3_u", "v3_v"]
+
+ if fmt == "ascii":
+ # the first number has different dtype than the list
+ dtypes[name].append(
+ (mesh_names[0], ply_dtypes[line[2]]))
+ # rest of the numbers have the same dtype
+ dt = ply_dtypes[line[3]]
+ else:
+ # the first number has different dtype than the list
+ dtypes[name].append(
+ (mesh_names[0], ext + ply_dtypes[line[2]]))
+ # rest of the numbers have the same dtype
+ dt = ext + ply_dtypes[line[3]]
+
+ for j in range(1, len(mesh_names)):
+ dtypes[name].append((mesh_names[j], dt))
+ else:
+ if fmt == "ascii":
+ dtypes[name].append(
+ (line[2].decode(), ply_dtypes[line[1]]))
+ else:
+ dtypes[name].append(
+ (line[2].decode(), ext + ply_dtypes[line[1]]))
+
+ elif b'comment' in line:
+ line = line.split(b" ", 1)
+ comment = line[1].decode().rstrip()
+ comments.append(comment)
+
+ count += 1
+
+ # for bin
+ end_header = ply.tell()
+
+ data = {}
+
+ if comments:
+ data["comments"] = comments
+
+ if fmt == 'ascii':
+ top = count
+ bottom = 0 if mesh_size is None else mesh_size
+
+ names = [x[0] for x in dtypes["vertex"]]
+
+ data["points"] = pd.read_csv(filename, sep=" ", header=None, engine="python",
+ skiprows=top, skipfooter=bottom, usecols=names, names=names)
+
+ for n, col in enumerate(data["points"].columns):
+ data["points"][col] = data["points"][col].astype(
+ dtypes["vertex"][n][1])
+
+ if mesh_size :
+ top = count + points_size
+
+ names = np.array([x[0] for x in dtypes["face"]])
+ usecols = [1, 2, 3, 5, 6, 7, 8, 9, 10] if has_texture else [1, 2, 3]
+ names = names[usecols]
+
+ data["mesh"] = pd.read_csv(
+ filename, sep=" ", header=None, engine="python", skiprows=top, usecols=usecols, names=names)
+
+ for n, col in enumerate(data["mesh"].columns):
+ data["mesh"][col] = data["mesh"][col].astype(
+ dtypes["face"][n + 1][1])
+
+ else:
+ with open(filename, 'rb') as ply:
+ ply.seek(end_header)
+ points_np = np.fromfile(ply, dtype=dtypes["vertex"], count=points_size)
+ if ext != sys_byteorder:
+ points_np = points_np.byteswap().newbyteorder()
+ data["points"] = pd.DataFrame(points_np)
+ if mesh_size:
+ mesh_np = np.fromfile(ply, dtype=dtypes["face"], count=mesh_size)
+ if ext != sys_byteorder:
+ mesh_np = mesh_np.byteswap().newbyteorder()
+ data["mesh"] = pd.DataFrame(mesh_np)
+ data["mesh"].drop('n_points', axis=1, inplace=True)
+
+ return data
+
+
+def write_ply(filename, points=None, mesh=None, as_text=False, comments=None):
+ """Write a PLY file populated with the given fields.
+
+ Parameters
+ ----------
+ filename: str
+ The created file will be named with this
+ points: ndarray
+ mesh: ndarray
+ as_text: boolean
+ Set the write mode of the file. Default: binary
+ comments: list of string
+
+ Returns
+ -------
+ boolean
+ True if no problems
+
+ """
+ if not filename.endswith('ply'):
+ filename += '.ply'
+
+ # open in text mode to write the header
+ with open(filename, 'w') as ply:
+ header = ['ply']
+
+ if as_text:
+ header.append('format ascii 1.0')
+ else:
+ header.append('format binary_' + sys.byteorder + '_endian 1.0')
+
+ if comments:
+ for comment in comments:
+ header.append('comment ' + comment)
+
+ if points is not None:
+ header.extend(describe_element('vertex', points))
+ if mesh is not None:
+ mesh = mesh.copy()
+ mesh.insert(loc=0, column="n_points", value=3)
+ mesh["n_points"] = mesh["n_points"].astype("u1")
+ header.extend(describe_element('face', mesh))
+
+ header.append('end_header')
+
+ for line in header:
+ ply.write("%s\n" % line)
+
+ if as_text:
+ if points is not None:
+ points.to_csv(filename, sep=" ", index=False, header=False, mode='a',
+ encoding='ascii')
+ if mesh is not None:
+ mesh.to_csv(filename, sep=" ", index=False, header=False, mode='a',
+ encoding='ascii')
+
+ else:
+ with open(filename, 'ab') as ply:
+ if points is not None:
+ points.to_records(index=False).tofile(ply)
+ if mesh is not None:
+ mesh.to_records(index=False).tofile(ply)
+
+ return True
+
+
+def describe_element(name, df):
+ """ Takes the columns of the dataframe and builds a ply-like description
+
+ Parameters
+ ----------
+ name: str
+ df: pandas DataFrame
+
+ Returns
+ -------
+ element: list[str]
+ """
+ property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int', 'b': 'bool'}
+ element = ['element ' + name + ' ' + str(len(df))]
+
+ if name == 'face':
+ element.append("property list uchar int vertex_indices")
+
+ else:
+ for i in range(len(df.columns)):
+ # get first letter of dtype to infer format
+ f = property_formats[str(df.dtypes[i])[0]]
+ element.append('property ' + f + ' ' + df.columns.values[i])
+
+ return element
+
+
+
+
+
+def write_ply_float(filename, points=None, mesh=None, as_text=False, comments=None):
+ """Write a PLY file populated with the given fields.
+
+ Parameters
+ ----------
+ filename: str
+ The created file will be named with this
+ points: ndarray
+ mesh: ndarray
+ as_text: boolean
+ Set the write mode of the file. Default: binary
+ comments: list of string
+
+ Returns
+ -------
+ boolean
+ True if no problems
+
+ """
+ if not filename.endswith('ply'):
+ filename += '.ply'
+
+ # open in text mode to write the header
+ with open(filename, 'w') as ply:
+ header = ['ply']
+
+ #header.append('format ascii 1.0')
+ header.append('format binary_' + sys.byteorder + '_endian 1.0')
+
+ if comments:
+ for comment in comments:
+ header.append('comment ' + comment)
+
+ if points is not None:
+ header.extend(describe_element('vertex', points))
+ if mesh is not None:
+ mesh = mesh.copy()
+ mesh.insert(loc=0, column="n_points", value=3)
+ mesh["n_points"] = mesh["n_points"].astype("u1")
+ header.extend(describe_element('face', mesh))
+
+ header.append('end_header')
+
+ for line in header:
+ ply.write("%s\n" % line)
+
+ with open(filename, 'ab') as ply:
+ if points is not None:
+ pointsNumPy = points.to_numpy()
+ pointsNumPy.tofile(filename,format=' Tensor:
+ """"""
+
+ all_nodes = torch.arange(x.shape[0]).to(x.device)
+
+ if self.padding_mode == 'constant':
+ # Constant value of the average of the all the nodes
+ x_mean = torch.mean(x,dim=0)
+
+ out = torch.zeros((x.shape[0],self.out_channels)).to(x.device)
+
+ if self.padding_mode == 'normalize':
+ dir_count = torch.zeros((x.shape[0],1)).to(x.device)
+
+ if self.kernel_size == 1 or self.kernel_size == 3:
+
+ # Find the appropriate node for each selection by stepping through connecting edges
+ for s in range(self.selection_count):
+ cur_dir = torch.where(selections == s)[0]
+
+ cur_source = edge_index[0,cur_dir]
+ cur_target = edge_index[1,cur_dir]
+
+ if interps is not None:
+ cur_interps = interps[cur_dir]
+ cur_interps = torch.unsqueeze(cur_interps,dim=1)
+ #print(torch.amin(cur_interps),torch.amax(cur_interps))
+
+ if self.dilation > 1:
+ for _ in range(1, self.dilation):
+ vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,cur_dir])
+ cur_source = cur_source[ind1]
+ cur_target = edge_index[1,cur_dir][ind2]
+ if interps is not None:
+ cur_interps = cur_interps[ind1]
+
+ # Main Calculation
+ if interps is None:
+ #out[cur_source] += torch.matmul(x[cur_target], self.weight[s])
+ result = torch.matmul(x[cur_target], self.weight[s])
+ else:
+ #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s])
+ result = cur_interps*torch.matmul(x[cur_target], self.weight[s])
+
+ # Adding with duplicate indices
+ out.index_add_(0,cur_source,result)
+
+ # Sanity check
+ #from tqdm import tqdm
+ #for i,node in enumerate(tqdm(cur_source)):
+ # out[node] += result[i]
+
+ if self.padding_mode == 'constant':
+ missed_nodes = setdiff1d(all_nodes, cur_source)
+ #out[missed_nodes] += torch.matmul(x_mean, self.weight[s])
+ out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s]))
+
+ if self.padding_mode == 'replicate':
+ missed_nodes = setdiff1d(all_nodes, cur_source)
+ #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s])
+ out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s]))
+
+ if self.padding_mode == 'reflect':
+ missed_nodes = setdiff1d(all_nodes, cur_source)
+
+ opposite = s+4
+ if opposite > 8:
+ opposite = opposite % 9 + 1
+
+ op_dir = torch.where(selections == opposite)[0]
+
+ op_source = edge_index[0,op_dir]
+ op_target = edge_index[1,op_dir]
+
+ if interps is not None:
+ op_interps = interps[op_dir]
+ op_interps = torch.unsqueeze(op_interps,dim=1)
+
+
+ # Only take edges that are part of missed nodes
+ vals, ind1, ind2 = intersect1d(op_source,missed_nodes)
+ op_source = op_source[ind1]
+ op_target = op_target[ind1]
+ if interps is not None:
+ op_interps = op_interps[ind1]
+
+ if self.dilation > 1:
+ for _ in range(1, self.dilation):
+ vals, ind1, ind2 = intersect1d(op_target,edge_index[0,op_dir])
+ op_source = op_source[ind1]
+ op_target = edge_index[1,op_dir][ind2]
+ if interps is not None:
+ op_interps = op_interps[ind1]
+
+ # Main Calculation
+ if interps is None:
+ result = torch.matmul(x[op_target], self.weight[s])
+ else:
+ result = op_interps * torch.matmul(x[op_target], self.weight[s])
+
+ out.index_add_(0,op_source,result)
+
+ if self.padding_mode == 'normalize':
+ dir_count[torch.unique(cur_source)] += 1
+
+ else:
+ width = self.kernel_size//2
+ horiz = torch.arange(-width,width+1).to(x.device)
+ vert = torch.arange(-width,width+1).to(x.device)
+
+ right = torch.where(selections == 1)[0]
+ left = torch.where(selections == 5)[0]
+ down = torch.where(selections == 7)[0]
+ up = torch.where(selections == 3)[0]
+
+ center = torch.where(selections == 0)[0]
+
+ # Find the appropriate node for each selection by stepping through connecting edges
+ s = 0
+ for i in range(self.kernel_size):
+ for j in range(self.kernel_size):
+ x_loc = horiz[j]
+ y_loc = vert[i]
+
+ cur_source = edge_index[0,center] #Starting location
+ cur_target = edge_index[1,center]
+
+ if interps is not None:
+ cur_interps = interps[center]
+ cur_interps = torch.unsqueeze(cur_interps,dim=1)
+
+ #print(torch.sum(cur_target-cur_source))
+
+ #print(cur_target.shape)
+
+ if x_loc < 0:
+ for _ in range(self.dilation*abs(x_loc)):
+ vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,left])
+ cur_source = cur_source[ind1]
+ cur_target = edge_index[1,left][ind2]
+ if interps is not None:
+ cur_interps = cur_interps[ind1]
+ if x_loc > 0:
+ for _ in range(self.dilation*abs(x_loc)):
+ vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,right])
+ cur_source = cur_source[ind1]
+ cur_target = edge_index[1,right][ind2]
+ if interps is not None:
+ cur_interps = cur_interps[ind1]
+
+ if y_loc < 0:
+ for _ in range(self.dilation*abs(y_loc)):
+ vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,up])
+ cur_source = cur_source[ind1]
+ cur_target = edge_index[1,up][ind2]
+ if interps is not None:
+ cur_interps = cur_interps[ind1]
+
+ if y_loc > 0:
+ for _ in range(self.dilation*abs(y_loc)):
+ vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,down])
+ cur_source = cur_source[ind1]
+ cur_target = edge_index[1,down][ind2]
+ if interps is not None:
+ cur_interps = cur_interps[ind1]
+
+ # Main Calculation
+ if interps is None:
+ #out[cur_source] += torch.matmul(x[cur_target], self.weight[s])
+ result = torch.matmul(x[cur_target], self.weight[s])
+ else:
+ #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s])
+ result = cur_interps*torch.matmul(x[cur_target], self.weight[s])
+
+ # Adding with duplicate indices
+ out.index_add_(0,cur_source,result)
+
+ if self.padding_mode == 'constant':
+ missed_nodes = setdiff1d(all_nodes, cur_source)
+ #out[missed_nodes] += torch.matmul(x_mean, self.weight[s])
+ out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s]))
+
+ if self.padding_mode == 'replicate':
+ missed_nodes = setdiff1d(all_nodes, cur_source)
+ #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s])
+ out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s]))
+
+ if self.padding_mode == 'reflect':
+ raise ValueError("Reflect padding not yet implemented for larger kernels")
+
+ if self.padding_mode == 'normalize':
+ dir_count[torch.unique(cur_source)] += 1
+
+ s+=1
+
+ #print(self.selection_count/(dir_count + 1e-8))
+ #test_val = self.selection_count/(dir_count + 1e-8)
+ # print(torch.max(test_val),torch.min(test_val),torch.mean(test_val))
+
+ if self.padding_mode == 'zeros':
+ pass # Already accounted for in the graph structure, no further computation needed
+ elif self.padding_mode == 'normalize':
+ out *= self.selection_count/(dir_count + 1e-8)
+ elif self.padding_mode == 'constant':
+ pass # Processed earlier
+ elif self.padding_mode == 'replicate':
+ pass
+ elif self.padding_mode == 'reflect':
+ pass
+ elif self.padding_mode == 'circular':
+ raise ValueError("Circular padding cannot be generalized on a graph. Instead, create a graph with edges connecting to the wrapped around nodes")
+ else:
+ raise ValueError(f"Unknown padding mode: {self.padding_mode}")
+
+ # Add bias if applicable
+ out += self.bias
+
+ return out
+
+
+ def copy_weightsNxN(self,weight,bias=None):
+
+ width = int(sqrt(self.selection_count))
+
+ # Assumes weight comes in as [output channels, input channels, row, col]
+ for i in range(self.selection_count):
+ self.weight[i] = weight[:,:,i//width,i%width].permute(1,0)
+
+
+ def copy_weights3x3(self,weight,bias=None):
+
+
+ # Assumes weight comes in as [output channels, input channels, row, col]
+ # Assumes weight is a 3x3
+
+ # Current Ordering
+ # 4 3 2
+ # 5 0 1
+ # 6 7 8
+
+ # Need to flip horizontally per implementation of convolution
+ #self.weight[5] = weight[:,:,1,2].permute(1,0)
+ #self.weight[7] = weight[:,:,0,1].permute(1,0)
+ #self.weight[1] = weight[:,:,1,0].permute(1,0)
+ #self.weight[3] = weight[:,:,2,1].permute(1,0)
+ #self.weight[6] = weight[:,:,0,2].permute(1,0)
+ #self.weight[8] = weight[:,:,0,0].permute(1,0)
+ #self.weight[2] = weight[:,:,2,0].permute(1,0)
+ #self.weight[4] = weight[:,:,2,2].permute(1,0)
+ #self.weight[0] = weight[:,:,1,1].permute(1,0)
+
+ self.weight[1] = weight[:,:,1,2].permute(1,0)
+ self.weight[3] = weight[:,:,0,1].permute(1,0)
+ self.weight[5] = weight[:,:,1,0].permute(1,0)
+ self.weight[7] = weight[:,:,2,1].permute(1,0)
+ self.weight[2] = weight[:,:,0,2].permute(1,0)
+ self.weight[4] = weight[:,:,0,0].permute(1,0)
+ self.weight[6] = weight[:,:,2,0].permute(1,0)
+ self.weight[8] = weight[:,:,2,2].permute(1,0)
+ self.weight[0] = weight[:,:,1,1].permute(1,0)
+
+
+ def copy_weights1x1(self, weight, bias=None):
+ self.weight[0] = weight[:,:,0,0].permute(1, 0)
+
+
+ def copy_weights(self,weight,bias=None):
+
+ if self.kernel_size == 3:
+ self.copy_weights3x3(weight,bias)
+ elif self.kernel_size == 1:
+ self.copy_weights1x1(weight, bias)
+ else:
+ self.copy_weightsNxN(weight,bias)
+
+ if bias is None:
+ self.bias[:] = 0.0
+ else:
+ self.bias = bias
+
+
+ def __repr__(self):
+ return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
+ self.out_channels)
+
diff --git a/sphere_helpers.py b/sphere_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..650be180309411fa7e36f1e8b54220c3ab137bfd
--- /dev/null
+++ b/sphere_helpers.py
@@ -0,0 +1,180 @@
+import torch
+import numpy as np
+import graph_helpers as gh
+from math import pi, sqrt, sin
+
+def equirec2spherical(rows, cols, device = 'cpu'):
+ theta_steps = torch.linspace(0, 2*pi, int(cols+1)).to(device)[:-1] # Avoid overlapping points
+ phi_steps = torch.linspace(0, pi, int(rows+1)).to(device)[:-1]
+ theta, phi = torch.meshgrid(theta_steps, phi_steps,indexing='xy')
+ return theta.flatten(),phi.flatten()
+
+def spherical2equirec(theta,phi,rows,cols):
+ x = theta*cols/(2*pi)
+ y = phi*rows/pi
+ return x,y
+
+def spherical2xyz(theta,phi):
+ x, y, z = torch.cos(theta) * torch.sin(phi), torch.cos(phi), torch.sin(theta) * torch.sin(phi),;
+ return x,y,z
+
+def sampleSphere_Equirec(rows,cols):
+ theta, phi = equirec2spherical(rows, cols)
+ x,y,z = spherical2xyz(theta,phi)
+
+ spherical = torch.stack((theta,phi),dim=1)
+ cartesian = torch.stack((x,y,z),dim=1)
+ return cartesian, spherical
+
+def sampleSphere_Layering(rows,cols=None):
+ N_phi = rows
+
+ # Make rows match the number of phi locations
+ N = 4/pi*N_phi*N_phi
+
+ # Alternative method presented in https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
+ a = 4*pi/N
+ d = sqrt(a)
+ M_phi = int(round(pi/d))
+ d_phi = pi/M_phi
+ d_theta = a/d_phi
+
+ theta_list = []
+ phi_list = []
+
+ for m in range(M_phi):
+ phi = pi*(m+0.5)/M_phi
+ M_theta = int(round(2*pi*sin(phi)/d_theta))
+ for n in range(M_theta):
+ theta = 2*pi*n/M_theta
+
+ phi_list.append(phi)
+ theta_list.append(theta)
+
+ theta = torch.tensor(theta_list,dtype=torch.float)
+ phi = torch.tensor(phi_list,dtype=torch.float)
+ x,y,z = spherical2xyz(theta,phi)
+
+ spherical = torch.stack((theta,phi),dim=1)
+ cartesian = torch.stack((x,y,z),dim=1)
+ return cartesian, spherical
+
+def sampleSphere_Spiral(rows,cols):
+ # Get a sufficient number of points to represent the resolution of the image.
+ N = int(2/pi*cols*rows)
+
+ # Use the Fibonacci Spiral to equally space points
+ # Visualization at https://www.youtube.com/watch?v=Ua0kig6N3po
+ goldenRatio = (1 + 5**0.5)/2
+ i = torch.arange(0, N)
+ theta = (2 * pi * i / goldenRatio) % (2 * pi)
+ phi = torch.arccos(1 - 2*(i+0.5)/N)
+ x,y,z = spherical2xyz(theta,phi)
+
+ spherical = torch.stack((theta,phi),dim=1)
+ cartesian = torch.stack((x,y,z),dim=1)
+ return cartesian, spherical
+
+def sampleSphere_Icosphere(rows,cols=None):
+ num_subdiv = int(np.floor(np.log2(rows)) - 1)
+ cartesian = icosphere(num_subdiv)
+
+ # Spherical coordinates calculator
+ xy = cartesian[:,0]**2 + cartesian[:,1]**2
+ phi = np.arctan2(np.sqrt(xy), cartesian[:,2]) # for elevation angle defined from Z-axis down
+ theta = np.arctan2(cartesian[:,1], cartesian[:,0])
+ theta += np.pi
+
+ # Convert to torch
+ phi = torch.tensor(phi,dtype=torch.float)
+ theta = torch.tensor(theta,dtype=torch.float)
+
+ x,y,z = spherical2xyz(theta,phi) # Redefine for axis consistency
+
+ spherical = torch.stack((theta,phi),dim=1)
+ cartesian = torch.stack((x,y,z),dim=1)
+ return cartesian, spherical
+
+def sampleSphere_Random(rows,cols):
+ # Get a sufficient number of points to represent the resolution of the image.
+ N = int(2/pi*cols*rows)
+
+ # Alternative method presented in https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
+ z = 2*torch.rand(N) - 1
+ phi = 2*pi*torch.rand(N)
+
+ x = torch.sqrt(1 - z**2)*torch.cos(phi)
+ y = torch.sqrt(1 - z**2)*torch.sin(phi)
+
+ # Spherical coordinates calculator
+ xy = x**2 + y**2
+ phi = torch.atan2(torch.sqrt(xy), z) # for elevation angle defined from Z-axis down
+ theta = torch.atan2(y, x)
+ theta += pi
+
+ x,y,z = spherical2xyz(theta,phi) # Redefine for axis consistency
+
+ spherical = torch.stack((theta,phi),dim=1)
+ cartesian = torch.stack((x,y,z),dim=1)
+ return cartesian, spherical
+
+#### Icosphere Methods #####
+# Code taken from https://yaweiliu.github.io/research_notes/notes/20210301_Creating%20an%20icosphere%20with%20Python.html
+from scipy.spatial.transform import Rotation as R
+
+def vertex(x, y, z):
+ """ Return vertex coordinates fixed to the unit sphere """
+ length = np.sqrt(x**2 + y**2 + z**2)
+ return [i / length for i in (x,y,z)]
+
+def middle_point(verts,middle_point_cache,point_1, point_2):
+ """ Find a middle point and project to the unit sphere """
+ # We check if we have already cut this edge first
+ # to avoid duplicated verts
+ smaller_index = min(point_1, point_2)
+ greater_index = max(point_1, point_2)
+ key = '{0}-{1}'.format(smaller_index, greater_index)
+ if key in middle_point_cache: return middle_point_cache[key]
+ # If it's not in cache, then we can cut it
+ vert_1 = verts[point_1]
+ vert_2 = verts[point_2]
+ middle = [sum(i)/2 for i in zip(vert_1, vert_2)]
+ verts.append(vertex(*middle))
+ index = len(verts) - 1
+ middle_point_cache[key] = index
+ return index
+
+def icosphere(subdiv):
+ # verts for icosahedron
+ r = (1.0 + np.sqrt(5.0)) / 2.0;
+ verts = np.array([[-1.0, r, 0.0],[ 1.0, r, 0.0],[-1.0, -r, 0.0],
+ [1.0, -r, 0.0],[0.0, -1.0, r],[0.0, 1.0, r],
+ [0.0, -1.0, -r],[0.0, 1.0, -r],[r, 0.0, -1.0],
+ [r, 0.0, 1.0],[ -r, 0.0, -1.0],[-r, 0.0, 1.0]]);
+ # rescale the size to radius of 0.5
+ verts /= np.linalg.norm(verts[0])
+ # adjust the orientation
+ r = R.from_quat([[0.19322862,-0.68019314,-0.19322862,0.68019314]])
+ verts = r.apply(verts)
+ verts = list(verts)
+
+ faces = [[0, 11, 5],[0, 5, 1],[0, 1, 7],[0, 7, 10],
+ [0, 10, 11],[1, 5, 9],[5, 11, 4],[11, 10, 2],
+ [10, 7, 6],[7, 1, 8],[3, 9, 4],[3, 4, 2],
+ [3, 2, 6],[3, 6, 8],[3, 8, 9],[5, 4, 9],
+ [2, 4, 11],[6, 2, 10],[8, 6, 7],[9, 8, 1],];
+
+ for i in range(subdiv):
+ middle_point_cache = {}
+ faces_subdiv = []
+ for tri in faces:
+ v1 = middle_point(verts,middle_point_cache,tri[0], tri[1])
+ v2 = middle_point(verts,middle_point_cache,tri[1], tri[2])
+ v3 = middle_point(verts,middle_point_cache,tri[2], tri[0])
+ faces_subdiv.append([tri[0], v1, v3])
+ faces_subdiv.append([tri[1], v2, v1])
+ faces_subdiv.append([tri[2], v3, v2])
+ faces_subdiv.append([v1, v2, v3])
+ faces = faces_subdiv
+
+ return np.array(verts)
diff --git a/splat_helpers.py b/splat_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca901d7655d6c35c26a6611793ef8c221cbfc3d0
--- /dev/null
+++ b/splat_helpers.py
@@ -0,0 +1,312 @@
+import torch
+import pyntcloud_io as plyio
+from torch_geometric.nn import knn
+import numpy as np
+from torch_scatter import scatter_mean
+import plyio as splatio
+import pandas as pd
+import plotly.graph_objects as go
+from sklearn.cluster import KMeans
+import matplotlib.pyplot as plt
+
+def splat_update_and_save(rawdata, results, fileName):
+
+
+ #df1[['f_dc_0','f_dc_1','f_dc_2']] = results
+ #plyNumpy = df1.to_numpy()
+ #plyNumpy.tofile('Scaniverse_chair_testOut2.ply')
+
+ #plyNumpy = df1.to_numpy()
+ #np.savetxt('Scaniverse_chair_test.csv', plyNumpy, delimiter=',')
+
+ #SH_C0 = 0.28209479177387814
+ #normedResults = 0.5 - results/SH_C0
+
+ #df1[['f_dc_0','f_dc_1','f_dc_2']] = normedResults
+
+
+ #plyNumpy = df1.to_numpy()
+ #np.savetxt('Scaniverse_chair_testOut.csv', plyNumpy, delimiter=',')
+ #plyNumpy = df1.to_numpy()
+ #plyNumpy.tofile('Scaniverse_chair_testOutNormed.ply')
+
+
+ outFileName='Scaniverse_chair_outputFloat.ply'
+ df1 = pd.DataFrame(rawdata, columns=['x','y','z',
+ 'nx','ny','nz',
+ 'f_dc_0','f_dc_1','f_dc_2',
+ 'opacity',
+ 'scale_0','scale_1','scale_2',
+ 'rot_0','rot_1','rot_2','rot_3'
+ ])
+
+ plyio.write_ply_float(filename=outFileName, points= df1, mesh=None, as_text=False)
+
+
+ outFilePath='C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.splat'
+ opacity = rawdata[:,-1].reshape(224713, 1)
+ positions = rawdata[:,0:3]
+ scales = rawdata[:,3:6]
+ rots = rawdata[:,6:10]
+ colors = results
+
+ splatio.numpy_to_splat(positions, scales, rots, colors, opacity, outFilePath)
+
+ #save as csv
+ normalized = False
+ #normalize the color values if needed
+ if (normalized):
+ SH_C0 = 0.28209479177387814
+ colors = colors/SH_C0-0.5
+ opacity = -np.log(1/(opacity-1))
+
+ colors = np.concatenate((colors, opacity), axis=1)
+ splat = np.concatenate((positions, scales, rots, colors), axis=1)
+ #np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',', header=splat.dtype.names)
+ np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',')
+
+ return
+
+
+def splat_save(positions, scales, rots, colors, output_path):
+
+
+
+ splatio.numpy_to_splat(positions, scales, rots, colors, output_path)
+
+ return
+
+
+
+
+
+
+def splat_unpacker(neighbors, fileName, removeTopQ = 0):
+
+ '''
+ #Get the raw PLY data into a tensor
+ plyout = plyio.read_ply('Scaniverse_chair_StyleOutput.ply')
+ df1out = plyout.get('points')
+
+ #Get the raw PLY data into a tensor
+ ply = plyio.read_ply(fileName)
+ df1 = ply.get('points')
+ f1 = df1.to_numpy()
+ f2 = df1out.to_numpy()
+ '''
+
+ positions, scales, rots, colors = splatio.ply_to_numpy(fileName)
+
+
+ #Get the raw PLY data into a tensor
+ torchPoints = torch.Tensor(positions)
+ samples = torchPoints.size()[0]
+
+ y = torchPoints
+ x = torchPoints
+ assign_index = knn(x, y, neighbors)
+
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ diffvector = y - out[0:samples]
+
+ diffvectorNum =diffvector.numpy()
+
+ normals = torch.from_numpy(diffvectorNum)
+
+ pos3D = torch.from_numpy(positions)
+
+ torchColors = torch.Tensor(colors)
+ torchColors.clamp(0,1)
+
+ return pos3D, normals, colors, scales, rots
+
+
+
+def splat_downsampler(pos3D, colors, scaleV):
+
+ pointsNP = pos3D.numpy()
+ colorsNP = colors.numpy()
+
+ #cluster
+ kmeans = KMeans(n_clusters=scaleV*256, random_state=0, n_init="auto").fit(pointsNP)
+
+ centers = kmeans.cluster_centers_
+ labels = kmeans.labels_
+
+
+ x = np.indices(pointsNP.shape)[0]
+ y = np.indices(pointsNP.shape)[1]
+ z = np.indices(pointsNP.shape)[2]
+ col = pointsNP.flatten()
+
+ # 3D Plot
+ fig = plt.figure()
+ ax3D = fig.add_subplot(projection='3d')
+ cm = plt.colormaps['brg']
+ p3d = ax3D.scatter(x, y, z, c=col, cmap=cm)
+ plt.colorbar(p3d)
+ plt.colorbar(p3d)
+
+
+
+ return pos3D, colors
+
+
+
+def splat_unpacker_threshold(neighbors, fileName, threshold):
+
+ #check if .ply or .splat
+ if fileName[-3:] == 'ply':
+ positionsFile, scalesFile, rotsFile, colorsFile = splatio.ply_to_numpy(fileName)
+ fileType = 'ply'
+ else:
+ positionsFile, scalesFile, rotsFile, colorsFile = splatio.splat_to_numpy(fileName)
+ fileType = 'splat'
+
+
+
+ positionsNP = positionsFile.copy()
+ scalesNP = scalesFile.copy()
+ rotsNP = rotsFile.copy()
+ colorsNP = colorsFile.copy()
+
+
+ #Get the raw PLY data into tensors
+ pos3D = torch.from_numpy(positionsNP)
+ colors = torch.from_numpy(colorsNP)
+ colors.clamp(0,1)
+ rots = torch.from_numpy(rotsNP)
+ scales = torch.from_numpy(scalesNP)
+ points = torch.from_numpy(positionsNP)
+
+ samples = points.size()[0]
+ y = points
+ x = points
+
+ #knn find midpoint
+ assign_index = knn(x, y, neighbors)
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ #calculate normals
+ normals = y - out[0:samples]
+
+
+ #use Euclidian distances to create a mask
+ #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 )
+ distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 )
+ threshold = np.clip(threshold, 0, 100)
+ boundry = np.percentile(distances, threshold)
+ mask = distances < boundry
+
+
+ pos3DFiltered = pos3D[mask]
+ normalsFiltered = normals[mask]
+ colorsFiltered = colors[mask]
+ scalesFiltered = scales[mask]
+ rotsFiltered = rots[mask]
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered[:,0:3], colorsFiltered[:,3], scalesFiltered, rotsFiltered, fileType
+
+
+def splat_unpacker_threshold_graph_normals(neighbors, fileName, threshold):
+
+ coneSize = 1
+
+ positionsNP, scalesNP, rotsNP, colorsNP = splatio.ply_to_numpy(fileName)
+
+ #Get the raw PLY data into tensors
+ pos3D = torch.from_numpy(positionsNP)
+ colors = torch.from_numpy(colorsNP)
+ colors.clamp(0,1)
+ rots = torch.from_numpy(rotsNP)
+ scales = torch.from_numpy(scalesNP)
+ points = torch.from_numpy(positionsNP)
+
+ samples = points.size()[0]
+ y = points
+ x = points
+
+ #knn find midpoint
+ assign_index = knn(x, y, neighbors)
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ #calculate normals
+ normals = y - out[0:samples]
+
+
+ #use Euclidian distances to create a mask
+ #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 )
+ distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 )
+ threshold = np.clip(threshold, 0, 100)
+ boundry = np.percentile(distances, threshold)
+ mask = distances < boundry
+
+
+ pos3DFiltered = pos3D[mask]
+ normalsFiltered = normals[mask]
+ colorsFiltered = colors[mask]
+ scalesFiltered = scales[mask]
+ rotsFiltered = rots[mask]
+
+
+
+ diffvector = y - out[0:samples]
+
+ diffvectorNum =diffvector.numpy()
+ diffNum = distances.numpy()
+ resultNum = out[0:samples].numpy()
+ # Normalised [0,1]
+ diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum)
+
+ '''
+ #point cloud
+ marker_data = go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 2],
+ z=-points[:, 1],
+ marker=go.scatter3d.Marker(size=3, color= diffNumNorm),
+ opacity=0.8,
+ mode='markers'
+ )
+ fig=go.Figure(data=marker_data)
+ fig.show()
+ '''
+ #normals
+ fig = go.Figure(data=go.Cone(
+ x=points[:, 0],
+ y=points[:, 2],
+ z=-points[:, 1],
+ u=diffvectorNum[:, 0],
+ v=diffvectorNum[:, 2],
+ w=-diffvectorNum[:, 1],
+ sizemode="absolute",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig.show()
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered, scalesFiltered, rotsFiltered
\ No newline at end of file
diff --git a/splat_mesh_helpers.py b/splat_mesh_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c6e4708221a633a0615f05d5e7e6e0f1f0c4f07
--- /dev/null
+++ b/splat_mesh_helpers.py
@@ -0,0 +1,687 @@
+import torch
+import torch.utils
+import pyntcloud_io as plyio
+from torch_geometric.nn import knn
+import numpy as np
+from torch_scatter import scatter_mean
+import plyio as splatio
+import pandas as pd
+import plotly.graph_objects as go
+from sklearn.cluster import KMeans
+import matplotlib.pyplot as plt
+
+
+
+def splat_update_and_save(rawdata, results, fileName):
+
+
+ #df1[['f_dc_0','f_dc_1','f_dc_2']] = results
+ #plyNumpy = df1.to_numpy()
+ #plyNumpy.tofile('Scaniverse_chair_testOut2.ply')
+
+ #plyNumpy = df1.to_numpy()
+ #np.savetxt('Scaniverse_chair_test.csv', plyNumpy, delimiter=',')
+
+ #SH_C0 = 0.28209479177387814
+ #normedResults = 0.5 - results/SH_C0
+
+ #df1[['f_dc_0','f_dc_1','f_dc_2']] = normedResults
+
+
+ #plyNumpy = df1.to_numpy()
+ #np.savetxt('Scaniverse_chair_testOut.csv', plyNumpy, delimiter=',')
+ #plyNumpy = df1.to_numpy()
+ #plyNumpy.tofile('Scaniverse_chair_testOutNormed.ply')
+
+
+ outFileName='Scaniverse_chair_outputFloat.ply'
+ df1 = pd.DataFrame(rawdata, columns=['x','y','z',
+ 'nx','ny','nz',
+ 'f_dc_0','f_dc_1','f_dc_2',
+ 'opacity',
+ 'scale_0','scale_1','scale_2',
+ 'rot_0','rot_1','rot_2','rot_3'
+ ])
+
+ plyio.write_ply_float(filename=outFileName, points= df1, mesh=None, as_text=False)
+
+
+ outFilePath='C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.splat'
+ opacity = rawdata[:,-1].reshape(224713, 1)
+ positions = rawdata[:,0:3]
+ scales = rawdata[:,3:6]
+ rots = rawdata[:,6:10]
+ colors = results
+
+ splatio.numpy_to_splat(positions, scales, rots, colors, opacity, outFilePath)
+
+ #save as csv
+ normalized = False
+ #normalize the color values if needed
+ if (normalized):
+ SH_C0 = 0.28209479177387814
+ colors = colors/SH_C0-0.5
+ opacity = -np.log(1/(opacity-1))
+
+ colors = np.concatenate((colors, opacity), axis=1)
+ splat = np.concatenate((positions, scales, rots, colors), axis=1)
+ #np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',', header=splat.dtype.names)
+ np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',')
+
+ return
+
+
+def splat_save(positions, scales, rots, colors, output_path, fileType):
+
+ splatio.numpy_to_splat(positions, scales, rots, colors, output_path, fileType)
+
+ return
+
+
+
+def splat_randomsampler(pos3D):
+
+ pointsNP = pos3D.numpy().copy()
+ np.random.shuffle(pointsNP)
+ scaler = int(pointsNP.shape[0]/2)
+ valuesNP = pointsNP[:scaler]
+ values64 = torch.from_numpy(valuesNP)
+ values = values64.to(torch.float32)
+ return values
+
+
+def splat_GaussianSuperSampler(pos3D_fr, colors_fr, opacity_fr, scales_fr, rots_fr, GaussianSamples):
+
+ pos3D = pos3D_fr.clone()
+ colors = colors_fr.clone()
+ opacity = opacity_fr.clone()
+ scales = scales_fr.clone()
+ rots = rots_fr.clone()
+
+ #get all the splats' relative sizes/densities and opacities
+ importance = (torch.absolute(scales[:,0])+torch.absolute(scales[:,1])+torch.absolute(scales[:,2]))/scales.max()+torch.absolute(opacity)/opacity.max()
+ importanceConstant = importance.sum().to(torch.int32)
+
+ increaseNormalizer = GaussianSamples/importanceConstant.item()
+ copiesToMakeQuantity = (importance*increaseNormalizer).to(torch.int32)
+ #create duplicates of all the larger Gaussians, create matched tensors for all the other attributes too
+ duplicatePoints = torch.from_numpy(np.repeat(pos3D.numpy(), copiesToMakeQuantity.numpy(), axis=0))
+ duplicateColors = torch.from_numpy(np.repeat(colors.numpy(), copiesToMakeQuantity.numpy(), axis=0))
+ duplicateRots = torch.from_numpy(np.repeat(rots.numpy(), copiesToMakeQuantity.numpy(), axis=0))
+ duplicateScales = torch.from_numpy(np.repeat(scales.numpy(), copiesToMakeQuantity.numpy(), axis=0))
+
+ #create a copy for saving the rotation results
+ duplicateScalesRotated = duplicateScales.clone()
+
+ q0 = duplicateRots[:,0]
+ q1 = duplicateRots[:,1]
+ q2 = duplicateRots[:,2]
+ q3 = duplicateRots[:,3]
+ x00 = q0**2 + q1**2 - q2**2 - q3**2
+ x01 = 2*q1*q2 - 2*q0*q3
+ x02 = 2*q1*q3 + 2*q0*q2
+ x10 = 2*q1*q2 + 2*q0*q3
+ x11 = q0**2 - q1**2 + q2**2 - q3**2
+ x12 = 2*q2*q3 - 2*q0*q1
+ x20 = 2*q1*q3 - 2*q0*q2
+ x21 = 2*q2*q3 + 2*q0*q1
+ x22 = q0**2 - q1**2 - q2**2 + q3**2
+
+ X0 = torch.stack((x00, x01, x02), dim=1)
+ X1 = torch.stack((x10, x11, x12), dim=1)
+ X2 = torch.stack((x20, x21, x22), dim=1)
+
+
+ R_matrix = torch.stack((X0, X1, X2), dim=1)
+ duplicateScalesRotated = torch.bmm(R_matrix, duplicateScales.unsqueeze(2)).squeeze(2)
+ #torch.normal can't handle negative numbers, store the signs and add them back afterwards
+ duplicateScalesRotatedSigns = duplicateScalesRotated.sign()
+ duplicateScalesRotatedABS = torch.absolute(duplicateScalesRotated)
+
+ GaussianPointsNoSignScale = torch.normal(duplicatePoints, duplicateScalesRotatedABS)
+ #correct for signs
+ GaussianPoints = duplicatePoints + (GaussianPointsNoSignScale-duplicatePoints)*duplicateScalesRotatedSigns
+
+ pos3D = torch.cat((pos3D_fr, GaussianPoints
+ ), dim=0)
+ colors = torch.cat((colors_fr,duplicateColors
+ ), dim=0)
+ return pos3D, colors.to(torch.float32)
+
+
+
+
+def splat_unpacker_with_threshold(neighbors, fileName, threshold):
+
+ #check if .ply or .splat or gen
+ if fileName[-3:] == 'ply':
+ positionsFile, scalesFile, rotsFile, colorsFile = splatio.ply_to_numpy(fileName)
+ fileType = 'ply'
+ elif fileName[-3:] == 'gen':
+ resolution = 100*5
+ radius = 1
+ indices = np.arange(0, resolution, dtype=float) + 0.5
+ phi = np.arccos(1 - 2*indices/resolution)
+ theta = np.pi * (1 + 5**0.5) * indices
+
+ x = radius * np.sin(phi) * np.cos(theta)
+ y = radius * np.sin(phi) * np.sin(theta)
+ z = radius * np.cos(phi)
+
+ pos3DFiltered = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32)))
+ normalsFiltered = torch.full_like(pos3DFiltered, 0.0)
+ opacity = torch.full_like(pos3DFiltered, 0.5)
+ #colorsFiltered = torch.full_like(pos3DFiltered, 0.5)
+ scalesFiltered = torch.full_like(pos3DFiltered, 0.015)
+ rotsFiltered = torch.full((resolution, 4), 0.0)
+ rotsFiltered[:,0] = 1.0
+ fileType = 'splat'
+
+ #make is stripy
+ c1 = torch.tensor([0.9, 0.9, 0.9], dtype=torch.float32)
+ c2 = torch.tensor([0.9, 0.0, 0.0], dtype=torch.float32)
+ c3 = torch.tensor([0.0, 0.9, 0.0], dtype=torch.float32)
+ c4 = torch.tensor([0.0, 0.0, 0.9], dtype=torch.float32)
+ c5 = torch.tensor([0.9, 0.9, 0.0], dtype=torch.float32)
+
+ C1 = torch.from_numpy(np.tile(c1.numpy(), (int(resolution/5),1)))
+ C2 = torch.from_numpy(np.tile(c2.numpy(), (int(resolution/5),1)))
+ C3 = torch.from_numpy(np.tile(c3.numpy(), (int(resolution/5),1)))
+ C4 = torch.from_numpy(np.tile(c4.numpy(), (int(resolution/5),1)))
+ C5 = torch.from_numpy(np.tile(c5.numpy(), (int(resolution/5),1)))
+
+ colorsFiltered = torch.cat((C1, C2, C3, C4, C5), dim=0)
+
+ addnoise = True
+ if (addnoise):
+ noisePoints = torch.from_numpy(np.random.uniform(-1.5, 1.5, (50,3)).astype(np.float32))
+ pos3DFiltered = torch.cat((noisePoints, pos3DFiltered
+ ), dim=0)
+
+ noiseColors = torch.from_numpy(np.random.uniform(0, 9.99, (50,3)).astype(np.float32))
+ colorsFiltered = torch.cat((noiseColors, colorsFiltered
+ ), dim=0)
+
+ noiseNormals = torch.full_like(noisePoints, 0.0)
+ normalsFiltered = torch.cat((noiseNormals, normalsFiltered
+ ), dim=0)
+
+ noiseScales = torch.full_like(noisePoints, 0.015)
+ scalesFiltered = torch.cat((noiseScales, scalesFiltered
+ ), dim=0)
+
+ noiseopacity = torch.full_like(noisePoints, 0.5)
+ opacity = torch.cat((noiseopacity, opacity
+ ), dim=0)
+
+ noiseScales = torch.full((50, 4), 0.0)
+ rotsFiltered = torch.cat((noiseScales, rotsFiltered
+ ), dim=0)
+ rotsFiltered[:,0] = 1.0
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType
+ elif fileName[-4:] == 'gen2':
+
+
+ z = torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9], dtype=torch.float32)
+ y = torch.tensor([1.7, 1.4, 1.0, 0.5, 0.0], dtype=torch.float32)
+ x = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=torch.float32)
+
+
+ pos3DFiltered = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32)))
+ normalsFiltered = torch.full_like(pos3DFiltered, 0.0)
+ opacity = torch.full_like(pos3DFiltered, 0.99)
+ #colorsFiltered = torch.full_like(pos3DFiltered, 0.5)
+ #scalesFiltered = torch.full_like(pos3DFiltered, 0.02)
+ rotsFiltered = torch.full((5, 4), 0.0)
+ rotsFiltered[:,0] = 1.0
+ fileType = 'splat'
+
+ scalesFiltered = torch.tensor([[0.03, 0.03, 0.01],
+ [0.03, 0.03, 0.05],
+ [0.07, 0.07, 0.03],
+ [0.05, 0.05, 0.09],
+ [0.09, 0.12, 0.12,]
+ ], dtype= torch.float32)
+
+
+ #each Gaussian has its own color
+ c1 = torch.tensor([0.1, 0.2, 0.5, 0.7, 0.9], dtype=torch.float32)
+ c2 = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1], dtype=torch.float32)
+ c3 = torch.tensor([0.9, 0.7, 0.5, 0.2, 0.1], dtype=torch.float32)
+
+
+ colorsFiltered = torch.from_numpy(np.transpose(np.stack([c1,c2,c3]).astype(np.float32)))
+
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType
+ else:
+ positionsFile, scalesFile, rotsFile, colorsFile = splatio.splat_to_numpy(fileName)
+ fileType = 'splat'
+
+
+ positionsNP = positionsFile.copy()
+ scalesNP = scalesFile.copy()
+ rotsNP = rotsFile.copy()
+ colorsNP = colorsFile.copy()
+
+
+ #Get the raw PLY data into tensors
+ pos3D = torch.from_numpy(positionsNP)
+ colors = torch.from_numpy(colorsNP)
+ colors.clamp(0,1)
+ rots = torch.from_numpy(rotsNP)
+ scales = torch.from_numpy(scalesNP)
+ points = torch.from_numpy(positionsNP)
+
+ samples = points.size()[0]
+ y = points
+ x = points
+
+ #knn find midpoint
+ assign_index = knn(x, y, neighbors)
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ #calculate normals
+ normals = y - out[0:samples]
+
+ if threshold == 100:
+
+ pos3DFiltered = pos3D
+ normalsFiltered = normals
+ colorsFiltered = colors
+ scalesFiltered = scales
+ rotsFiltered = rots
+
+ else:
+
+ #use Euclidian distances to create a mask
+ #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 )
+ distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 )
+ threshold = np.clip(threshold, 0, 100)
+ boundry = np.percentile(distances, threshold)
+ mask = distances < boundry
+
+
+ pos3DFiltered = pos3D[mask]
+ normalsFiltered = normals[mask]
+ colorsFiltered = colors[mask]
+ scalesFiltered = scales[mask]
+ rotsFiltered = rots[mask]
+
+
+ displayNormals = False
+ if (displayNormals):
+
+ coneSize = 0.1
+ diffvectorNum =normals.numpy()
+
+ showBoth = False
+ if (showBoth):
+
+
+ diffBoth = torch.cat((normals, normalsFiltered), dim=0).numpy().astype(np.float16)
+ pos3dShifted = pos3DFiltered
+ pos3dShifted[:,2] = pos3dShifted[:,2] + 1
+ posBoth = torch.cat((pos3D, pos3dShifted), dim=0).numpy().astype(np.float16)
+
+ distancesMasked = distances[mask]
+ diffBothNum = torch.cat((distances, distancesMasked), dim=0).numpy().astype(np.float16)
+
+ # Normalised [0,1]
+ diffNumNorm = (diffBothNum - np.min(diffBothNum))/np.ptp(diffBothNum)
+
+ #point cloud
+ marker_data = go.Scatter3d(
+ x=posBoth[:, 0],
+ y=posBoth[:, 2],
+ z=-posBoth[:, 1],
+ marker=go.scatter3d.Marker(size=1, color= diffNumNorm),
+ opacity=0.8,
+ mode='markers'
+ )
+ fig=go.Figure(data=marker_data)
+ fig.show()
+
+
+
+ '''
+ #Too high memory usage :(
+ diffBoth = torch.cat((normals, normalsFiltered), dim=0).numpy().astype(np.float16)
+ pos3dShifted = pos3DFiltered
+ pos3dShifted[:,2] = pos3dShifted[:,2] + 1
+ posBoth = torch.cat((pos3D, pos3dShifted), dim=0).numpy().astype(np.float16)
+
+
+ #normals
+ fig = go.Figure(data=go.Cone(
+ x=posBoth[:, 0],
+ y=posBoth[:, 2],
+ z=-posBoth[:, 1],
+ u=diffBoth[:, 0],
+ v=diffBoth[:, 2],
+ w=-diffBoth[:, 1],
+ sizemode="raw",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig.show()
+
+ exit()
+ '''
+
+
+ else:
+
+ diffNum = distances.numpy()
+ # Normalised [0,1]
+ diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum)
+
+ #point cloud
+ marker_data = go.Scatter3d(
+ x=pos3DFiltered[:, 0],
+ y=pos3DFiltered[:, 2],
+ z=-pos3DFiltered[:, 1],
+ marker=go.scatter3d.Marker(size=1, color= diffNumNorm),
+ opacity=0.8,
+ mode='markers'
+ )
+ fig=go.Figure(data=marker_data)
+ fig.show()
+
+
+
+ #normals
+ fig = go.Figure(data=go.Cone(
+ x=pos3D[:, 0],
+ y=pos3D[:, 2],
+ z=-pos3D[:, 1],
+ u=diffvectorNum[:, 0],
+ v=diffvectorNum[:, 2],
+ w=-diffvectorNum[:, 1],
+ sizemode="absolute",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig.show()
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered[:,0:3], colorsFiltered[:,3], scalesFiltered, rotsFiltered, fileType
+
+
+
+def generate_with_noise_ablation(neighbors, fileName, threshold):
+
+ resolution = 100*5
+ radius = 1
+ indices = np.arange(0, resolution, dtype=float) + 0.5
+ phi = np.arccos(1 - 2*indices/resolution)
+ theta = np.pi * (1 + 5**0.5) * indices
+
+ x = radius * np.sin(phi) * np.cos(theta)
+ y = radius * np.sin(phi) * np.sin(theta)
+ z = radius * np.cos(phi)
+
+ pos3D = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32)))
+ normalsFiltered = torch.full_like(pos3D, 0.0)
+ opacity = torch.full_like(pos3D, 0.5)
+ #colorsFiltered = torch.full_like(pos3DFiltered, 0.5)
+ scales = torch.full_like(pos3D, 0.015)
+ rots = torch.full((resolution, 4), 0.0)
+ rots[:,0] = 1.0
+ fileType = 'splat'
+
+ #make is stripy
+ c1 = torch.tensor([0.9, 0.9, 0.9], dtype=torch.float32)
+ c2 = torch.tensor([0.9, 0.0, 0.0], dtype=torch.float32)
+ c3 = torch.tensor([0.0, 0.9, 0.0], dtype=torch.float32)
+ c4 = torch.tensor([0.0, 0.0, 0.9], dtype=torch.float32)
+ c5 = torch.tensor([0.9, 0.9, 0.0], dtype=torch.float32)
+
+ C1 = torch.from_numpy(np.tile(c1.numpy(), (int(resolution/5),1)))
+ C2 = torch.from_numpy(np.tile(c2.numpy(), (int(resolution/5),1)))
+ C3 = torch.from_numpy(np.tile(c3.numpy(), (int(resolution/5),1)))
+ C4 = torch.from_numpy(np.tile(c4.numpy(), (int(resolution/5),1)))
+ C5 = torch.from_numpy(np.tile(c5.numpy(), (int(resolution/5),1)))
+
+ colors = torch.cat((C1, C2, C3, C4, C5), dim=0)
+
+ addnoise = True
+ if (addnoise):
+ noisePoints = torch.from_numpy(np.random.uniform(-1.01, 1.01, (50,3)).astype(np.float32))
+ pos3D = torch.cat((noisePoints, pos3D
+ ), dim=0)
+
+ noiseColors = torch.from_numpy(np.random.uniform(0, 9.99, (50,3)).astype(np.float32))
+ colors = torch.cat((noiseColors, colors
+ ), dim=0)
+
+ noiseScales = torch.full_like(noisePoints, 0.015)
+ scales = torch.cat((noiseScales, scales
+ ), dim=0)
+
+ noiseopacity = torch.full_like(noisePoints, 0.5)
+ opacity = torch.cat((noiseopacity, opacity
+ ), dim=0)
+
+ noiseRots = torch.full((50, 4), 0.0)
+ rots = torch.cat((noiseRots, rots
+ ), dim=0)
+ rots[:,0] = 1.0
+
+ points = pos3D
+
+ samples = points.size()[0]
+ y = points
+ x = points
+
+ #knn find midpoint
+ assign_index = knn(x, y, neighbors)
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ #calculate normals
+ normals = y - out[0:samples]
+
+ if threshold == 100:
+
+ pos3DFiltered = pos3D
+ normalsFiltered = normals
+ colorsFiltered = colors
+ scalesFiltered = scales
+ rotsFiltered = rots
+ opacity = opacity
+
+ else:
+
+ #use Euclidian distances to create a mask
+ #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 )
+ distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 )
+ threshold = np.clip(threshold, 0, 100)
+ boundry = np.percentile(distances, threshold)
+ mask = distances < boundry
+
+
+ pos3DFiltered = pos3D[mask]
+ normalsFiltered = normals[mask]
+ colorsFiltered = colors[mask]
+ scalesFiltered = scales[mask]
+ rotsFiltered = rots[mask]
+ opacity = opacity[mask]
+
+ displayNormals = True
+ if (displayNormals):
+
+ coneSize = 1
+ diffvectorNum = normals.numpy()
+ diffvectorfilteredNum = normalsFiltered.numpy()
+ diffNum = distances.numpy()
+
+ # Normalised [0,1]
+ diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum)
+
+ #point cloud
+ marker_data = go.Scatter3d(
+ x=pos3D[:, 0],
+ y=pos3D[:, 2],
+ z=-pos3D[:, 1],
+ marker=go.scatter3d.Marker(size=5, color= diffNumNorm),
+ opacity=0.8,
+ mode='markers'
+ )
+ fig=go.Figure(data=marker_data)
+ fig.show()
+
+ #normals
+ fig2 = go.Figure(data=go.Cone(
+ x=pos3D[:, 0],
+ y=pos3D[:, 2],
+ z=-pos3D[:, 1],
+ u=diffvectorNum[:, 0],
+ v=diffvectorNum[:, 2],
+ w=-diffvectorNum[:, 1],
+ sizemode="raw",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig2.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig2.show()
+
+
+ #normals Filtered
+ fig3 = go.Figure(data=go.Cone(
+ x=pos3DFiltered[:, 0],
+ y=pos3DFiltered[:, 2],
+ z=-pos3DFiltered[:, 1],
+ u=diffvectorfilteredNum[:, 0],
+ v=diffvectorfilteredNum[:, 2],
+ w=-diffvectorfilteredNum[:, 1],
+ sizemode="raw",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig3.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig3.show()
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType
+
+
+
+
+def splat_unpacker_threshold_graph_normals(neighbors, fileName, threshold):
+
+ coneSize = 1
+
+ positionsNP, scalesNP, rotsNP, colorsNP = splatio.ply_to_numpy(fileName)
+
+ #Get the raw PLY data into tensors
+ pos3D = torch.from_numpy(positionsNP)
+ colors = torch.from_numpy(colorsNP)
+ colors.clamp(0,1)
+ rots = torch.from_numpy(rotsNP)
+ scales = torch.from_numpy(scalesNP)
+ points = torch.from_numpy(positionsNP)
+
+ samples = points.size()[0]
+ y = points
+ x = points
+
+ #knn find midpoint
+ assign_index = knn(x, y, neighbors)
+ indexSrc = assign_index[0:1][0]
+ indexSrcTrn = indexSrc.reshape(-1,1)
+ index = indexSrcTrn.expand(indexSrc.size()[0],3)
+
+ src = x[assign_index[1:2]][0]
+ out = src.new_zeros(src.size())
+ out = scatter_mean(src, index, 0, out=out)
+
+ #calculate normals
+ normals = y - out[0:samples]
+
+
+ #use Euclidian distances to create a mask
+ #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 )
+ distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 )
+ threshold = np.clip(threshold, 0, 100)
+ boundry = np.percentile(distances, threshold)
+ mask = distances < boundry
+
+
+ pos3DFiltered = pos3D[mask]
+ normalsFiltered = normals[mask]
+ colorsFiltered = colors[mask]
+ scalesFiltered = scales[mask]
+ rotsFiltered = rots[mask]
+
+
+
+ diffvector = y - out[0:samples]
+
+ diffvectorNum =diffvector.numpy()
+ diffNum = distances.numpy()
+ resultNum = out[0:samples].numpy()
+ # Normalised [0,1]
+ diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum)
+
+ '''
+ #point cloud
+ marker_data = go.Scatter3d(
+ x=points[:, 0],
+ y=points[:, 2],
+ z=-points[:, 1],
+ marker=go.scatter3d.Marker(size=3, color= diffNumNorm),
+ opacity=0.8,
+ mode='markers'
+ )
+ fig=go.Figure(data=marker_data)
+ fig.show()
+ '''
+ #normals
+ fig = go.Figure(data=go.Cone(
+ x=points[:, 0],
+ y=points[:, 2],
+ z=-points[:, 1],
+ u=diffvectorNum[:, 0],
+ v=diffvectorNum[:, 2],
+ w=-diffvectorNum[:, 1],
+ sizemode="absolute",
+ sizeref=coneSize,
+ anchor="tail"))
+
+ fig.update_layout(
+ scene=dict(domain_x=[0, 1],
+ camera_eye=dict(x=-1.57, y=1.36, z=0.58)))
+
+ fig.show()
+
+
+ return pos3DFiltered, normalsFiltered, colorsFiltered, scalesFiltered, rotsFiltered
\ No newline at end of file
diff --git a/style_ims/21.jpg b/style_ims/21.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..462c08de54e803c01fe3d86f48898272629625fc
Binary files /dev/null and b/style_ims/21.jpg differ
diff --git a/style_ims/31.jpg b/style_ims/31.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fa4f7071105d8fc25da420f158b8d4c13371b25e
Binary files /dev/null and b/style_ims/31.jpg differ
diff --git a/style_ims/47.jpg b/style_ims/47.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..464542fb9466390479a9409179cb53db9d80031a
Binary files /dev/null and b/style_ims/47.jpg differ
diff --git a/style_ims/6.jpg b/style_ims/6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d15ca0986ed0d66e0da2b497d0b929376e465e8b
Binary files /dev/null and b/style_ims/6.jpg differ
diff --git a/style_ims/img11.jpg b/style_ims/img11.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7167997e015ffbff43584315927d65f9a8a65e9f
Binary files /dev/null and b/style_ims/img11.jpg differ
diff --git a/style_ims/img66.jpg b/style_ims/img66.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d15ca0986ed0d66e0da2b497d0b929376e465e8b
Binary files /dev/null and b/style_ims/img66.jpg differ
diff --git a/style_ims/style-10.jpg b/style_ims/style-10.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..70f6adec87d980523fe1d59493a9556b076913b3
--- /dev/null
+++ b/style_ims/style-10.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efe9c5fcf5d652ba574fd1079eda9ec26aa823fd630750469963435b552aae82
+size 144264
diff --git a/style_ims/style0.jpg b/style_ims/style0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..820b093f207d0884a0ca0e5d9236d29dc7aa730b
--- /dev/null
+++ b/style_ims/style0.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8
+size 376101
diff --git a/style_ims/style1.jpg b/style_ims/style1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..63aa06fe4296fa6e0feb39d5c6be57fc177283d7
Binary files /dev/null and b/style_ims/style1.jpg differ
diff --git a/style_ims/style100.JPG b/style_ims/style100.JPG
new file mode 100644
index 0000000000000000000000000000000000000000..0e6c11487e79fe59db161380d9b26ec34581d7f8
Binary files /dev/null and b/style_ims/style100.JPG differ
diff --git a/style_ims/style12.jpg b/style_ims/style12.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0da803b36c0d6e23922a0f133cc0a1d9fe35ffb3
Binary files /dev/null and b/style_ims/style12.jpg differ
diff --git a/style_ims/style19.jpg b/style_ims/style19.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a7b5df070485bacad6c1f0c9cfe3e48fe0995509
Binary files /dev/null and b/style_ims/style19.jpg differ
diff --git a/style_ims/style2.jpg b/style_ims/style2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c6e61e0ddbdbe6b9cf24d239c2aa64d966f82dba
--- /dev/null
+++ b/style_ims/style2.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60192c2cb0eefc036ecc5bf04810f3155ec30239608b6ebefa4ae3ea96864ed2
+size 140682
diff --git a/style_ims/style3.jpg b/style_ims/style3.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3dbb29cf89e70f9421afb3341674eec0a991cf52
Binary files /dev/null and b/style_ims/style3.jpg differ
diff --git a/style_ims/style31.jpg b/style_ims/style31.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fa4f7071105d8fc25da420f158b8d4c13371b25e
Binary files /dev/null and b/style_ims/style31.jpg differ
diff --git a/style_ims/style4.jpg b/style_ims/style4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f819f887e6d85dc19f334533ebb5c25b11d1f2f
Binary files /dev/null and b/style_ims/style4.jpg differ
diff --git a/style_ims/style40.jpg b/style_ims/style40.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f086b087a8f895774866565f01b48b5902c197e
Binary files /dev/null and b/style_ims/style40.jpg differ
diff --git a/style_ims/style44.jpg b/style_ims/style44.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..356f707637fd225bef6c5d4d422713e486239138
--- /dev/null
+++ b/style_ims/style44.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28b09fc5f3ca18e8dda06475f72f3ec8d912c71f0b35a372dd309ea2cc4cfda9
+size 178110
diff --git a/style_ims/style5.jpg b/style_ims/style5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..732b7d5de2cc43b4a8d1e092b3a06fed496ec1d6
Binary files /dev/null and b/style_ims/style5.jpg differ
diff --git a/style_ims/style50.jpg b/style_ims/style50.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d625f87588f641c5fcb42fc930cb0ccd57bb72c5
Binary files /dev/null and b/style_ims/style50.jpg differ
diff --git a/style_ims/style6.jpg b/style_ims/style6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..70f6adec87d980523fe1d59493a9556b076913b3
--- /dev/null
+++ b/style_ims/style6.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efe9c5fcf5d652ba574fd1079eda9ec26aa823fd630750469963435b552aae82
+size 144264
diff --git a/style_ims/style66.jpg b/style_ims/style66.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f077a49af8699a399340df8cbb9e477631b697cf
Binary files /dev/null and b/style_ims/style66.jpg differ
diff --git a/style_ims/style7.jpg b/style_ims/style7.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cb0e9091c8c18c589ff104a92c29d119ba1bf1b1
Binary files /dev/null and b/style_ims/style7.jpg differ
diff --git a/style_ims/style8.jpg b/style_ims/style8.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3462b70be0bfe4c4272ec3d86e4639a7000b1a16
Binary files /dev/null and b/style_ims/style8.jpg differ
diff --git a/style_ims/style9.jpg b/style_ims/style9.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b2d70e41daa0c72264c13d3e77291c7273837c89
Binary files /dev/null and b/style_ims/style9.jpg differ
diff --git a/style_ims/style99.jpg b/style_ims/style99.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1a07615bff581507454b934ae14c7f53f334eb4e
Binary files /dev/null and b/style_ims/style99.jpg differ
diff --git a/style_ims/waves.jpg b/style_ims/waves.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f077a49af8699a399340df8cbb9e477631b697cf
Binary files /dev/null and b/style_ims/waves.jpg differ
diff --git a/styletransfer_splat.py b/styletransfer_splat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8575f9cfb7cc0ad3b7faeb733dc87ca244a5b9af
--- /dev/null
+++ b/styletransfer_splat.py
@@ -0,0 +1,218 @@
+import pointCloudToMesh as ply2M
+import argparse
+import utils
+import graph_io as gio
+from clusters import *
+#from tqdm import tqdm,trange
+import splat_mesh_helpers as splt
+import clusters as cl
+from torch_geometric.data import Data
+from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
+import pointCloudToMesh as plyToMesh
+import plotly.graph_objects as go
+import pyvista as pv
+
+from time import time
+
+from graph_networks.LinearStyleTransfer_vgg import encoder,decoder
+from graph_networks.LinearStyleTransfer_matrix import TransformLayer
+
+from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer
+from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4
+
+#import matplotlib.pyplot as plt
+
+
+def styletransfer_with_filtering_sampling(filename, stylePath, outPath, device = 'cpu', threshold=99.9, samplingRate=1.5, displayPointCloud = False):
+
+ print("Running on device:", device)
+
+ n = 25
+ style_ref = utils.loadImage(stylePath, shape=(256*2,256*2))
+ ratio=.25
+ depth = 3
+
+ pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = splt.splat_unpacker_with_threshold(n, filename, threshold)
+
+ time1_start = time()
+
+ #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1))
+ if samplingRate > 1:
+ GaussianSamples = int(pos3D_Original.shape[0]*samplingRate)
+ pos3D, colors = splt.splat_GaussianSuperSampler(pos3D_Original.clone(), colors_Original.clone(), opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(), GaussianSamples)
+ else:
+ pos3D, colors = pos3D_Original, colors_Original
+ #plyToMesh.graph_Points(pos3D, torch.clamp(colors, 0, 1))
+ #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1))
+
+ time1_end = time()
+
+ print("Number of nodes in the graph:", pos3D.shape[0])
+
+ print(f"Time taken for Gaussian Super Sampling: {time1_end - time1_start}")
+
+
+ if (displayPointCloud):
+ #point cloud
+ point_cloud = pv.PolyData(pos3D.numpy())
+
+ # Add colors to the point data
+ point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy()
+
+ # Plot the point cloud
+ plotter = pv.Plotter()
+ plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.05)
+ plotter.show_axes()
+ plotter.show()
+
+ time2_start = time()
+
+ #find normals
+ normalsNP = ply2M.Estimate_Normals(pos3D, threshold)
+ normals = torch.from_numpy(normalsNP)
+
+ #print("Time to compute normals:", time() - time2_start)
+
+ up_vector = torch.tensor([[1,1,1]],dtype=torch.float)
+ #up_vector = 2*torch.rand((1,3))-1
+ up_vector = up_vector/torch.linalg.norm(up_vector,dim=1)
+
+ pos3D.to(device)
+ colors.to(device)
+ normals.to(device)
+ up_vector.to(device)
+
+ # Build initial graph
+ #edge_index are neighbors of a point, directions are the directions from that point
+ edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
+ #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv
+ edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
+
+ # Generate info for downsampled versions of the graph
+ clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+ #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device)
+
+ time2_end = time()
+ print(f"Time taken for graph construction: {time2_end - time2_start}")
+
+ time3_start = time()
+
+ # Make final graph and metadata needed for mapping the result after going through the network
+ content = Data(x=colors,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list)
+ content_meta = Data(pos3D=pos3D)
+
+ style,_ = gio.image2Graph(style_ref,depth=3,device=device)
+
+ # Load original network
+ enc_ref = encoder4()
+ dec_ref = decoder4()
+ matrix_ref = MulLayer('r41')
+
+ enc_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/vgg_r41.pth'))
+ dec_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/dec_r41.pth'))
+ matrix_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/r41.pth',map_location=torch.device(device)))
+
+ # Copy weights to graph network
+ enc = encoder(padding_mode="replicate")
+ dec = decoder(padding_mode="replicate")
+ matrix = TransformLayer()
+
+ with torch.no_grad():
+ enc.copy_weights(enc_ref)
+ dec.copy_weights(dec_ref)
+ matrix.copy_weights(matrix_ref)
+
+ content = content.to(device)
+ style = style.to(device)
+ enc = enc.to(device)
+ dec = dec.to(device)
+ matrix = matrix.to(device)
+
+ # Run graph network
+ with torch.no_grad():
+ cF = enc(content)
+ sF = enc(style)
+ feature,transmatrix = matrix(cF['r41'],sF['r41'],
+ content.edge_indexes[3],content.selections_list[3],
+ style.edge_indexes[3],style.selections_list[3],
+ content.interps_list[3] if hasattr(content,'interps_list') else None)
+ result = dec(feature,content)
+ result = result.clamp(0,1)
+
+ colors[:, 0:3] = result
+
+ time3_end = time()
+ print(f"Time taken for stylization: {time3_end - time3_start}")
+
+ if (displayPointCloud):
+ #point cloud
+ point_cloud = pv.PolyData(pos3D.numpy())
+
+ # Add colors to the point data
+ point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy()
+
+ # Plot the point cloud
+ plotter = pv.Plotter()
+ plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.25)
+ plotter.show_axes()
+ plotter.show()
+
+ time4_start = time()
+
+ #create the interpolator
+ interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu())
+ results_OriginalNP = interp2(pos3D_Original)
+ results_OriginalNP64 = torch.from_numpy(results_OriginalNP)
+ results_Original = results_OriginalNP64.to(torch.float32)
+
+ colors_and_opacity_Original = torch.cat((results_Original, opacity_Original.unsqueeze(1)), dim=1)
+
+ time4_end = time()
+ print(f"Time taken for interpolation: {time4_end - time4_start}")
+
+ # Save/show result
+ splt.splat_save(pos3D_Original.numpy(), scales_Original.numpy(), rots_Original.numpy(), colors_and_opacity_Original.numpy(), outPath, fileType)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "filename",
+ type=str,
+ default=''
+ )
+ parser.add_argument(
+ "--stylePath",
+ type=str,
+ default="style_ims/style0.jpg"
+ )
+ parser.add_argument(
+ "--outPath",
+ type=str,
+ default='output.splat'
+ )
+ parser.add_argument(
+ "--device",
+ default= 0 if torch.cuda.is_available() else "cpu",
+ choices=list(range(torch.cuda.device_count())) + ["cpu"] or ["cpu"]
+ )
+ parser.add_argument(
+ "--threshold",
+ type=float,
+ default=99.8
+ )
+ parser.add_argument(
+ "--samplingRate",
+ type=float,
+ default=1.5
+ )
+ parser.add_argument(
+ "--displayPointCloud",
+ action='store_true'
+ )
+ args = parser.parse_args()
+ styletransfer_with_filtering_sampling(**vars(args))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e018f9b6cb1f99176e421ae7b5035a62a747c36b
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,222 @@
+import torch
+import numpy as np
+from imageio import imread, imwrite
+from skimage.transform import resize
+import os
+from scipy.ndimage import binary_erosion, binary_dilation, distance_transform_cdt, grey_dilation
+import matplotlib.pyplot as plt
+import json
+
+def ensure_dir(file_path):
+ directory = os.path.dirname(file_path)
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+def loadJSON(file):
+ with open(file) as f:
+ result = json.load(f)
+ return result
+
+def saveJSON(data,file):
+ with open(file, 'w') as outfile:
+ json.dump(data, outfile)
+
+def loadImage(filename, asTensor = True, imagenet_mean = False, shape = None):
+
+ image = imread(filename)
+
+ if image.ndim == 2:
+ image = np.stack((image,image,image),axis=2)
+
+ image = image[:,:,:3]/255 # No alpha channels
+
+ if shape is not None:
+ image = resize(image, shape)
+
+ if imagenet_mean:
+ mean = np.array([0.485, 0.456, 0.406])
+ std = np.array([0.229, 0.224, 0.225])
+ image = (image - mean)/std
+
+ # Reorganize for torch
+ if asTensor:
+ image = np.transpose(image,(2,0,1))
+
+ image = np.expand_dims(image,axis=0)
+
+ return torch.tensor(image,dtype=torch.float)
+ else:
+ return image
+
+def loadMask(filename, asTensor = True, shape = None):
+
+ image = imread(filename)
+
+ if image.ndim == 3:
+ image = image[:,:,0]
+
+ if shape is not None:
+ image = resize(image, shape)
+
+ mask = image >= 128
+
+ # Reorganize for torch
+ if asTensor:
+ mask = np.expand_dims(mask,axis=0)
+
+ return torch.tensor(mask,dtype=torch.bool)
+ else:
+ return mask
+
+def toTensor(numpy_data):
+ image = np.transpose(numpy_data,(2,0,1))
+ image = np.expand_dims(image,axis=0)
+ return torch.tensor(image,dtype=torch.float)
+
+def toTorch(numpy_data):
+ return toTensor(numpy_data)
+
+def toNumpy(tensor, permute=True):
+ image = np.squeeze(tensor.detach().clone().cpu().numpy())
+ if permute:
+ image = np.transpose(image,(1,2,0))
+ return image
+
+def makeCanvas(x,original):
+
+ if torch.is_tensor(original):
+ original = toNumpy(original)
+
+ if x.ndim > 1:
+ if x.shape[-1] == original.shape[-1]:
+ return original
+
+ rows,cols,_ = original.shape
+ if x.ndim > 1:
+ return np.zeros((rows,cols,x.shape[-1]))
+ else:
+ return np.zeros((rows,cols))
+
+def reverse_selection(s):
+ return [0, 5, 6, 7, 8, 1, 2, 3, 4][s]
+
+
+def extrapolate_image(im, numpy = False):
+
+ rows,cols,ch = im.shape
+ if numpy:
+ result = np.zeros((rows+1,cols+1,ch))
+ else:
+ result = torch.zeros((rows+1,cols+1,ch))
+
+ result[:rows,:cols] = im
+
+ # Extrapolate last column
+ result[:rows,cols] = 2*result[:rows,cols-1] - result[:rows,cols-2]
+ # Extrapolate last row
+ result[rows] = 2*result[rows-1] - result[rows-2]
+
+ return result
+
+def bilinear_interpolate(im, x, y, numpy = False):
+
+ if numpy:
+
+ im = extrapolate_image(im,numpy=True)
+
+ x = np.asarray(x)
+ y = np.asarray(y)
+
+ x0 = np.floor(x).astype(int)
+ x1 = x0 + 1
+ y0 = np.floor(y).astype(int)
+ y1 = y0 + 1
+
+ x0 = np.clip(x0, 0, im.shape[1]-1);
+ x1 = np.clip(x1, 0, im.shape[1]-1);
+ y0 = np.clip(y0, 0, im.shape[0]-1);
+ y1 = np.clip(y1, 0, im.shape[0]-1);
+
+ Ia = im[ y0, x0 ]
+ Ib = im[ y1, x0 ]
+ Ic = im[ y0, x1 ]
+ Id = im[ y1, x1 ]
+
+ wa = (x1-x) * (y1-y)
+ wb = (x1-x) * (y-y0)
+ wc = (x-x0) * (y1-y)
+ wd = (x-x0) * (y-y0)
+
+ wa = np.expand_dims(wa,axis=1)
+ wb = np.expand_dims(wb,axis=1)
+ wc = np.expand_dims(wc,axis=1)
+ wd = np.expand_dims(wd,axis=1)
+
+ else:
+ im = im[0].permute((1,2,0)).float()
+ im = extrapolate_image(im)
+
+ x0 = torch.floor(x).long()
+ x1 = x0 + 1
+ y0 = torch.floor(y).long()
+ y1 = y0 + 1
+
+ x0 = torch.clip(x0, 0, im.shape[1]-1);
+ x1 = torch.clip(x1, 0, im.shape[1]-1);
+ y0 = torch.clip(y0, 0, im.shape[0]-1);
+ y1 = torch.clip(y1, 0, im.shape[0]-1);
+
+ Ia = im[y0,x0]
+ Ib = im[y1,x0]
+ Ic = im[y0,x1]
+ Id = im[y1,x1]
+
+ wa = (x1-x) * (y1-y)
+ wb = (x1-x) * (y-y0)
+ wc = (x-x0) * (y1-y)
+ wd = (x-x0) * (y-y0)
+
+ wa = torch.unsqueeze(wa,dim=1)
+ wb = torch.unsqueeze(wb,dim=1)
+ wc = torch.unsqueeze(wc,dim=1)
+ wd = torch.unsqueeze(wd,dim=1)
+
+ return wa*Ia + wb*Ib + wc*Ic + wd*Id
+
+def cosineWeighting(rows,cols):
+ phi_vals = np.linspace(-np.pi/2,np.pi/2,rows)
+ cosines = np.cos(phi_vals)
+
+ result = np.zeros((rows,cols))
+ for i in range(rows):
+ result[i, :] = cosines[i]
+
+ return result
+
+def cross(a,b):
+ # Computes the cross product of two torch tensors
+ out_i = a[:,1]*b[:,2] - a[:,2]*b[:,1]
+ out_j = a[:,2]*b[:,0] - a[:,0]*b[:,2]
+ out_k = a[:,0]*b[:,1] - a[:,1]*b[:,0]
+
+ out = torch.stack((out_i,out_j,out_k),dim=1)
+
+ return out
+
+def interpolatePointCloud2D(source_points,source_features,target_x,target_y,extrapolate=True):
+ #from scipy.interpolate import LinearNDInterpolator as Interpolater
+ from scipy.interpolate import CloughTocher2DInterpolator as Interpolater
+ from scipy.interpolate import NearestNDInterpolator as Extrapolater
+ interp = Interpolater(source_points,source_features)
+ result = interp(target_x,target_y)
+
+ chk = np.isnan(result)
+
+ if chk.any() and extrapolate:
+ nearest = Extrapolater(source_points,source_features)
+ return np.where(chk, nearest(target_x,target_y), result)
+ else:
+ result = np.where(chk, 0, result)
+ return result
+
+
\ No newline at end of file
diff --git a/venv libs.txt b/venv libs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..27689b8f3043e9408989c0bc1564afb57beb6fcf
--- /dev/null
+++ b/venv libs.txt
@@ -0,0 +1,122 @@
+aiohttp 3.9.5
+aiosignal 1.3.1
+asttokens 3.0.0
+attrs 23.2.0
+blinker 1.9.0
+bqplot 0.12.44
+certifi 2024.7.4
+charset-normalizer 3.3.2
+click 8.1.8
+colorama 0.4.6
+comm 0.2.2
+ConfigArgParse 1.7
+contourpy 1.2.1
+cycler 0.12.1
+dash 2.18.2
+dash-core-components 2.0.0
+dash-html-components 2.0.0
+dash-table 5.0.0
+decorator 5.1.1
+equilib 0.0.1
+executing 2.2.0
+fastjsonschema 2.21.1
+filelock 3.13.1
+Flask 3.0.3
+fonttools 4.53.1
+frozenlist 1.4.1
+fsspec 2024.2.0
+idna 3.7
+imageio 2.34.2
+importlib_metadata 8.6.1
+intel-openmp 2021.4.0
+ipydatawidgets 4.3.5
+ipython 8.31.0
+ipyvue 1.11.2
+ipyvuetify 1.11.1
+ipywebrtc 0.6.0
+ipywidgets 8.1.5
+itsdangerous 2.2.0
+jedi 0.19.2
+Jinja2 3.1.3
+joblib 1.4.2
+jsonschema 4.23.0
+jsonschema-specifications 2024.10.1
+jupyter_core 5.7.2
+jupyterlab_widgets 3.0.13
+kiwisolver 1.4.5
+lazy_loader 0.4
+markdown-it-py 3.0.0
+MarkupSafe 2.1.5
+matplotlib 3.9.1
+matplotlib-inline 0.1.7
+mdurl 0.1.2
+meshio 5.3.5
+mkl 2021.4.0
+mpmath 1.3.0
+multidict 6.0.5
+nbformat 5.10.4
+nest-asyncio 1.6.0
+networkx 3.2.1
+numpy 1.26.3
+open3d 0.19.0
+opencv-python 4.10.0.84
+packaging 24.1
+pandas 2.2.2
+parso 0.8.4
+pillow 10.2.0
+pip 24.0
+platformdirs 4.2.2
+plotly 5.24.1
+plyfile 1.0.3
+pooch 1.8.2
+prompt_toolkit 3.0.50
+psutil 6.0.0
+pure_eval 0.2.3
+py360convert 0.1.0
+pyg_lib 0.4.0+pt23cu121
+pyglet 1.5.15
+Pygments 2.18.0
+pyparsing 3.1.2
+python-dateutil 2.9.0.post0
+pythreejs 2.4.2
+pytz 2024.1
+pyvista 0.44.1
+pywin32 308
+referencing 0.36.2
+requests 2.32.3
+retrying 1.3.4
+rich 13.7.1
+rpds-py 0.22.3
+scikit-image 0.24.0
+scikit-learn 1.5.1
+scipy 1.14.0
+scooby 0.10.0
+setuptools 75.8.0
+six 1.16.0
+stack-data 0.6.3
+sympy 1.12
+tbb 2021.11.0
+tenacity 9.0.0
+threadpoolctl 3.5.0
+tifffile 2024.7.2
+torch 2.3.1+cu121
+torch_cluster 1.6.3+pt23cu121
+torch_geometric 2.5.3
+torch_scatter 2.1.2+pt23cu121
+torch_sparse 0.6.18+pt23cu121
+torch_spline_conv 1.2.2+pt23cu121
+torchaudio 2.3.1+cu121
+torchvision 0.18.1+cu121
+tqdm 4.66.4
+traitlets 5.14.3
+traittypes 0.2.1
+trimesh 4.4.3
+typing_extensions 4.9.0
+tzdata 2024.1
+urllib3 2.2.2
+vtk 9.3.1
+wcwidth 0.2.13
+Werkzeug 3.0.6
+widgetsnbextension 4.0.13
+yarl 1.9.4
+zipp 3.21.0
\ No newline at end of file