diff --git a/FastSplatStyler_huggingface/.DS_Store b/FastSplatStyler_huggingface/.DS_Store deleted file mode 100644 index f51561a9f167555daae4d5080f30840e41db2d87..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/.DS_Store and /dev/null differ diff --git a/FastSplatStyler_huggingface/.gitignore b/FastSplatStyler_huggingface/.gitignore deleted file mode 100644 index ad134a25780be998c90dcb99e8b4e9f0ca712cb9..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/.gitignore +++ /dev/null @@ -1,210 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[codz] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py.cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock -#poetry.toml - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. -# https://pdm-project.org/en/latest/usage/project/#working-with-version-control -#pdm.lock -#pdm.toml -.pdm-python -.pdm-build/ - -# pixi -# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. -#pixi.lock -# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one -# in the .venv directory. It is recommended not to include this directory in version control. -.pixi - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.envrc -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Abstra -# Abstra is an AI-powered process automation framework. -# Ignore directories containing user credentials, local state, and settings. -# Learn more at https://abstra.io/docs -.abstra/ - -# Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore -# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, -# you could uncomment the following to ignore the entire vscode folder -# .vscode/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# Cursor -# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to -# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data -# refer to https://docs.cursor.com/context/ignore-files -.cursorignore -.cursorindexingignore - -# Marimo -marimo/_static/ -marimo/_lsp/ -__marimo__/ -.DS_Store -/graph_networks -output.splat diff --git a/FastSplatStyler_huggingface/LICENSE b/FastSplatStyler_huggingface/LICENSE deleted file mode 100644 index dd3fbced4eaaebd97a05c54796d9d2afafd095b2..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2025 ECU Computer Vision Lab - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/FastSplatStyler_huggingface/README.md b/FastSplatStyler_huggingface/README.md deleted file mode 100644 index fdf34a3b83a36b9d9b1ac8f2495a9b921473ff9c..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# FastSplatStyler -Official Implementation of "Optimization-Free Style Transfer of 3D Gaussian Splats" - -[arXiv Paper](https://arxiv.org/abs/2508.05813) - -![](example.jpg) - -## Example Outputs - -Example Outputs can be visualized using the [Antimatter WebGL viewer](https://antimatter15.com/splat/) at the following links. - -- Broche: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/broche-rose-gold_style3.splat) -- Crystal Lamp: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-original.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/crystal-lamp-style2.splat) -- Family Statue: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/family-style-6.splat) -- M60 Tanks: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/m60-style-31.splat) -- Table: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table.ply) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/Table_style5.splat) -- Train: [Original](https://antimatter15.com/splat/) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/train_style1.splat) -- Truck: [Original](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck.splat) and [Stylized](https://antimatter15.com/splat/?url=https://huggingface.co/datasets/incrl/fast-splat-styler/resolve/main/truck-style-21.splat) - -Example inputs and outputs can be downloaded from [Google Drive](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link) or [Hugging Face](https://huggingface.co/datasets/incrl/fast-splat-styler/tree/main) - -## Demo - -**Coming soon** - -## Install - -This work relies heavily on the [Pytorch](https://pytorch.org/) and [Pytorch Geometric](https://www.pyg.org/) libraries. - -This code was tested with Python 3.12, Pytorch 2.9.1 (CUDA Toolkit 12.8), and Pytorch Geometric 2.8 on Windows. It was also tested on Mac with the same setup (without Cuda). The provided requirements.txt file comes from the Mac configuration. - -This repository relies on a graph networks library that was presented in a [previous work](https://github.com/davidmhart/interpolated-selectionconv/tree/main). The library can be downloaded, with included model weights, at this [google drive link](https://drive.google.com/drive/folders/10YmtcCOKGosXfPEi84ho1AfYRYioYo12?usp=drive_link). Place the "graph_networks" folder in the main directory. - -To stylize an example splat, run `python styletransfer_splat.py example-broche-rose-gold.splat --stylePath style_ims/style0.jpg --samplingRate 1.5`. You can change the input splat and style image to your specific use case. - -Supports `.splat` and `.ply` files. diff --git a/FastSplatStyler_huggingface/app.py b/FastSplatStyler_huggingface/app.py deleted file mode 100644 index f3bca07be619b8f6a2f19d5470c513b84217c05b..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/app.py +++ /dev/null @@ -1,302 +0,0 @@ -import gradio as gr -import torch -import os -import tempfile -import shutil -from pathlib import Path -from time import time - -# ── Core style-transfer logic (adapted from styletransfer_splat.py) ────────── -import pointCloudToMesh as ply2M -import utils -import graph_io as gio -from clusters import * -import splat_mesh_helpers as splt -import clusters as cl -from torch_geometric.data import Data -from scipy.interpolate import NearestNDInterpolator - -from graph_networks.LinearStyleTransfer_vgg import encoder, decoder -from graph_networks.LinearStyleTransfer_matrix import TransformLayer -from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer -from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4 - - -# ── Example assets (place your own files in ./examples/) ───────────────────── -EXAMPLE_SPLATS = [ - ["example-broche-rose-gold.splat", "style_ims/style2.jpg"], - ["example-broche-rose-gold.splat", "style_ims/style6.jpg"], -] - - -# ── Style-transfer function called by Gradio ───────────────────────────────── -def run_style_transfer( - splat_file, - style_image, - threshold: float, - sampling_rate: float, - device_choice: str, - progress=gr.Progress(track_tqdm=True), -): - if splat_file is None: - raise gr.Error("Please upload a 3D Gaussian Splat file (.ply or .splat).") - if style_image is None: - raise gr.Error("Please upload a style image.") - - device = device_choice if device_choice == "cpu" else f"cuda:{device_choice}" - - # ── Parameters ──────────────────────────────────────────────────────────── - n = 25 - ratio = 0.25 - depth = 3 - style_shape = (512, 512) - - logs = [] - - def log(msg): - logs.append(msg) - print(msg) - return "\n".join(logs) - - # ── 1. Load splat ───────────────────────────────────────────────────────── - progress(0.05, desc="Loading splat…") - splat_path = splat_file.name if hasattr(splat_file, "name") else splat_file - log(f"Loading splat: {splat_path}") - - pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = \ - splt.splat_unpacker_with_threshold(n, splat_path, threshold) - - # ── 2. Gaussian super-sampling ──────────────────────────────────────────── - progress(0.15, desc="Super-sampling…") - t0 = time() - if sampling_rate > 1: - GaussianSamples = int(pos3D_Original.shape[0] * sampling_rate) - pos3D, colors = splt.splat_GaussianSuperSampler( - pos3D_Original.clone(), colors_Original.clone(), - opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(), - GaussianSamples, - ) - else: - pos3D, colors = pos3D_Original, colors_Original - log(f"Nodes in graph: {pos3D.shape[0]} ({time()-t0:.1f}s)") - - # ── 3. Graph construction ───────────────────────────────────────────────── - progress(0.30, desc="Building surface graph…") - t0 = time() - style_ref = utils.loadImage(style_image, shape=style_shape) - - normalsNP = ply2M.Estimate_Normals(pos3D, threshold) - normals = torch.from_numpy(normalsNP) - - up_vector = torch.tensor([[1, 1, 1]], dtype=torch.float) - up_vector = up_vector / torch.linalg.norm(up_vector, dim=1) - - pos3D = pos3D.to(device) - colors = colors.to(device) - normals = normals.to(device) - up_vector = up_vector.to(device) - - edge_index, directions = gh.surface2Edges(pos3D, normals, up_vector, k_neighbors=16) - edge_index, selections, interps = gh.edges2Selections(edge_index, directions, interpolated=True) - - clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters( - pos3D, normals, edge_index, selections, interps, - ratio=ratio, up_vector=up_vector, depth=depth, device=device, - ) - log(f"Graph built ({time()-t0:.1f}s)") - - # ── 4. Load networks ────────────────────────────────────────────────────── - progress(0.50, desc="Loading networks…") - t0 = time() - - enc_ref = encoder4() - dec_ref = decoder4() - matrix_ref = MulLayer("r41") - - enc_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/vgg_r41.pth", map_location=device)) - dec_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/dec_r41.pth", map_location=device)) - matrix_ref.load_state_dict(torch.load("graph_networks/LinearStyleTransfer/models/r41.pth", map_location=device)) - - enc = encoder(padding_mode="replicate") - dec = decoder(padding_mode="replicate") - matrix = TransformLayer() - - with torch.no_grad(): - enc.copy_weights(enc_ref) - dec.copy_weights(dec_ref) - matrix.copy_weights(matrix_ref) - - content = Data( - x=colors, clusters=clusters, - edge_indexes=edge_indexes, - selections_list=selections_list, - interps_list=interps_list, - ).to(device) - - style, _ = gio.image2Graph(style_ref, depth=3, device=device) - - enc = enc.to(device) - dec = dec.to(device) - matrix = matrix.to(device) - log(f"Networks loaded ({time()-t0:.1f}s)") - - # ── 5. Style transfer ───────────────────────────────────────────────────── - progress(0.70, desc="Running style transfer…") - t0 = time() - - with torch.no_grad(): - cF = enc(content) - sF = enc(style) - feature, _ = matrix( - cF["r41"], sF["r41"], - content.edge_indexes[3], content.selections_list[3], - style.edge_indexes[3], style.selections_list[3], - content.interps_list[3] if hasattr(content, "interps_list") else None, - ) - result = dec(feature, content).clamp(0, 1) - - colors[:, 0:3] = result - log(f"Stylization done ({time()-t0:.1f}s)") - - # ── 6. Interpolate back to original resolution ──────────────────────────── - progress(0.88, desc="Interpolating back to original splat…") - t0 = time() - - interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu()) - results_OriginalNP = interp2(pos3D_Original) - results_Original = torch.from_numpy(results_OriginalNP).to(torch.float32) - colors_and_opacity_Original = torch.cat( - (results_Original, opacity_Original.unsqueeze(1)), dim=1 - ) - log(f"Interpolation done ({time()-t0:.1f}s)") - - # ── 7. Save output ──────────────────────────────────────────────────────── - progress(0.95, desc="Saving output splat…") - suffix = ".splat" if fileType == "splat" else ".ply" - out_dir = tempfile.mkdtemp() - out_path = os.path.join(out_dir, f"stylized{suffix}") - - splt.splat_save( - pos3D_Original.numpy(), - scales_Original.numpy(), - rots_Original.numpy(), - colors_and_opacity_Original.numpy(), - out_path, - fileType, - ) - log(f"Saved to: {out_path}") - progress(1.0, desc="Done!") - - return out_path, "\n".join(logs) - - -# ── Gradio UI ───────────────────────────────────────────────────────────────── -def build_ui(): - available_devices = ( - [str(i) for i in range(torch.cuda.device_count())] + ["cpu"] - if torch.cuda.is_available() - else ["cpu"] - ) - - with gr.Blocks( - title="3DGS Style Transfer", - theme=gr.themes.Soft(primary_hue="violet"), - css=""" - #title { text-align: center; } - #subtitle { text-align: center; color: #666; margin-bottom: 1rem; } - .panel { border-radius: 12px; } - #run-btn { font-size: 1.1rem; } - """, - ) as demo: - - gr.Markdown("# 🎨 3D Gaussian Splat Style Transfer", elem_id="title") - gr.Markdown( - "Upload a 3DGS scene and a style image — the app will repaint the splat " - "with the artistic style of the image and give you a stylized splat to download. " - "After downloading, you can view your splat with an [online viewer](https://antimatter15.com/splat/).", - elem_id="subtitle", - ) - - with gr.Row(): - # ── Left column: inputs ─────────────────────────────────────────── - with gr.Column(scale=1, elem_classes="panel"): - gr.Markdown("### 📂 Inputs") - - splat_input = gr.File( - label="3D Gaussian Splat (.ply or .splat)", - file_types=[".ply", ".splat"], - type="filepath", - ) - - style_input = gr.Image( - label="Style Image", - type="filepath", - height=240, - ) - - with gr.Accordion("⚙️ Advanced Settings", open=False): - threshold_slider = gr.Slider( - minimum=90.0, maximum=100.0, value=99.8, step=0.1, - label="Opacity threshold (percentile)", - info="Points below this opacity percentile are removed.", - ) - sampling_slider = gr.Slider( - minimum=0.5, maximum=3.0, value=1.5, step=0.1, - label="Gaussian super-sampling rate", - info="Values > 1 add extra samples; 1.0 = no super-sampling.", - ) - device_radio = gr.Radio( - choices=available_devices, - value=available_devices[0], - label="Device", - ) - - run_btn = gr.Button("🚀 Run Style Transfer", variant="primary", elem_id="run-btn") - - # ── Right column: outputs ───────────────────────────────────────── - with gr.Column(scale=1, elem_classes="panel"): - gr.Markdown("### 📥 Output") - - output_file = gr.File( - label="Download Stylized Splat", - interactive=False, - ) - - log_box = gr.Textbox( - label="Progress log", - lines=12, - max_lines=20, - interactive=False, - placeholder="Logs will appear here once processing starts…", - ) - - # ── Examples ───────────────────────────────────────────────────────── - example_splat_paths = [row[0] for row in EXAMPLE_SPLATS] - example_style_paths = [row[1] for row in EXAMPLE_SPLATS] - - valid_examples = [ - row for row in EXAMPLE_SPLATS - if os.path.exists(row[0]) and os.path.exists(row[1]) - ] - - if valid_examples: - gr.Markdown("### 🖼️ Examples") - gr.Examples( - examples=valid_examples, - inputs=[splat_input, style_input], - label="Click an example to load it", - ) - - # ── Event wiring ────────────────────────────────────────────────────── - run_btn.click( - fn=run_style_transfer, - inputs=[splat_input, style_input, threshold_slider, sampling_slider, device_radio], - outputs=[output_file, log_box], - ) - - return demo - - -if __name__ == "__main__": - demo = build_ui() - demo.launch(share=False) diff --git a/FastSplatStyler_huggingface/clusters.py b/FastSplatStyler_huggingface/clusters.py deleted file mode 100644 index 556fe199df9f85abd6c4a3df52252753dd0d2461..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/clusters.py +++ /dev/null @@ -1,234 +0,0 @@ -import torch -from torch_scatter import scatter -from torch_geometric.nn.pool.consecutive import consecutive_cluster -from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops -from torch_geometric.nn import fps, knn -from torch_sparse import coalesce -import graph_helpers as gh -import sphere_helpers as sh -import mesh_helpers as mh - -import math -from math import pi,sqrt - - -from warnings import warn - -def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2): - clusters = [] - edge_indexes = [torch.clone(edge_index).to(device)] - selections_list = [torch.clone(selections).to(device)] - - for _ in range(depth): - Nx = Nx//stride - Ny = Ny//stride - cx,cy = getGrid(pos2D,Nx,Ny) - cluster, pos2D = gridCluster(pos2D,cx,cy,Nx) - edge_index, selections = selectionAverage(cluster, edge_index, selections) - - clusters.append(torch.clone(cluster).to(device)) - edge_indexes.append(torch.clone(edge_index).to(device)) - selections_list.append(torch.clone(selections).to(device)) - - return clusters, edge_indexes, selections_list - -def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'): - clusters = [] - edge_indexes = [torch.clone(edge_index).to(device)] - selections_list = [torch.clone(selections).to(device)] - interps_list = [torch.clone(interps).to(device)] - - for _ in range(depth): - - rows = rows//stride - cols = cols//stride - - if bary_d is not None: - bary_d = bary_d*stride - - if cluster_method == "equirec": - centroids, _ = sh.sampleSphere_Equirec(rows,cols) - - elif cluster_method == "layering": - centroids, _ = sh.sampleSphere_Layering(rows) - - elif cluster_method == "spiral": - centroids, _ = sh.sampleSphere_Spiral(rows,cols) - - elif cluster_method == "icosphere": - centroids, _ = sh.sampleSphere_Icosphere(rows) - - elif cluster_method == "random": - centroids, _ = sh.sampleSphere_Random(rows,cols) - - elif cluster_method == "random_nodes": - index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice - centroids = pos3D[index] - - elif cluster_method == "fps": - # Farthest Point Search used in PointNet++ - index = fps(pos3D, ratio=ratio) - centroids = pos3D[index] - else: - raise ValueError("Sphere cluster_method unknown") - - - # Find closest centriod to each current point - cluster = knn(centroids,pos3D,1)[1] - cluster, _ = consecutive_cluster(cluster) - pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') - - # Regenerate surface graph - normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors - edge_index,directions = gh.surface2Edges(pos3D,normals) - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d) - - clusters.append(torch.clone(cluster).to(device)) - edge_indexes.append(torch.clone(edge_index).to(device)) - selections_list.append(torch.clone(selections).to(device)) - interps_list.append(torch.clone(interps).to(device)) - - return clusters, edge_indexes, selections_list, interps_list - -def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'): - clusters = [] - edge_indexes = [torch.clone(edge_index).to(device)] - selections_list = [torch.clone(selections).to(device)] - interps_list = [torch.clone(interps).to(device)] - - for _ in range(depth): - - #Desired number of clusters in the next level - N = int(len(pos3D) * ratio) - - if cluster_method == "random": - index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice - centroids = pos3D[index] - - elif cluster_method == "fps": - # Farthest Point Search used in PointNet++ - index = fps(pos3D, ratio=ratio) - centroids = pos3D[index] - - # Find closest centriod to each current point - cluster = knn(centroids,pos3D,1)[1] - cluster, _ = consecutive_cluster(cluster) - pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') - normals = scatter(normals, cluster, dim=0, reduce='mean') - - # Regenerate surface graph - normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors - edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) - - clusters.append(torch.clone(cluster).to(device)) - edge_indexes.append(torch.clone(edge_index).to(device)) - selections_list.append(torch.clone(selections).to(device)) - interps_list.append(torch.clone(interps).to(device)) - - return clusters, edge_indexes, selections_list, interps_list - - -def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'): - clusters = [] - edge_indexes = [torch.clone(edge_index).to(device)] - selections_list = [torch.clone(selections).to(device)] - interps_list = [torch.clone(interps).to(device)] - - for _ in range(depth): - - #Desired number of clusters in the next level - N = int(len(pos3D) * ratio) - - # Generate new point cloud from downsampled version of texture map - centroids, normals = mh.sampleSurface(mesh,N,return_x=False) - - # Find closest centriod to each current point - cluster = knn(centroids,pos3D,1)[1] - cluster, _ = consecutive_cluster(cluster) - pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') - - # Regenerate surface graph - #normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors - edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector) - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) - - clusters.append(torch.clone(cluster).to(device)) - edge_indexes.append(torch.clone(edge_index).to(device)) - selections_list.append(torch.clone(selections).to(device)) - interps_list.append(torch.clone(interps).to(device)) - - return clusters, edge_indexes, selections_list, interps_list - - - -def getGrid(pos,Nx,Ny,xrange=None,yrange=None): - xmin = torch.min(pos[:,0]) if xrange is None else xrange[0] - ymin = torch.min(pos[:,1]) if yrange is None else yrange[0] - xmax = torch.max(pos[:,0]) if xrange is None else xrange[1] - ymax = torch.max(pos[:,1]) if yrange is None else yrange[1] - - cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1) - cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1) - return cx, cy - -def gridCluster(pos,cx,cy,xmax): - cluster = cx + cy*xmax - cluster = cluster.type(torch.long) # Cast appropriately - cluster, _ = consecutive_cluster(cluster) - pos = scatter(pos, cluster, dim=0, reduce='mean') - - return cluster, pos - -def selectionAverage(cluster, edge_index, selections): - num_nodes = cluster.size(0) - edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1) - edge_index, selections = remove_self_loops(edge_index, selections) - if edge_index.numel() > 0: - - # To avoid means over discontinuities, do mean for two selections at at a time - final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") - selections_check = torch.round(selections_check).type(torch.long) - - final_selections = torch.zeros_like(selections_check).to(selections.device) - - final_selections[torch.where(selections_check==4)] = 4 - final_selections[torch.where(selections_check==5)] = 5 - - #Rotate selection kernel - selections += 2 - selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") - - _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") - selections_check = torch.round(selections_check).type(torch.long) - final_selections[torch.where(selections_check==4)] = 2 - final_selections[torch.where(selections_check==5)] = 3 - - #Rotate selection kernel - selections += 2 - selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") - - _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") - selections_check = torch.round(selections_check).type(torch.long) - final_selections[torch.where(selections_check==4)] = 8 - final_selections[torch.where(selections_check==5)] = 1 - - #Rotate selection kernel - selections += 2 - selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") - - _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") - selections_check = torch.round(selections_check).type(torch.long) - final_selections[torch.where(selections_check==4)] = 6 - final_selections[torch.where(selections_check==5)] = 7 - - #print(torch.min(final_selections), torch.max(final_selections)) - #print(torch.mean(final_selections.type(torch.float))) - - edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long)) - - else: - edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long)) - print("Warning: Edge Pool found no edges") - - return edge_index, selections diff --git a/FastSplatStyler_huggingface/example-broche-rose-gold.splat b/FastSplatStyler_huggingface/example-broche-rose-gold.splat deleted file mode 100644 index 4e96c18d7bd1e6436ad9fa64577dd91b1e4ad113..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/example-broche-rose-gold.splat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:62cccf0adfb4e5713985e8ee874fa30a6a37ae222a23ead7c6e4639a5802ab62 -size 4157728 diff --git a/FastSplatStyler_huggingface/example.jpg b/FastSplatStyler_huggingface/example.jpg deleted file mode 100644 index 2a3fbbc38cb9027240bd9d50a506c1fb5def1dd7..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/example.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_helpers.py b/FastSplatStyler_huggingface/graph_helpers.py deleted file mode 100644 index 2b53a78fd9ab2ac39855d9663e3343d3f1093b24..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_helpers.py +++ /dev/null @@ -1,400 +0,0 @@ -import torch -from torch_geometric.nn import radius_graph, knn_graph -import torch_geometric as tg -from torch_geometric.utils import subgraph -import utils -from math import sqrt - -def getImPos(rows,cols,start_row=0,start_col=0): - row_space = torch.arange(start_row,rows+start_row) - col_space = torch.arange(start_col,cols+start_col) - col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy') - im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2)) - return im_pos - -def convertImPos(im_pos,flip_y=True): - - # Cast to float for clustering based methods - pos2D = im_pos.float() - - # Switch rows,cols to x,y - pos2D[:,[1,0]] = pos2D[:,[0,1]] - - if flip_y: - - # Flip to y-axis to match mathematical definition and edges2Selections settings - pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1] - - return pos2D - - -def grid2Edges(locs): - # Assume locs are already spaced at a distance of 1 structure - edge_index = radius_graph(locs,1.44,loop=True) - return edge_index - -def radius2Edges(locs,r=1.0): - edge_index = radius_graph(locs,r,loop=True) - return edge_index - -def knn2Edges(locs,knn=9): - edge_index = knn_graph(locs,knn,loop=True) - return edge_index - -def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9): - - if up_vector is None: - up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device) - - # K Nearest Neighbors graph - edge_index = knn_graph(pos3D,k_neighbors,loop=True) - - # Cull neighbors based on normals (dot them together) - culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1) - edge_index = edge_index[:,torch.where(culling>0)[0]] - - # For each node, rotate based on Grahm-Schmidt Orthognalization - norms = normals[edge_index[0]] - - z_dir = norms - z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector - #x_dir = torch.cross(up_vector,norms,dim=1) - x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch - x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True) - #y_dir = torch.cross(norms,x_dir,dim=1) - y_dir = utils.cross(norms,x_dir) - y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True) - - directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]]) - - # Perform rotation by multiplying out rotation matrix - temp = torch.clone(directions) # Buffer - directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2] - directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2] - #directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2] - - # Drop z coordinate - directions = directions[:,:2] - - return edge_index, directions - -def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False): - - # Current Ordering - # 4 3 2 - # 5 0 1 - # 6 7 8 - if y_down: - vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0) - else: - vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0) - - if interpolated: - - if bary_d is None: - edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList) - else: - edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList) - interps = normalizeEdges(edge_index,selections,interps) - return edge_index,selections,interps - - else: - selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1 - selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection - return selections - - -def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True): - - sources = sources.flatten() - targets = targets.flatten() - - prev_sources += sources.tolist() - prev_targets += targets.tolist() - prev_selections += len(sources)*[selection] - - if reverse: - prev_sources += targets - prev_targets += sources - prev_selections += len(sources)*[utils.reverse_selection(selection)] - - return prev_sources,prev_targets,prev_selections - -def maskNodes(mask,x): - node_mask = torch.where(mask) - x = x[node_mask] - return x - -def maskPoints(mask,x,y): - - mask = torch.squeeze(mask) - - x0 = torch.floor(x).long() - x1 = x0 + 1 - y0 = torch.floor(y).long() - y1 = y0 + 1 - - x0 = torch.clip(x0, 0, mask.shape[1]-1); - x1 = torch.clip(x1, 0, mask.shape[1]-1); - y0 = torch.clip(y0, 0, mask.shape[0]-1); - y1 = torch.clip(y1, 0, mask.shape[0]-1); - - Ma = mask[ y0, x0 ] - Mb = mask[ y1, x0 ] - Mc = mask[ y0, x1 ] - Md = mask[ y1, x1 ] - - node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0] - - return node_mask - - -def maskGraph(mask,edge_index,selections,interps=None): - - edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True) - selections = selections[edge_mask] - - if interps: - interps = interps[edge_mask] - return edge_index, selections, interps - else: - return edge_index, selections - -def interpolateSelections(edge_index,directions,vectorList=None): - - if vectorList is None: - # Current Ordering - # 4 3 2 - # 5 0 1 - # 6 7 8 - vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0) - - # Normalize directions for simplicity of calculations - dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True) - directions = directions/dir_norm - #locs = torch.where(dir_norm > 1)[0] - #directions[locs] = directions[locs]/dir_norm[locs] - - values = torch.matmul(directions,vectorList) - best = torch.unsqueeze(torch.argmax(values,dim=1),1) - - best_val = torch.take_along_dim(values,best,dim=1) - - # Look at both neighbors to see who is closer - lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1) - upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1) - - comp_vals = torch.cat((lower_val,upper_val),dim=1) - - second_best_vals = torch.amax(comp_vals,dim=1) - second_best = torch.argmax(comp_vals,dim=1) - - # Find the interpolation value (in terms of angles) - best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function - angle_best = torch.arccos(best_val) - angle_second_best = torch.arccos(second_best_vals) - - angle_vals = angle_best/(angle_second_best + angle_best) - - # Use negative values for clockwise selections - clockwise = torch.where(second_best == 0)[0] - angle_vals[clockwise] = -angle_vals[clockwise] - - # Handle computation problems at the poles - angle_vals = torch.nan_to_num(angle_vals) - - # Make Selections - selections = best[:,0] + 1 - - # Same cell selection - same_locs = torch.where(edge_index[0] == edge_index[1]) - selections[same_locs] = 0 - angle_vals[same_locs] = 0 - - # Make starting interp_values - interps = torch.ones_like(angle_vals) - interps -= torch.abs(angle_vals) - - # Add new edges - pos_interp_locs = torch.where(angle_vals > 1e-2)[0] - pos_interps = angle_vals[pos_interp_locs] - pos_edges = edge_index[:,pos_interp_locs] - pos_selections = selections[pos_interp_locs] + 1 - pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around - - neg_interp_locs = torch.where(angle_vals < -1e-2)[0] - neg_interps = torch.abs(angle_vals[neg_interp_locs]) - neg_edges = edge_index[:,neg_interp_locs] - neg_selections = selections[neg_interp_locs] - 1 - neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around - - edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1) - selections = torch.cat((selections,pos_selections,neg_selections),dim=0) - interps = torch.cat((interps,pos_interps,neg_interps),dim=0) - - return edge_index,selections,interps - -def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None): - - if vectorList is None: - # Current Ordering - # 4 3 2 - # 5 0 1 - # 6 7 8 - vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device) - - # Preprune central selections and reappend them at the end - same_locs = torch.where(edge_index[0] == edge_index[1]) - same_edges = edge_index[:,same_locs[0]] - - different_locs = torch.where(edge_index[0] != edge_index[1]) - edge_index = edge_index[:,different_locs[0]] - directions = directions[different_locs[0]] - - # Normalize directions for simplicity of calculations - dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True) - unit_directions = directions/dir_norm - #locs = torch.where(dir_norm > 1)[0] - #directions[locs] = directions[locs]/dir_norm[locs] - - values = torch.matmul(unit_directions,vectorList) - best = torch.unsqueeze(torch.argmax(values,dim=1),1) - #best_val = torch.take_along_dim(values,best,dim=1) - - # Look at both neighbors to see who is closer - lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1) - upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1) - - comp_vals = torch.cat((lower_val,upper_val),dim=1) - - second_best = torch.argmax(comp_vals,dim=1) - #second_best_vals = torch.amax(comp_vals,dim=1) - - # Convert into uv cooridnates for barycentric interpolation calculation - # /| - # / |v - # /__| - # u - - scaled_directions = torch.abs(directions/d) - u = torch.amax(scaled_directions,dim=1) - v = torch.amin(scaled_directions,dim=1) - - # Force coordinates to be within the triangle - boundary_check = torch.where(u > d) - v[boundary_check] /= u[boundary_check] - u[boundary_check] = 1.0 - - # Precalculated barycentric values from linear matrix solve - I0 = 1 - u - I1 = u - v - I2 = v - - # Make first selections and proper interps - selections = best[:,0] + 1 - interps = I1 - even_sels = torch.where(selections % 2 == 0) - interps[even_sels] = I2[even_sels] # Corners get different weights - - # Make new edges for the central selections - central_edges = torch.clone(edge_index).to(edge_index.device) - central_selections = torch.zeros_like(selections) - central_interps = I0 - - # Make new edges for the last selection - pos_locs = torch.where(second_best==1)[0] - pos_edges = edge_index[:,pos_locs] - pos_selections = selections[pos_locs] + 1 - pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around - pos_interps = I1[pos_locs] - even_sels = torch.where(pos_selections % 2 == 0) - pos_interps[even_sels] = I2[pos_locs][even_sels] - - neg_locs = torch.where(second_best==0)[0] - neg_edges = edge_index[:,neg_locs] - neg_selections = selections[neg_locs] - 1 - neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around - neg_interps = I1[neg_locs] - even_sels = torch.where(neg_selections % 2 == 0) - neg_interps[even_sels] = I2[neg_locs][even_sels] - - # Account for the previously pruned same node edges - same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long) - same_interps = torch.ones(same_edges.shape[1],dtype=torch.float) - - # Combine - edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1) - selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0) - interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0) - - #edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1) - #selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0) - #interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0) - - # Account for edges to the same node - #same_locs = torch.where(edge_index[0] == edge_index[1]) - #selections[same_locs] = 0 - #interps[same_locs] = 1 - - return edge_index,selections,interps - -def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False): - '''Given an edge_index and selections, normalize the edges for each node so that - aggregation of edges with interps = 1. If interps is given, use a weighted average. - if kernel_norm = True, account for missing selections by increasing weight on other selections.''' - - N = torch.max(edge_index) + 1 - S = torch.max(selections) + 1 - - total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device) - - if interps is None: - interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device) - - # Aggregate all edges to determine normalizations per selection - nodes = edge_index[0] - #total_weight[nodes,selections] += interps - total_weight.index_put_((nodes,selections),interps,accumulate=True) - - # Reassign interps accordingly - if kernel_norm: - row_totals = torch.sum(total_weight,dim=1) - interps = interps * S/row_totals[nodes] - else: - norms = total_weight[nodes,selections] - norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error - interps = interps/norms - - return interps - -def simplifyGraph(edge_index,selections,edge_lengths): - # Take the shortest edge for the set of the same selections on a given node - num_edges = edge_index.shape[1] - - # Keep track of which nodes have been visited - keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device) - - previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device) - previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device) - - for i in range(num_edges): - start_node = edge_index[0,i] - #end_node = edge_index[1,i] - selection = selections[i] - distance = edge_lengths[i] - - if distance < previous_best_distance[start_node,selection]: - previous_best_distance[start_node,selection] = distance - keep_edges[i] = True - - prev = previous_best_edge[start_node,selection] - if prev != -1: - keep_edges[prev] = False - - previous_best_edge[start_node,selection] = i - - edge_index = edge_index[:,torch.where(keep_edges)[0]] - selections = selections[torch.where(keep_edges)] - - return edge_index, selections - diff --git a/FastSplatStyler_huggingface/graph_io.py b/FastSplatStyler_huggingface/graph_io.py deleted file mode 100644 index bce2bb09aff7752716898e64e86fa519af7820b4..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_io.py +++ /dev/null @@ -1,306 +0,0 @@ -import numpy as np -import torch -from torch_geometric.nn import knn -from torch_geometric.data import Data -from torch_geometric.nn import radius_graph, knn_graph - -import graph_helpers as gh -import sphere_helpers as sh -import mesh_helpers as mh -import clusters as cl -import utils - -from torch_scatter import scatter - -import math -from math import pi, sqrt - -from warnings import warn - -def image2Graph(data, gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'): - - _,ch,rows,cols = data.shape - - x = torch.reshape(data,(ch,rows*cols)).permute((1,0)).to(device) - - if mask is not None: - # Mask out nodes - node_mask = torch.where(mask.flatten()) - x = x[node_mask] - - if gt is not None: - y = gt.flatten().to(device) - if mask is not None: - y = y[node_mask] - - if x_only: - if gt is not None: - return x,y - else: - return x - - im_pos = gh.getImPos(rows,cols) - - if mask is not None: - im_pos = im_pos[node_mask] - - # Make "point cloud" for clustering - pos2D = gh.convertImPos(im_pos,flip_y=False) - - # Generate initial graph - edge_index = gh.grid2Edges(pos2D) - directions = pos2D[edge_index[1]] - pos2D[edge_index[0]] - selections = gh.edges2Selections(edge_index,directions,interpolated=False,y_down=True) - - # Generate info for downsampled versions of the graph - clusters, edge_indexes, selections_list = cl.makeImageClusters(pos2D,cols,rows,edge_index,selections,depth=depth,device=device) - - # Make final graph and metadata needed for mapping the result after going through the network - graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=None) - metadata = Data(original=data,im_pos=im_pos.long(),rows=rows,cols=cols,ch=ch) - - if gt is not None: - graph.y = y - - return graph,metadata - -def graph2Image(result,metadata,canvas=None): - - x = utils.toNumpy(result,permute=False) - im_pos = utils.toNumpy(metadata.im_pos,permute=False) - if canvas is None: - canvas = utils.makeCanvas(x,metadata.original) - - # Paint over the original image (neccesary for masked images) - canvas[im_pos[:,0],im_pos[:,1]] = x - - return canvas - -### Begin Interpolated Methods ### - -def sphere2Graph(data, structure="layering", cluster_method="layering", scale=1.0, stride=2, interpolation_mode = "angle", gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'): - - _,ch,rows,cols = data.shape - - if structure == "equirec": - # Use the original data to start with - cartesian, spherical = sh.sampleSphere_Equirec(scale*rows,scale*cols) - elif structure == "layering": - cartesian, spherical = sh.sampleSphere_Layering(scale*rows) - elif structure == "spiral": - cartesian, spherical = sh.sampleSphere_Spiral(scale*rows,scale*cols) - elif structure == "icosphere": - cartesian, spherical = sh.sampleSphere_Icosphere(scale*rows) - elif structure == "random": - cartesian, spherical = sh.sampleSphere_Random(scale*rows,scale*cols) - else: - raise ValueError("Sphere structure unknown") - - if interpolation_mode == "bary": - bary_d = pi/(scale*rows) - else: - bary_d = None - - # Get the landing point for each node - sample_x, sample_y = sh.spherical2equirec(spherical[:,0],spherical[:,1],rows,cols) - - if mask is not None: - - node_mask = gh.maskPoints(mask,sample_x,sample_y) - sample_x = sample_x[node_mask] - sample_y = sample_y[node_mask] - spherical = spherical[node_mask] - cartesian = cartesian[node_mask] - - features = utils.bilinear_interpolate(data, sample_x, sample_y).to(device) - - if gt is not None: - features_y = utils.bilinear_interpolate(gt.unsqueeze(0), sample_x, sample_y).to(device) - - if x_only: - if gt is not None: - return features,features_y - else: - return features - - # Build initial graph - edge_index,directions = gh.surface2Edges(cartesian,cartesian) - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d) - - # Generate info for downsampled versions of the graph - clusters, edge_indexes, selections_list, interps_list = cl.makeSphereClusters(cartesian,edge_index,selections,interps,rows*scale,cols*scale,cluster_method,stride=stride,bary_d=bary_d,depth=depth,device=device) - - # Make final graph and metadata needed for mapping the result after going through the network - graph = Data(x=features,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) - metadata = Data(original=data,pos3D=cartesian,mask=mask,rows=rows,cols=cols,ch=ch) - - if gt is not None: - graph.y = features_y - - return graph, metadata - -def graph2Sphere(features,metadata): - - # Generate equirectangular points and their 3D locations - theta, phi = sh.equirec2spherical(metadata.rows, metadata.cols) - x,y,z = sh.spherical2xyz(theta,phi) - - v = torch.stack((x,y,z),dim=1) - - # Find closest 3D point to each equirectangular point - nearest = torch.reshape(knn(metadata.pos3D,v,3)[1],(len(v),3)) - - #Interpolate based on proximty to each node - w0 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,0]]),dim=1, keepdim=True).to(features.device) - w1 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,1]]),dim=1, keepdim=True).to(features.device) - w2 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,2]]),dim=1, keepdim=True).to(features.device) - - w0 = torch.nan_to_num(w0, nan=1e6) - w1 = torch.nan_to_num(w1, nan=1e6) - w2 = torch.nan_to_num(w2, nan=1e6) - - w0 = torch.clamp(w0,0,1e6) - w1 = torch.clamp(w1,0,1e6) - w2 = torch.clamp(w2,0,1e6) - - total = w0 + w1 + w2 - - #w0,w1,w2 = mh.getBarycentricWeights(v,metadata.pos3D[nearest[:,0]],metadata.pos3D[nearest[:,1]],metadata.pos3D[nearest[:,2]]) - - #w0 = w0.unsqueeze(1).to(features.device) - #w1 = w1.unsqueeze(1).to(features.device) - #w2 = w2.unsqueeze(1).to(features.device) - - result = (w0*features[nearest[:,0]] + w1*features[nearest[:,1]] + w2*features[nearest[:,2]])/total - - #result = result.clamp(0,1) - - if hasattr(metadata,"mask"): - mask = utils.toNumpy(metadata.mask.squeeze(),permute=False) - canvas = utils.makeCanvas(result,metadata.original) - result = np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1])) - canvas[np.where(mask)] = result[np.where(mask)] - return canvas - else: - return np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1])) - - - -def splat2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, depth = 1, device = 'cpu'): - """ Sample mesh faces to determine graph """ - - if up_vector == None: - up_vector = torch.tensor([[1,1,1]],dtype=torch.float) - #up_vector = 2*torch.rand((1,3))-1 - up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) - - - #position, normal vector, uv coordinates in the texture map, x is color - pos3D, normals = mh.sampleSurface(mesh,N) - - # Build initial graph - #edge_index are neighbors of a point, directions are the directions from that point - edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) - #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) - - # Generate info for downsampled versions of the graph - clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - - # Make final graph and metadata needed for mapping the result after going through the network - graph = Data(clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) - metadata = Data(original=data,pos3D=pos3D,mesh=mesh) - - return graph,metadata - -def mesh2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, mask = None, depth = 1, x_only = False, device = 'cpu'): - """ Sample mesh faces to determine graph """ - - if up_vector == None: - up_vector = torch.tensor([[1,1,1]],dtype=torch.float) - #up_vector = 2*torch.rand((1,3))-1 - up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) - - if mask is not None: - warn("Masks are not currently implemented for mesh graphs") - - #position, normal vector, uv coordinates in the texture map, x is color - pos3D, normals, uvs, x = mh.sampleSurface(mesh,N,return_x=True) - - x = x.to(device) - - if x_only: - warn("x_only returns randomly selected points for mesh2Graph. Do not use with previous graph structures") - return x - - # Build initial graph - #edge_index are neighbors of a point, directions are the directions from that point - edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) - #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) - - # Generate info for downsampled versions of the graph - clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - - # Make final graph and metadata needed for mapping the result after going through the network - graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) - metadata = Data(original=data,pos3D=pos3D,uvs=uvs,mesh=mesh) - - return graph,metadata - -def graph2Splat(features,metadata,view3D=False): - - features = features.cpu().numpy() - - canvas = utils.toNumpy(metadata.original) - rows,cols,ch = canvas.shape - - # Get 2D positions by scaling uv - pos2D = metadata.uvs.cpu().numpy() - pos2D[:,0] = pos2D[:,0]*cols - pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom - pos2D[:,1] = pos2D[:,1]*rows - - # Generate desired points - row_space = np.arange(rows) - col_space = np.arange(cols) - col_image,row_image = np.meshgrid(col_space,row_space) - - canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image) - canvas = np.clip(canvas,0,1) - - if view3D: - mesh = mh.setTexture(metadata.mesh,canvas) - mesh.show() - - return canvas - - -def graph2Mesh(features,metadata,view3D=False): - - features = features.cpu().numpy() - - canvas = utils.toNumpy(metadata.original) - rows,cols,ch = canvas.shape - - # Get 2D positions by scaling uv - pos2D = metadata.uvs.cpu().numpy() - pos2D[:,0] = pos2D[:,0]*cols - pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom - pos2D[:,1] = pos2D[:,1]*rows - - # Generate desired points - row_space = np.arange(rows) - col_space = np.arange(cols) - col_image,row_image = np.meshgrid(col_space,row_space) - - canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image) - canvas = np.clip(canvas,0,1) - - if view3D: - mesh = mh.setTexture(metadata.mesh,canvas) - mesh.show() - - return canvas diff --git a/FastSplatStyler_huggingface/graph_networks/.DS_Store b/FastSplatStyler_huggingface/graph_networks/.DS_Store deleted file mode 100644 index 5562a58c1a663d343a951dca0a853bba3e02bafc..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/.DS_Store and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/.DS_Store b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/.DS_Store deleted file mode 100644 index 90435c4f3587ced903d7224eece2021505e2f881..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/.DS_Store and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/LICENSE b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/LICENSE deleted file mode 100644 index b3fb3a86103b77107cc211e4aa9224fd79076f97..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/LICENSE +++ /dev/null @@ -1,25 +0,0 @@ -BSD 2-Clause License - -Copyright (c) 2018, SunshineAtNoon -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/README.md b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/README.md deleted file mode 100644 index 8fc41f7b2d4f5ce38771709cc7730e93125a7098..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/README.md +++ /dev/null @@ -1,102 +0,0 @@ -## Learning Linear Transformations for Fast Image and Video Style Transfer -**[[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Learning_Linear_Transformations_for_Fast_Image_and_Video_Style_Transfer_CVPR_2019_paper.pdf)** **[[Project Page]](https://sites.google.com/view/linear-style-transfer-cvpr19/)** - - - - -## Prerequisites -- [Pytorch](http://pytorch.org/) -- [torchvision](https://github.com/pytorch/vision) -- [opencv](https://opencv.org/) for video generation - -**All code tested on Ubuntu 16.04, pytorch 0.4.1, and opencv 3.4.2** - -## Style Transfer -- Clone from github: `git clone https://github.com/sunshineatnoon/LinearStyleTransfer` -- Download pre-trained models from [google drive](https://drive.google.com/file/d/1H9T5rfXGlGCUh04DGkpkMFbVnmscJAbs/view?usp=sharing). -- Uncompress to root folder : -``` -cd LinearStyleTransfer -unzip models.zip -rm models.zip -``` - -#### Artistic style transfer -``` -python TestArtistic.py -``` -or conduct style transfer on relu_31 features -``` -python TestArtistic.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31 -``` - -#### Photo-realistic style transfer -For photo-realistic style transfer, we need first compile the [pytorch_spn](https://github.com/Liusifei/pytorch_spn) repository. -``` -cd libs/pytorch_spn -sh make.sh -cd ../.. -``` -Then: -``` -python TestPhotoReal.py -``` -Note: images with `_filtered.png` as postfix are images filtered by the SPN after style transfer, images with `_smooth.png` as postfix are images post process by a [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py). - -#### Video style transfer -``` -python TestVideo.py -``` - -#### Real-time video demo -``` -python real-time-demo.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --matrixPath models/r31.pth --layer r31 -``` - -## Model Training -### Data Preparation -- MSCOCO -``` -wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip -``` -- WikiArt - - Either manually download from [kaggle](https://www.kaggle.com/c/painter-by-numbers). - - Or install [kaggle-cli](https://github.com/floydwch/kaggle-cli) and download by running: - ``` - kg download -u -p -c painter-by-numbers -f train.zip - ``` - -### Training -#### Train a style transfer model -To train a model that transfers relu4_1 features, run: -``` -python Train.py --vgg_dir models/vgg_r41.pth --decoder_dir models/dec_r41.pth --layer r41 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR -``` -or train a model that transfers relu3_1 features: -``` -python Train.py --vgg_dir models/vgg_r31.pth --decoder_dir models/dec_r31.pth --layer r31 --contentPath PATH_TO_MSCOCO --stylePath PATH_TO_WikiArt --outf OUTPUT_DIR -``` -Key hyper-parameters: -- style_layers: which features to compute style loss. -- style_weight: larger style weight leads to heavier style in transferred images. - -Intermediate results and weight will be stored in `OUTPUT_DIR` - -#### Train a SPN model to cancel distortions for photo-realistic style transfer -Run: -``` -python TrainSPN.py --contentPath PATH_TO_MSCOCO -``` - -### Acknowledgement -- We use the [smooth filter](https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py) by [LouieYang](https://github.com/LouieYang) in the photo-realistic style transfer. - -### Citation -``` -@inproceedings{li2018learning, - author = {Li, Xueting and Liu, Sifei and Kautz, Jan and Yang, Ming-Hsuan}, - title = {Learning Linear Transformations for Fast Arbitrary Style Transfer}, - booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, - year = {2019} -} -``` diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestArtistic.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestArtistic.py deleted file mode 100644 index 8435542c201b9448145a8642de0cd4111d715ee5..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestArtistic.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import torch -import argparse -from libs.Loader import Dataset -from libs.Matrix import MulLayer -import torchvision.utils as vutils -import torch.backends.cudnn as cudnn -from libs.utils import print_options -from libs.models import encoder3,encoder4, encoder5 -from libs.models import decoder3,decoder4, decoder5 - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', - help='pre-trained encoder path') -parser.add_argument("--decoder_dir", default='models/dec_r41.pth', - help='pre-trained decoder path') -parser.add_argument("--matrixPath", default='models/r41.pth', - help='pre-trained model path') -parser.add_argument("--stylePath", default="./data/style/", - help='path to style image') -parser.add_argument("--contentPath", default="./data/content/", - help='path to frames') -parser.add_argument("--outf", default="Artistic/", - help='path to transferred images') -parser.add_argument("--batchSize", type=int,default=1, - help='batch size') -parser.add_argument('--loadSize', type=int, default=256, - help='scale image size') -parser.add_argument('--fineSize', type=int, default=256, - help='crop image size') -parser.add_argument("--layer", default="r41", - help='which features to transfer, either r31 or r41') - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.cuda = torch.cuda.is_available() -print_options(opt) - -os.makedirs(opt.outf,exist_ok=True) -cudnn.benchmark = True - -################# DATA ################# -content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize,test=True) -content_loader = torch.utils.data.DataLoader(dataset=content_dataset, - batch_size = opt.batchSize, - shuffle = False) - #num_workers = 1) -style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize,test=True) -style_loader = torch.utils.data.DataLoader(dataset=style_dataset, - batch_size = opt.batchSize, - shuffle = False) - #num_workers = 1) - -################# MODEL ################# -if(opt.layer == 'r31'): - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - vgg = encoder4() - dec = decoder4() -matrix = MulLayer(opt.layer) -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.decoder_dir)) -matrix.load_state_dict(torch.load(opt.matrixPath)) - -################# GLOBAL VARIABLE ################# -contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) -styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - matrix.cuda() - contentV = contentV.cuda() - styleV = styleV.cuda() - -for ci,(content,contentName) in enumerate(content_loader): - contentName = contentName[0] - contentV.resize_(content.size()).copy_(content) - for sj,(style,styleName) in enumerate(style_loader): - styleName = styleName[0] - styleV.resize_(style.size()).copy_(style) - - # forward - with torch.no_grad(): - sF = vgg(styleV) - cF = vgg(contentV) - - if(opt.layer == 'r41'): - feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) - else: - feature,transmatrix = matrix(cF,sF) - transfer = dec(feature) - - transfer = transfer.clamp(0,1) - vutils.save_image(transfer,'%s/%s_%s.png'%(opt.outf,contentName,styleName),normalize=True,scale_each=True,nrow=opt.batchSize) - print('Transferred image saved at %s%s_%s.png'%(opt.outf,contentName,styleName)) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestPhotoReal.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestPhotoReal.py deleted file mode 100644 index e00c7bb3468dfd8e678076a7b384300c521bf8a7..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestPhotoReal.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -import cv2 -import time -import torch -import argparse -import numpy as np -from PIL import Image -from libs.SPN import SPN -import torchvision.utils as vutils -from libs.utils import print_options -from libs.MatrixTest import MulLayer -import torch.backends.cudnn as cudnn -from libs.LoaderPhotoReal import Dataset -from libs.models import encoder3,encoder4 -from libs.models import decoder3,decoder4 -import torchvision.transforms as transforms -from libs.smooth_filter import smooth_filter - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', - help='pre-trained encoder path') -parser.add_argument("--decoder_dir", default='models/dec_r41.pth', - help='pre-trained decoder path') -parser.add_argument("--matrixPath", default='models/r41.pth', - help='pre-trained model path') -parser.add_argument("--stylePath", default="data/photo_real/style/images/", - help='path to style image') -parser.add_argument("--styleSegPath", default="data/photo_real/styleSeg/", - help='path to style image masks') -parser.add_argument("--contentPath", default="data/photo_real/content/images/", - help='path to content image') -parser.add_argument("--contentSegPath", default="data/photo_real/contentSeg/", - help='path to content image masks') -parser.add_argument("--outf", default="PhotoReal/", - help='path to save output images') -parser.add_argument("--batchSize", type=int,default=1, - help='batch size') -parser.add_argument('--fineSize', type=int, default=512, - help='image size') -parser.add_argument("--layer", default="r41", - help='features of which layer to transform, either r31 or r41') -parser.add_argument("--spn_dir", default='models/r41_spn.pth', - help='path to pretrained SPN model') - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.cuda = torch.cuda.is_available() -print_options(opt) - -os.makedirs(opt.outf, exist_ok=True) - -cudnn.benchmark = True - -################# DATA ################# -dataset = Dataset(opt.contentPath,opt.stylePath,opt.contentSegPath,opt.styleSegPath,opt.fineSize) -loader = torch.utils.data.DataLoader(dataset=dataset, - batch_size=1, - shuffle=False) - -################# MODEL ################# -if(opt.layer == 'r31'): - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - vgg = encoder4() - dec = decoder4() -matrix = MulLayer(opt.layer) -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.decoder_dir)) -matrix.load_state_dict(torch.load(opt.matrixPath)) -spn = SPN() -spn.load_state_dict(torch.load(opt.spn_dir)) - -################# GLOBAL VARIABLE ################# -contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) -styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) -whitenV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - spn.cuda() - matrix.cuda() - contentV = contentV.cuda() - styleV = styleV.cuda() - whitenV = whitenV.cuda() - -for i,(contentImg,styleImg,whitenImg,cmasks,smasks,imname) in enumerate(loader): - imname = imname[0] - contentV.resize_(contentImg.size()).copy_(contentImg) - styleV.resize_(styleImg.size()).copy_(styleImg) - whitenV.resize_(whitenImg.size()).copy_(whitenImg) - - # forward - sF = vgg(styleV) - cF = vgg(contentV) - - with torch.no_grad(): - if(opt.layer == 'r41'): - feature = matrix(cF[opt.layer],sF[opt.layer],cmasks,smasks) - else: - feature = matrix(cF,sF,cmasks,smasks) - transfer = dec(feature) - filtered = spn(transfer,whitenV) - vutils.save_image(transfer,os.path.join(opt.outf,'%s_transfer.png'%(imname.split('.')[0]))) - - filtered = filtered.clamp(0,1) - filtered = filtered.cpu() - vutils.save_image(filtered,'%s/%s_filtered.png'%(opt.outf,imname.split('.')[0])) - out_img = filtered.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy() - content = contentImg.squeeze(0).mul(255).clamp(0,255).byte().permute(1,2,0).cpu().numpy() - content = content.copy() - out_img = out_img.copy() - smoothed = smooth_filter(out_img, content, f_radius=15, f_edge=1e-1) - smoothed.save('%s/%s_smooth.png'%(opt.outf,imname.split('.')[0])) - print('Transferred image saved at %s%s, filtered image saved at %s%s_filtered.png' \ - %(opt.outf,imname,opt.outf,imname.split('.')[0])) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestVideo.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestVideo.py deleted file mode 100644 index 94323aac35353cdd4723048c1d52aa1534608998..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TestVideo.py +++ /dev/null @@ -1,108 +0,0 @@ -import os -import torch -import argparse -from PIL import Image -from libs.Loader import Dataset -from libs.Matrix import MulLayer -import torch.backends.cudnn as cudnn -from libs.models import encoder3,encoder4 -from libs.models import decoder3,decoder4 -import torchvision.transforms as transforms -from libs.utils import makeVideo, print_options - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r31.pth', - help='pre-trained encoder path') -parser.add_argument("--decoder_dir", default='models/dec_r31.pth', - help='pre-trained decoder path') -parser.add_argument("--matrix_dir", default="models/r31.pth", - help='path to pre-trained model') -parser.add_argument("--style", default="data/style/in2.jpg", - help='path to style image') -parser.add_argument("--content_dir", default="data/videos/content/mountain_2/", - help='path to video frames') -parser.add_argument('--loadSize', type=int, default=512, - help='scale image size') -parser.add_argument('--fineSize', type=int, default=512, - help='crop image size') -parser.add_argument("--name",default="transferred_video", - help="name of generated video") -parser.add_argument("--layer",default="r31", - help="features of which layer to transform") -parser.add_argument("--outf",default="videos", - help="output folder") - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.cuda = torch.cuda.is_available() -print_options(opt) - -os.makedirs(opt.outf,exist_ok=True) -cudnn.benchmark = True - -################# DATA ################# -def loadImg(imgPath): - img = Image.open(imgPath).convert('RGB') - transform = transforms.Compose([ - transforms.Scale(opt.fineSize), - transforms.ToTensor()]) - return transform(img) -styleV = loadImg(opt.style).unsqueeze(0) - -content_dataset = Dataset(opt.content_dir, - loadSize = opt.loadSize, - fineSize = opt.fineSize, - test = True, - video = True) -content_loader = torch.utils.data.DataLoader(dataset = content_dataset, - batch_size = 1, - shuffle = False) - -################# MODEL ################# -if(opt.layer == 'r31'): - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - vgg = encoder4() - dec = decoder4() -matrix = MulLayer(layer=opt.layer) -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.decoder_dir)) -matrix.load_state_dict(torch.load(opt.matrix_dir)) - -################# GLOBAL VARIABLE ################# -contentV = torch.Tensor(1,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - matrix.cuda() - - styleV = styleV.cuda() - contentV = contentV.cuda() - -result_frames = [] -contents = [] -style = styleV.squeeze(0).cpu().numpy() -sF = vgg(styleV) - -for i,(content,contentName) in enumerate(content_loader): - print('Transfer frame %d...'%i) - contentName = contentName[0] - contentV.resize_(content.size()).copy_(content) - contents.append(content.squeeze(0).float().numpy()) - # forward - with torch.no_grad(): - cF = vgg(contentV) - - if(opt.layer == 'r41'): - feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) - else: - feature,transmatrix = matrix(cF,sF) - transfer = dec(feature) - - transfer = transfer.clamp(0,1) - result_frames.append(transfer.squeeze(0).cpu().numpy()) - -makeVideo(contents,style,result_frames,opt.outf) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/Train.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/Train.py deleted file mode 100644 index 9c4a1bda9c99c63557f9d96739d5475cf56db3cb..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/Train.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -import torch -import argparse -import torch.nn as nn -import torch.optim as optim -from libs.Loader import Dataset -from libs.Matrix import MulLayer -import torchvision.utils as vutils -import torch.backends.cudnn as cudnn -from libs.utils import print_options -from libs.Criterion import LossCriterion -from libs.models import encoder3,encoder4 -from libs.models import decoder3,decoder4 -from libs.models import encoder5 as loss_network - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', - help='pre-trained encoder path') -parser.add_argument("--loss_network_dir", default='models/vgg_r51.pth', - help='used for loss network') -parser.add_argument("--decoder_dir", default='models/dec_r41.pth', - help='pre-trained decoder path') -parser.add_argument("--stylePath", default="/home/xtli/DATA/wikiArt/train/images/", - help='path to wikiArt dataset') -parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/", - help='path to MSCOCO dataset') -parser.add_argument("--outf", default="trainingOutput/", - help='folder to output images and model checkpoints') -parser.add_argument("--content_layers", default="r41", - help='layers for content') -parser.add_argument("--style_layers", default="r11,r21,r31,r41", - help='layers for style') -parser.add_argument("--batchSize", type=int,default=8, - help='batch size') -parser.add_argument("--niter", type=int,default=100000, - help='iterations to train the model') -parser.add_argument('--loadSize', type=int, default=300, - help='scale image size') -parser.add_argument('--fineSize', type=int, default=256, - help='crop image size') -parser.add_argument("--lr", type=float, default=1e-4, - help='learning rate') -parser.add_argument("--content_weight", type=float, default=1.0, - help='content loss weight') -parser.add_argument("--style_weight", type=float, default=0.02, - help='style loss weight') -parser.add_argument("--log_interval", type=int, default=500, - help='log interval') -parser.add_argument("--gpu_id", type=int, default=0, - help='which gpu to use') -parser.add_argument("--save_interval", type=int, default=5000, - help='checkpoint save interval') -parser.add_argument("--layer", default="r41", - help='which features to transfer, either r31 or r41') - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.content_layers = opt.content_layers.split(',') -opt.style_layers = opt.style_layers.split(',') -opt.cuda = torch.cuda.is_available() -if(opt.cuda): - torch.cuda.set_device(opt.gpu_id) - -os.makedirs(opt.outf,exist_ok=True) -cudnn.benchmark = True -print_options(opt) - -################# DATA ################# -content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize) -content_loader_ = torch.utils.data.DataLoader(dataset = content_dataset, - batch_size = opt.batchSize, - shuffle = True, - num_workers = 1, - drop_last = True) -content_loader = iter(content_loader_) -style_dataset = Dataset(opt.stylePath,opt.loadSize,opt.fineSize) -style_loader_ = torch.utils.data.DataLoader(dataset = style_dataset, - batch_size = opt.batchSize, - shuffle = True, - num_workers = 1, - drop_last = True) -style_loader = iter(style_loader_) - -################# MODEL ################# -vgg5 = loss_network() -if(opt.layer == 'r31'): - matrix = MulLayer('r31') - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - matrix = MulLayer('r41') - vgg = encoder4() - dec = decoder4() -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.decoder_dir)) -vgg5.load_state_dict(torch.load(opt.loss_network_dir)) - -for param in vgg.parameters(): - param.requires_grad = False -for param in vgg5.parameters(): - param.requires_grad = False -for param in dec.parameters(): - param.requires_grad = False - -################# LOSS & OPTIMIZER ################# -criterion = LossCriterion(opt.style_layers, - opt.content_layers, - opt.style_weight, - opt.content_weight) -optimizer = optim.Adam(matrix.parameters(), opt.lr) - -################# GLOBAL VARIABLE ################# -contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) -styleV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - vgg5.cuda() - matrix.cuda() - contentV = contentV.cuda() - styleV = styleV.cuda() - -################# TRAINING ################# -def adjust_learning_rate(optimizer, iteration): - """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" - for param_group in optimizer.param_groups: - param_group['lr'] = opt.lr / (1+iteration*1e-5) - -for iteration in range(1,opt.niter+1): - optimizer.zero_grad() - try: - content,_ = content_loader.next() - except IOError: - content,_ = content_loader.next() - except StopIteration: - content_loader = iter(content_loader_) - content,_ = content_loader.next() - except: - continue - - try: - style,_ = style_loader.next() - except IOError: - style,_ = style_loader.next() - except StopIteration: - style_loader = iter(style_loader_) - style,_ = style_loader.next() - except: - continue - - contentV.resize_(content.size()).copy_(content) - styleV.resize_(style.size()).copy_(style) - - # forward - sF = vgg(styleV) - cF = vgg(contentV) - - if(opt.layer == 'r41'): - feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) - else: - feature,transmatrix = matrix(cF,sF) - transfer = dec(feature) - - sF_loss = vgg5(styleV) - cF_loss = vgg5(contentV) - tF = vgg5(transfer) - loss,styleLoss,contentLoss = criterion(tF,sF_loss,cF_loss) - - # backward & optimization - loss.backward() - optimizer.step() - print('Iteration: [%d/%d] Loss: %.4f contentLoss: %.4f styleLoss: %.4f Learng Rate is %.6f'% - (opt.niter,iteration,loss,contentLoss,styleLoss,optimizer.param_groups[0]['lr'])) - - adjust_learning_rate(optimizer,iteration) - - if((iteration) % opt.log_interval == 0): - transfer = transfer.clamp(0,1) - concat = torch.cat((content,style,transfer.cpu()),dim=0) - vutils.save_image(concat,'%s/%d.png'%(opt.outf,iteration),normalize=True,scale_each=True,nrow=opt.batchSize) - - if(iteration > 0 and (iteration) % opt.save_interval == 0): - torch.save(matrix.state_dict(), '%s/%s.pth' % (opt.outf,opt.layer)) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TrainSPN.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TrainSPN.py deleted file mode 100644 index b188fb40be531bd96d96aacb134edbd0bdebb7c5..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/TrainSPN.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import print_function -import os -import argparse - -from libs.SPN import SPN -from libs.Loader import Dataset -from libs.models import encoder4 -from libs.models import decoder4 -from libs.utils import print_options - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import torchvision.utils as vutils -import torch.backends.cudnn as cudnn -import torchvision.transforms as transforms - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r41.pth', - help='pre-trained encoder path') -parser.add_argument("--decoder_dir", default='models/dec_r41.pth', - help='pre-trained decoder path') -parser.add_argument("--contentPath", default="/home/xtli/DATA/MSCOCO/train2014/images/", - help='path to MSCOCO dataset') -parser.add_argument("--outf", default="trainingSPNOutput/", - help='folder to output images and model checkpoints') -parser.add_argument("--layer", default="r41", - help='layers for content') -parser.add_argument("--batchSize", type=int,default=8, - help='batch size') -parser.add_argument("--niter", type=int,default=100000, - help='iterations to train the model') -parser.add_argument('--loadSize', type=int, default=512, - help='scale image size') -parser.add_argument('--fineSize', type=int, default=256, - help='crop image size') -parser.add_argument("--lr", type=float, default=1e-3, - help='learning rate') -parser.add_argument("--log_interval", type=int, default=500, - help='log interval') -parser.add_argument("--save_interval", type=int, default=5000, - help='checkpoint save interval') -parser.add_argument("--spn_num", type=int, default=1, - help='number of spn filters') - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.cuda = torch.cuda.is_available() -print_options(opt) - - -os.makedirs(opt.outf, exist_ok = True) - -cudnn.benchmark = True - -################# DATA ################# -content_dataset = Dataset(opt.contentPath,opt.loadSize,opt.fineSize) -content_loader_ = torch.utils.data.DataLoader(dataset=content_dataset, - batch_size = opt.batchSize, - shuffle = True, - num_workers = 4, - drop_last = True) -content_loader = iter(content_loader_) - -################# MODEL ################# -spn = SPN(spn=opt.spn_num) -if(opt.layer == 'r31'): - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - vgg = encoder4() - dec = decoder4() -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.decoder_dir)) - -for param in vgg.parameters(): - param.requires_grad = False -for param in dec.parameters(): - param.requires_grad = False - -################# LOSS & OPTIMIZER ################# -criterion = nn.MSELoss(size_average=False) -#optimizer_spn = optim.SGD(spn.parameters(), opt.lr) -optimizer_spn = optim.Adam(spn.parameters(), opt.lr) - -################# GLOBAL VARIABLE ################# -contentV = torch.Tensor(opt.batchSize,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - spn.cuda() - contentV = contentV.cuda() - -################# TRAINING ################# -def adjust_learning_rate(optimizer, iteration): - for param_group in optimizer.param_groups: - param_group['lr'] = opt.lr / (1+iteration*1e-5) - -spn.train() -for iteration in range(1,opt.niter+1): - optimizer_spn.zero_grad() - try: - content,_ = content_loader.next() - except IOError: - content,_ = content_loader.next() - except StopIteration: - content_loader = iter(content_loader_) - content,_ = content_loader.next() - except: - continue - - contentV.resize_(content.size()).copy_(content) - - # forward - cF = vgg(contentV) - transfer = dec(cF['r41']) - - - propagated = spn(transfer,contentV) - loss = criterion(propagated,contentV) - - # backward & optimization - loss.backward() - #nn.utils.clip_grad_norm(spn.parameters(), 1000) - optimizer_spn.step() - print('Iteration: [%d/%d] Loss: %.4f Learng Rate is %.6f' - %(opt.niter,iteration,loss,optimizer_spn.param_groups[0]['lr'])) - - adjust_learning_rate(optimizer_spn,iteration) - - if((iteration) % opt.log_interval == 0): - transfer = transfer.clamp(0,1) - propagated = propagated.clamp(0,1) - vutils.save_image(transfer,'%s/%d_transfer.png'%(opt.outf,iteration)) - vutils.save_image(propagated,'%s/%d_propagated.png'%(opt.outf,iteration)) - - if(iteration > 0 and (iteration) % opt.save_interval == 0): - torch.save(spn.state_dict(), '%s/%s_spn.pth' % (opt.outf,opt.layer)) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/.DS_Store b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/.DS_Store deleted file mode 100644 index 570ec19abcb7d1dfa0fd6f0f8eaac9d721f485e6..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/.DS_Store and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Criterion.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Criterion.py deleted file mode 100644 index 3e287d77c891268ef7b2c7ef9a5b55f6303205c0..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Criterion.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import torch.nn as nn - -class styleLoss(nn.Module): - def forward(self,input,target): - ib,ic,ih,iw = input.size() - iF = input.view(ib,ic,-1) - iMean = torch.mean(iF,dim=2) - iCov = GramMatrix()(input) - - tb,tc,th,tw = target.size() - tF = target.view(tb,tc,-1) - tMean = torch.mean(tF,dim=2) - tCov = GramMatrix()(target) - - loss = nn.MSELoss(size_average=False)(iMean,tMean) + nn.MSELoss(size_average=False)(iCov,tCov) - return loss/tb - -class GramMatrix(nn.Module): - def forward(self,input): - b, c, h, w = input.size() - f = input.view(b,c,h*w) # bxcx(hxw) - # torch.bmm(batch1, batch2, out=None) # - # batch1: bxmxp, batch2: bxpxn -> bxmxn # - G = torch.bmm(f,f.transpose(1,2)) # f: bxcx(hxw), f.transpose: bx(hxw)xc -> bxcxc - return G.div_(c*h*w) - -class LossCriterion(nn.Module): - def __init__(self,style_layers,content_layers,style_weight,content_weight): - super(LossCriterion,self).__init__() - - self.style_layers = style_layers - self.content_layers = content_layers - self.style_weight = style_weight - self.content_weight = content_weight - - self.styleLosses = [styleLoss()] * len(style_layers) - self.contentLosses = [nn.MSELoss()] * len(content_layers) - - def forward(self,tF,sF,cF): - # content loss - totalContentLoss = 0 - for i,layer in enumerate(self.content_layers): - cf_i = cF[layer] - cf_i = cf_i.detach() - tf_i = tF[layer] - loss_i = self.contentLosses[i] - totalContentLoss += loss_i(tf_i,cf_i) - totalContentLoss = totalContentLoss * self.content_weight - - # style loss - totalStyleLoss = 0 - for i,layer in enumerate(self.style_layers): - sf_i = sF[layer] - sf_i = sf_i.detach() - tf_i = tF[layer] - loss_i = self.styleLosses[i] - totalStyleLoss += loss_i(tf_i,sf_i) - totalStyleLoss = totalStyleLoss * self.style_weight - loss = totalStyleLoss + totalContentLoss - - return loss,totalStyleLoss,totalContentLoss diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Loader.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Loader.py deleted file mode 100644 index 0616bc43835ebb112c9a87f5b206e3d03a5c5fe0..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Loader.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -from PIL import Image -import torch.utils.data as data -import torchvision.transforms as transforms - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) - -def default_loader(path): - return Image.open(path).convert('RGB') - -class Dataset(data.Dataset): - def __init__(self,dataPath,loadSize,fineSize,test=False,video=False): - super(Dataset,self).__init__() - self.dataPath = dataPath - self.image_list = [x for x in os.listdir(dataPath) if is_image_file(x)] - self.image_list = sorted(self.image_list) - if(video): - self.image_list = sorted(self.image_list) - if not test: - self.transform = transforms.Compose([ - transforms.Resize(fineSize), - transforms.RandomCrop(fineSize), - transforms.RandomHorizontalFlip(), - transforms.ToTensor()]) - else: - self.transform = transforms.Compose([ - transforms.Resize(fineSize), - transforms.ToTensor()]) - - self.test = test - - def __getitem__(self,index): - dataPath = os.path.join(self.dataPath,self.image_list[index]) - - Img = default_loader(dataPath) - ImgA = self.transform(Img) - - imgName = self.image_list[index] - imgName = imgName.split('.')[0] - return ImgA,imgName - - def __len__(self): - return len(self.image_list) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py deleted file mode 100644 index af1b8e07cff12ef8e72c76786d9d1714adda405a..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/LoaderPhotoReal.py +++ /dev/null @@ -1,162 +0,0 @@ -from PIL import Image -import torchvision.transforms as transforms -import torchvision.utils as vutils -import torch.utils.data as data -from os import listdir -from os.path import join -import numpy as np -import torch -import os -import torch.nn as nn -from torch.autograd import Variable -import numpy as np -from libs.utils import whiten - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) - -def default_loader(path,fineSize): - img = Image.open(path).convert('RGB') - w,h = img.size - if(w < h): - neww = fineSize - newh = h * neww / w - newh = int(newh / 8) * 8 - else: - newh = fineSize - neww = w * newh / h - neww = int(neww / 8) * 8 - img = img.resize((neww,newh)) - return img - -def MaskHelper(seg,color): - # green - mask = torch.Tensor() - if(color == 'green'): - mask = torch.lt(seg[0],0.1) - mask = torch.mul(mask,torch.gt(seg[1],1-0.1)) - mask = torch.mul(mask,torch.lt(seg[2],0.1)) - elif(color == 'black'): - mask = torch.lt(seg[0], 0.1) - mask = torch.mul(mask,torch.lt(seg[1], 0.1)) - mask = torch.mul(mask,torch.lt(seg[2], 0.1)) - elif(color == 'white'): - mask = torch.gt(seg[0], 1-0.1) - mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) - mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) - elif(color == 'red'): - mask = torch.gt(seg[0], 1-0.1) - mask = torch.mul(mask,torch.lt(seg[1], 0.1)) - mask = torch.mul(mask,torch.lt(seg[2], 0.1)) - elif(color == 'blue'): - mask = torch.lt(seg[0], 0.1) - mask = torch.mul(mask,torch.lt(seg[1], 0.1)) - mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) - elif(color == 'yellow'): - mask = torch.gt(seg[0], 1-0.1) - mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) - mask = torch.mul(mask,torch.lt(seg[2], 0.1)) - elif(color == 'grey'): - mask = torch.lt(seg[0], 0.1) - mask = torch.mul(mask,torch.lt(seg[1], 0.1)) - mask = torch.mul(mask,torch.lt(seg[2], 0.1)) - elif(color == 'lightblue'): - mask = torch.lt(seg[0], 0.1) - mask = torch.mul(mask,torch.gt(seg[1], 1-0.1)) - mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) - elif(color == 'purple'): - mask = torch.gt(seg[0], 1-0.1) - mask = torch.mul(mask,torch.lt(seg[1], 0.1)) - mask = torch.mul(mask,torch.gt(seg[2], 1-0.1)) - else: - print('MaskHelper(): color not recognized, color = ' + color) - return mask.float() - -def ExtractMask(Seg): - # Given segmentation for content and style, we get a list of segmentation for each color - ''' - Test Code: - content_masks,style_masks = ExtractMask(contentSegImg,styleSegImg) - for i,mask in enumerate(content_masks): - vutils.save_image(mask,'samples/content_%d.png' % (i),normalize=True) - for i,mask in enumerate(style_masks): - vutils.save_image(mask,'samples/style_%d.png' % (i),normalize=True) - ''' - color_codes = ['blue', 'green', 'black', 'white', 'red', 'yellow', 'grey', 'lightblue', 'purple'] - masks = [] - for color in color_codes: - mask = MaskHelper(Seg,color) - masks.append(mask) - return masks - -def calculate_size(h,w,fineSize): - if(h > w): - newh = fineSize - neww = int(w * 1.0 * newh / h) - else: - neww = fineSize - newh = int(h * 1.0 * neww / w) - newh = (newh // 8) * 8 - neww = (neww // 8) * 8 - return neww, newh - -class Dataset(data.Dataset): - def __init__(self,contentPath,stylePath,contentSegPath,styleSegPath,fineSize): - super(Dataset,self).__init__() - self.contentPath = contentPath - self.image_list = [x for x in listdir(contentPath) if is_image_file(x)] - self.stylePath = stylePath - self.contentSegPath = contentSegPath - self.styleSegPath = styleSegPath - self.fineSize = fineSize - - def __getitem__(self,index): - contentImgPath = os.path.join(self.contentPath,self.image_list[index]) - styleImgPath = os.path.join(self.stylePath,self.image_list[index]) - contentImg = default_loader(contentImgPath,self.fineSize) - styleImg = default_loader(styleImgPath,self.fineSize) - - try: - contentSegImgPath = os.path.join(self.contentSegPath,self.image_list[index]) - contentSegImg = default_loader(contentSegImgPath,self.fineSize) - except : - print('no mask provided, fake a whole black one') - contentSegImg = Image.new('RGB', (contentImg.size)) - - try: - styleSegImgPath = os.path.join(self.styleSegPath,self.image_list[index]) - styleSegImg = default_loader(styleSegImgPath,self.fineSize) - except : - print('no mask provided, fake a whole black one') - styleSegImg = Image.new('RGB', (styleImg.size)) - - - hs, ws = styleImg.size - newhs, newws = calculate_size(hs,ws,self.fineSize) - - transform = transforms.Compose([ - transforms.Resize((newhs, newws)), - transforms.ToTensor()]) - # Turning segmentation images into masks - styleSegImg = transform(styleSegImg) - styleImgArbi = transform(styleImg) - - hc, wc = contentImg.size - newhc, newwc = calculate_size(hc,wc,self.fineSize) - - transform = transforms.Compose([ - transforms.Resize((newhc, newwc)), - transforms.ToTensor()]) - contentSegImg = transform(contentSegImg) - contentImgArbi = transform(contentImg) - - content_masks = ExtractMask(contentSegImg) - style_masks = ExtractMask(styleSegImg) - - ImgW = whiten(contentImgArbi.view(3,-1).double()) - ImgW = ImgW.view(contentImgArbi.size()).float() - - return contentImgArbi.squeeze(0),styleImgArbi.squeeze(0),ImgW,content_masks,style_masks,self.image_list[index] - - def __len__(self): - return len(self.image_list) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Matrix.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Matrix.py deleted file mode 100644 index 2deb029457e6f1536e843efa2386bdce3fe86ad5..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/Matrix.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import torch.nn as nn - -class CNN(nn.Module): - def __init__(self,layer,matrixSize=32): - super(CNN,self).__init__() - if(layer == 'r31'): - # 256x64x64 - self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(128,64,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(64,matrixSize,3,1,1)) - elif(layer == 'r41'): - # 512x32x32 - self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(256,128,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(128,matrixSize,3,1,1)) - - # 32x8x8 - self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize) - #self.fc = nn.Linear(32*64,256*256) - - def forward(self,x): - out = self.convs(x) - # 32x8x8 - b,c,h,w = out.size() - out = out.view(b,c,-1) - # 32x64 - out = torch.bmm(out,out.transpose(1,2)).div(h*w) - # 32x32 - out = out.view(out.size(0),-1) - return self.fc(out) - -class MulLayer(nn.Module): - def __init__(self,layer,matrixSize=32): - super(MulLayer,self).__init__() - self.snet = CNN(layer,matrixSize) - self.cnet = CNN(layer,matrixSize) - self.matrixSize = matrixSize - - if(layer == 'r41'): - self.compress = nn.Conv2d(512,matrixSize,1,1,0) - self.unzip = nn.Conv2d(matrixSize,512,1,1,0) - elif(layer == 'r31'): - self.compress = nn.Conv2d(256,matrixSize,1,1,0) - self.unzip = nn.Conv2d(matrixSize,256,1,1,0) - self.transmatrix = None - - def forward(self,cF,sF,trans=True): - cFBK = cF.clone() - cb,cc,ch,cw = cF.size() - cFF = cF.view(cb,cc,-1) - cMean = torch.mean(cFF,dim=2,keepdim=True) - cMean = cMean.unsqueeze(3) - cMean = cMean.expand_as(cF) - cF = cF - cMean - - sb,sc,sh,sw = sF.size() - sFF = sF.view(sb,sc,-1) - sMean = torch.mean(sFF,dim=2,keepdim=True) - sMean = sMean.unsqueeze(3) - sMeanC = sMean.expand_as(cF) - sMeanS = sMean.expand_as(sF) - sF = sF - sMeanS - - - compress_content = self.compress(cF) - b,c,h,w = compress_content.size() - compress_content = compress_content.view(b,c,-1) - - if(trans): - cMatrix = self.cnet(cF) - sMatrix = self.snet(sF) - - sMatrix = sMatrix.view(sMatrix.size(0),self.matrixSize,self.matrixSize) - cMatrix = cMatrix.view(cMatrix.size(0),self.matrixSize,self.matrixSize) - transmatrix = torch.bmm(sMatrix,cMatrix) - print(cMatrix) - transfeature = torch.bmm(transmatrix,compress_content).view(b,c,h,w) - out = self.unzip(transfeature.view(b,c,h,w)) - out = out + sMeanC - return out, transmatrix - else: - out = self.unzip(compress_content.view(b,c,h,w)) - out = out + cMean - return out diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/MatrixTest.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/MatrixTest.py deleted file mode 100644 index a5de4147ebda6c10a756566ccac485f0d83b0f1a..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/MatrixTest.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch.nn as nn -import torch -import torch.nn.functional as F -import numpy as np -import cv2 -from torch.autograd import Variable -import torchvision.utils as vutils - - -class CNN(nn.Module): - def __init__(self,layer,matrixSize=32): - super(CNN,self).__init__() - # 256x64x64 - if(layer == 'r31'): - self.convs = nn.Sequential(nn.Conv2d(256,128,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(128,64,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(64,matrixSize,3,1,1)) - elif(layer == 'r41'): - # 512x32x32 - self.convs = nn.Sequential(nn.Conv2d(512,256,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(256,128,3,1,1), - nn.ReLU(inplace=True), - nn.Conv2d(128,matrixSize,3,1,1)) - self.fc = nn.Linear(32*32,32*32) - - def forward(self,x,masks,style=False): - color_code_number = 9 - xb,xc,xh,xw = x.size() - x = x.view(xc,-1) - feature_sub_mean = x.clone() - for i in range(color_code_number): - mask = masks[i].clone().squeeze(0) - mask = cv2.resize(mask.numpy(),(xw,xh),interpolation=cv2.INTER_NEAREST) - mask = torch.FloatTensor(mask) - mask = mask.long() - if(torch.sum(mask) >= 10): - mask = mask.view(-1) - - # dilation here - """ - kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5)) - mask = mask.cpu().numpy() - mask = cv2.dilate(mask.astype(np.float32), kernel) - mask = torch.from_numpy(mask) - mask = mask.squeeze() - """ - - fgmask = (mask>0).nonzero().squeeze(1) - fgmask = fgmask.cuda() - selectFeature = torch.index_select(x,1,fgmask) # 32x96 - # subtract mean - f_mean = torch.mean(selectFeature,1) - f_mean = f_mean.unsqueeze(1).expand_as(selectFeature) - selectFeature = selectFeature - f_mean - feature_sub_mean.index_copy_(1,fgmask,selectFeature) - - feature = self.convs(feature_sub_mean.view(xb,xc,xh,xw)) - # 32x16x16 - b,c,h,w = feature.size() - transMatrices = {} - feature = feature.view(c,-1) - - for i in range(color_code_number): - mask = masks[i].clone().squeeze(0) - mask = cv2.resize(mask.numpy(),(w,h),interpolation=cv2.INTER_NEAREST) - mask = torch.FloatTensor(mask) - mask = mask.long() - if(torch.sum(mask) >= 10): - mask = mask.view(-1) - fgmask = Variable((mask==1).nonzero().squeeze(1)) - fgmask = fgmask.cuda() - selectFeature = torch.index_select(feature,1,fgmask) # 32x96 - tc,tN = selectFeature.size() - - covMatrix = torch.mm(selectFeature,selectFeature.transpose(0,1)).div(tN) - transmatrix = self.fc(covMatrix.view(-1)) - transMatrices[i] = transmatrix - return transMatrices,feature_sub_mean - -class MulLayer(nn.Module): - def __init__(self,layer,matrixSize=32): - super(MulLayer,self).__init__() - self.snet = CNN(layer) - self.cnet = CNN(layer) - self.matrixSize = matrixSize - - if(layer == 'r41'): - self.compress = nn.Conv2d(512,matrixSize,1,1,0) - self.unzip = nn.Conv2d(matrixSize,512,1,1,0) - elif(layer == 'r31'): - self.compress = nn.Conv2d(256,matrixSize,1,1,0) - self.unzip = nn.Conv2d(matrixSize,256,1,1,0) - - def forward(self,cF,sF,cmasks,smasks): - - sb,sc,sh,sw = sF.size() - - sMatrices,sF_sub_mean = self.snet(sF,smasks,style=True) - cMatrices,cF_sub_mean = self.cnet(cF,cmasks,style=False) - - compress_content = self.compress(cF_sub_mean.view(cF.size())) - cb,cc,ch,cw = compress_content.size() - compress_content = compress_content.view(cc,-1) - transfeature = compress_content.clone() - color_code_number = 9 - finalSMean = Variable(torch.zeros(cF.size()).cuda(0)) - finalSMean = finalSMean.view(sc,-1) - for i in range(color_code_number): - cmask = cmasks[i].clone().squeeze(0) - smask = smasks[i].clone().squeeze(0) - - cmask = cv2.resize(cmask.numpy(),(cw,ch),interpolation=cv2.INTER_NEAREST) - cmask = torch.FloatTensor(cmask) - cmask = cmask.long() - smask = cv2.resize(smask.numpy(),(sw,sh),interpolation=cv2.INTER_NEAREST) - smask = torch.FloatTensor(smask) - smask = smask.long() - if(torch.sum(cmask) >= 10 and torch.sum(smask) >= 10 - and (i in sMatrices) and (i in cMatrices)): - cmask = cmask.view(-1) - fgcmask = Variable((cmask==1).nonzero().squeeze(1)) - fgcmask = fgcmask.cuda() - - smask = smask.view(-1) - fgsmask = Variable((smask==1).nonzero().squeeze(1)) - fgsmask = fgsmask.cuda() - - sFF = sF.view(sc,-1) - sFF_select = torch.index_select(sFF,1,fgsmask) - sMean = torch.mean(sFF_select,dim=1,keepdim=True) - sMean = sMean.view(1,sc,1,1) - sMean = sMean.expand_as(cF) - - sMatrix = sMatrices[i] - cMatrix = cMatrices[i] - - sMatrix = sMatrix.view(self.matrixSize,self.matrixSize) - cMatrix = cMatrix.view(self.matrixSize,self.matrixSize) - - transmatrix = torch.mm(sMatrix,cMatrix) # (C*C) - - compress_content_select = torch.index_select(compress_content,1,fgcmask) - - transfeatureFG = torch.mm(transmatrix,compress_content_select) - transfeature.index_copy_(1,fgcmask,transfeatureFG) - - sMean = sMean.contiguous() - sMean_select = torch.index_select(sMean.view(sc,-1),1,fgcmask) - finalSMean.index_copy_(1,fgcmask,sMean_select) - out = self.unzip(transfeature.view(cb,cc,ch,cw)) - return out + finalSMean.view(out.size()) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/SPN.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/SPN.py deleted file mode 100644 index c9bf7c39411ec132afaa50e4889d78f4580c3a08..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/SPN.py +++ /dev/null @@ -1,156 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.models import vgg16 -from torch.autograd import Variable -from collections import OrderedDict -import torch.nn.functional as F -import sys -sys.path.append('../') -from libs.pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind - -class spn_block(nn.Module): - def __init__(self, horizontal, reverse): - super(spn_block, self).__init__() - self.propagator = GateRecurrent2dnoind(horizontal,reverse) - - def forward(self,x,G1,G2,G3): - sum_abs = G1.abs() + G2.abs() + G3.abs() - sum_abs.data[sum_abs.data == 0] = 1e-6 - mask_need_norm = sum_abs.ge(1) - mask_need_norm = mask_need_norm.float() - G1_norm = torch.div(G1, sum_abs) - G2_norm = torch.div(G2, sum_abs) - G3_norm = torch.div(G3, sum_abs) - - G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm - G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm - G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm - - return self.propagator(x,G1,G2,G3) - -class VGG(nn.Module): - def __init__(self,nf): - super(VGG,self).__init__() - self.conv1 = nn.Conv2d(3,nf,3,padding = 1) - # 256 x 256 - self.pool1 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) - self.conv2 = nn.Conv2d(nf,nf*2,3,padding = 1) - # 128 x 128 - self.pool2 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) - self.conv3 = nn.Conv2d(nf*2,nf*4,3,padding = 1) - # 64 x 64 - self.pool3 = nn.MaxPool2d(kernel_size = 3, stride = 2,padding=1) - # 32 x 32 - self.conv4 = nn.Conv2d(nf*4,nf*8,3,padding = 1) - - def forward(self,x): - output = {} - output['conv1'] = self.conv1(x) - x = F.relu(output['conv1']) - x = self.pool1(x) - output['conv2'] = self.conv2(x) - # 128 x 128 - x = F.relu(output['conv2']) - x = self.pool2(x) - output['conv3'] = self.conv3(x) - # 64 x 64 - x = F.relu(output['conv3']) - output['pool3'] = self.pool3(x) - # 32 x 32 - output['conv4'] = self.conv4(output['pool3']) - return output - -class Decoder(nn.Module): - def __init__(self,nf=32,spn=1): - super(Decoder,self).__init__() - # 32 x 32 - self.layer0 = nn.Conv2d(nf*8,nf*4,1,1,0) # edge_conv5 - self.layer1 = nn.Upsample(scale_factor=2,mode='bilinear') - self.layer2 = nn.Sequential(nn.Conv2d(nf*4,nf*4,3,1,1), # edge_conv8 - nn.ELU(inplace=True)) - # 64 x 64 - self.layer3 = nn.Upsample(scale_factor=2,mode='bilinear') - self.layer4 = nn.Sequential(nn.Conv2d(nf*4,nf*2,3,1,1), # edge_conv8 - nn.ELU(inplace=True)) - # 128 x 128 - self.layer5 = nn.Upsample(scale_factor=2,mode='bilinear') - self.layer6 = nn.Sequential(nn.Conv2d(nf*2,nf,3,1,1), # edge_conv8 - nn.ELU(inplace=True)) - if(spn == 1): - self.layer7 = nn.Conv2d(nf,nf*12,3,1,1) - else: - self.layer7 = nn.Conv2d(nf,nf*24,3,1,1) - self.spn = spn - # 256 x 256 - - def forward(self,encode_feature): - output = {} - output['0'] = self.layer0(encode_feature['conv4']) - output['1'] = self.layer1(output['0']) - - output['2'] = self.layer2(output['1']) - output['2res'] = output['2'] + encode_feature['conv3'] - # 64 x 64 - - output['3'] = self.layer3(output['2res']) - output['4'] = self.layer4(output['3']) - output['4res'] = output['4'] + encode_feature['conv2'] - # 128 x 128 - - output['5'] = self.layer5(output['4res']) - output['6'] = self.layer6(output['5']) - output['6res'] = output['6'] + encode_feature['conv1'] - - output['7'] = self.layer7(output['6res']) - - return output['7'] - - -class SPN(nn.Module): - def __init__(self,nf=32,spn=1): - super(SPN,self).__init__() - # conv for mask - self.mask_conv = nn.Conv2d(3,nf,3,1,1) - - # guidance network - self.encoder = VGG(nf) - self.decoder = Decoder(nf,spn) - - # spn blocks - self.left_right = spn_block(True,False) - self.right_left = spn_block(True,True) - self.top_down = spn_block(False, False) - self.down_top = spn_block(False,True) - - # post upsample - self.post = nn.Conv2d(nf,3,3,1,1) - self.nf = nf - - def forward(self,x,rgb): - # feature for mask - X = self.mask_conv(x) - - # guidance - features = self.encoder(rgb) - guide = self.decoder(features) - - G = torch.split(guide,self.nf,1) - out1 = self.left_right(X,G[0],G[1],G[2]) - out2 = self.right_left(X,G[3],G[4],G[5]) - out3 = self.top_down(X,G[6],G[7],G[8]) - out4 = self.down_top(X,G[9],G[10],G[11]) - - out = torch.max(out1,out2) - out = torch.max(out,out3) - out = torch.max(out,out4) - - return self.post(out) - -if __name__ == '__main__': - spn = SPN() - spn = spn.cuda() - for i in range(100): - x = Variable(torch.Tensor(1,3,256,256)).cuda() - rgb = Variable(torch.Tensor(1,3,256,256)).cuda() - output = spn(x,rgb) - print(output.size()) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-311.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-311.pyc deleted file mode 100644 index d473f1aaf822fe9e25b06f58b3662de7ce57dfd8..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-311.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-312.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-312.pyc deleted file mode 100644 index 951a8443fcf0ed40d088510a39766573e302444b..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-312.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-39.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-39.pyc deleted file mode 100644 index f12f525fbf30d2bae06d54fc810ed6798108a650..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/Matrix.cpython-39.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-311.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 252f88d27c83ba5af3abe484d9d76db7e97ba73d..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-312.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 10f35bd8b4aed9f4fc0b44fa431d3460e5c986f3..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-39.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 748966533cf24d96e0273fadb5057eb42d53559b..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-311.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-311.pyc deleted file mode 100644 index add77acd4d365438d2c8737b047e317cef539916..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-311.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-312.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-312.pyc deleted file mode 100644 index d9b7a52d0ad539a685aaefb7f431f642fc25be7e..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-312.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-39.pyc b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-39.pyc deleted file mode 100644 index a36b27e768da3ad951458294b654d457db4bbe9d..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/__pycache__/models.cpython-39.pyc and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/models.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/models.py deleted file mode 100644 index 964f273f2d28083b6aa15d09ed3876111c29a49d..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/models.py +++ /dev/null @@ -1,662 +0,0 @@ -import torch -import torch.nn as nn - -class encoder3(nn.Module): - def __init__(self): - super(encoder3,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2,return_indices = True) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - def forward(self,x): - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - out = self.relu2(out) - out = self.reflecPad3(out) - out = self.conv3(out) - pool1 = self.relu3(out) - out,pool_idx = self.maxPool(pool1) - out = self.reflecPad4(out) - out = self.conv4(out) - out = self.relu4(out) - out = self.reflecPad5(out) - out = self.conv5(out) - pool2 = self.relu5(out) - out,pool_idx2 = self.maxPool2(pool2) - out = self.reflecPad6(out) - out = self.conv6(out) - out = self.relu6(out) - return out - -class decoder3(nn.Module): - def __init__(self): - super(decoder3,self).__init__() - # decoder - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,128,3,1,0) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 112 x 112 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(128,128,3,1,0) - self.relu8 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(128,64,3,1,0) - self.relu9 = nn.ReLU(inplace=True) - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(64,64,3,1,0) - self.relu10 = nn.ReLU(inplace=True) - - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(64,3,3,1,0) - - def forward(self,x): - output = {} - out = self.reflecPad7(x) - out = self.conv7(out) - out = self.relu7(out) - out = self.unpool(out) - out = self.reflecPad8(out) - out = self.conv8(out) - out = self.relu8(out) - out = self.reflecPad9(out) - out = self.conv9(out) - out_relu9 = self.relu9(out) - out = self.unpool2(out_relu9) - out = self.reflecPad10(out) - out = self.conv10(out) - out = self.relu10(out) - out = self.reflecPad11(out) - out = self.conv11(out) - return out - -class encoder4(nn.Module): - def __init__(self): - super(encoder4,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,256,3,1,0) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(256,256,3,1,0) - self.relu8 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(256,256,3,1,0) - self.relu9 = nn.ReLU(inplace=True) - # 56 x 56 - - self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) - # 28 x 28 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(256,512,3,1,0) - self.relu10 = nn.ReLU(inplace=True) - # 28 x 28 - def forward(self,x,sF=None,matrix11=None,matrix21=None,matrix31=None): - output = {} - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - output['r11'] = self.relu2(out) - out = self.reflecPad7(output['r11']) - - out = self.conv3(out) - output['r12'] = self.relu3(out) - - output['p1'] = self.maxPool(output['r12']) - out = self.reflecPad4(output['p1']) - out = self.conv4(out) - output['r21'] = self.relu4(out) - out = self.reflecPad7(output['r21']) - - out = self.conv5(out) - output['r22'] = self.relu5(out) - - output['p2'] = self.maxPool2(output['r22']) - out = self.reflecPad6(output['p2']) - out = self.conv6(out) - output['r31'] = self.relu6(out) - if(matrix31 is not None): - feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) - out = self.reflecPad7(feature3) - else: - out = self.reflecPad7(output['r31']) - out = self.conv7(out) - output['r32'] = self.relu7(out) - - out = self.reflecPad8(output['r32']) - out = self.conv8(out) - output['r33'] = self.relu8(out) - - out = self.reflecPad9(output['r33']) - out = self.conv9(out) - output['r34'] = self.relu9(out) - - output['p3'] = self.maxPool3(output['r34']) - out = self.reflecPad10(output['p3']) - out = self.conv10(out) - output['r41'] = self.relu10(out) - - return output - -class decoder4(nn.Module): - def __init__(self): - super(decoder4,self).__init__() - # decoder - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(512,256,3,1,0) - self.relu11 = nn.ReLU(inplace=True) - # 28 x 28 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 56 x 56 - - self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) - self.conv12 = nn.Conv2d(256,256,3,1,0) - self.relu12 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) - self.conv13 = nn.Conv2d(256,256,3,1,0) - self.relu13 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) - self.conv14 = nn.Conv2d(256,256,3,1,0) - self.relu14 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) - self.conv15 = nn.Conv2d(256,128,3,1,0) - self.relu15 = nn.ReLU(inplace=True) - # 56 x 56 - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 112 x 112 - - self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) - self.conv16 = nn.Conv2d(128,128,3,1,0) - self.relu16 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) - self.conv17 = nn.Conv2d(128,64,3,1,0) - self.relu17 = nn.ReLU(inplace=True) - # 112 x 112 - - self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) - self.conv18 = nn.Conv2d(64,64,3,1,0) - self.relu18 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) - self.conv19 = nn.Conv2d(64,3,3,1,0) - - def forward(self,x): - # decoder - out = self.reflecPad11(x) - out = self.conv11(out) - out = self.relu11(out) - out = self.unpool(out) - out = self.reflecPad12(out) - out = self.conv12(out) - - out = self.relu12(out) - out = self.reflecPad13(out) - out = self.conv13(out) - out = self.relu13(out) - out = self.reflecPad14(out) - out = self.conv14(out) - out = self.relu14(out) - out = self.reflecPad15(out) - out = self.conv15(out) - out = self.relu15(out) - out = self.unpool2(out) - out = self.reflecPad16(out) - out = self.conv16(out) - out = self.relu16(out) - out = self.reflecPad17(out) - out = self.conv17(out) - out = self.relu17(out) - out = self.unpool3(out) - out = self.reflecPad18(out) - out = self.conv18(out) - out = self.relu18(out) - out = self.reflecPad19(out) - out = self.conv19(out) - return out - -class decoder4(nn.Module): - def __init__(self): - super(decoder4,self).__init__() - # decoder - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(512,256,3,1,0) - self.relu11 = nn.ReLU(inplace=True) - # 28 x 28 - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 56 x 56 - - self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) - self.conv12 = nn.Conv2d(256,256,3,1,0) - self.relu12 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) - self.conv13 = nn.Conv2d(256,256,3,1,0) - self.relu13 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) - self.conv14 = nn.Conv2d(256,256,3,1,0) - self.relu14 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) - self.conv15 = nn.Conv2d(256,128,3,1,0) - self.relu15 = nn.ReLU(inplace=True) - # 56 x 56 - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 112 x 112 - - self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) - self.conv16 = nn.Conv2d(128,128,3,1,0) - self.relu16 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) - self.conv17 = nn.Conv2d(128,64,3,1,0) - self.relu17 = nn.ReLU(inplace=True) - # 112 x 112 - - self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) - # 224 x 224 - - self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) - self.conv18 = nn.Conv2d(64,64,3,1,0) - self.relu18 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) - self.conv19 = nn.Conv2d(64,3,3,1,0) - - def forward(self,x): - # decoder - out = self.reflecPad11(x) - out = self.conv11(out) - out = self.relu11(out) - out = self.unpool(out) - out = self.reflecPad12(out) - out = self.conv12(out) - - out = self.relu12(out) - out = self.reflecPad13(out) - out = self.conv13(out) - out = self.relu13(out) - out = self.reflecPad14(out) - out = self.conv14(out) - out = self.relu14(out) - out = self.reflecPad15(out) - out = self.conv15(out) - out = self.relu15(out) - out = self.unpool2(out) - out = self.reflecPad16(out) - out = self.conv16(out) - out = self.relu16(out) - out = self.reflecPad17(out) - out = self.conv17(out) - out = self.relu17(out) - out = self.unpool3(out) - out = self.reflecPad18(out) - out = self.conv18(out) - out = self.relu18(out) - out = self.reflecPad19(out) - out = self.conv19(out) - return out - -class encoder5(nn.Module): - def __init__(self): - super(encoder5,self).__init__() - # vgg - # 224 x 224 - self.conv1 = nn.Conv2d(3,3,1,1,0) - self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) - # 226 x 226 - - self.conv2 = nn.Conv2d(3,64,3,1,0) - self.relu2 = nn.ReLU(inplace=True) - # 224 x 224 - - self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) - self.conv3 = nn.Conv2d(64,64,3,1,0) - self.relu3 = nn.ReLU(inplace=True) - # 224 x 224 - - self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) - # 112 x 112 - - self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) - self.conv4 = nn.Conv2d(64,128,3,1,0) - self.relu4 = nn.ReLU(inplace=True) - # 112 x 112 - - self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) - self.conv5 = nn.Conv2d(128,128,3,1,0) - self.relu5 = nn.ReLU(inplace=True) - # 112 x 112 - - self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) - # 56 x 56 - - self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) - self.conv6 = nn.Conv2d(128,256,3,1,0) - self.relu6 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) - self.conv7 = nn.Conv2d(256,256,3,1,0) - self.relu7 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) - self.conv8 = nn.Conv2d(256,256,3,1,0) - self.relu8 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) - self.conv9 = nn.Conv2d(256,256,3,1,0) - self.relu9 = nn.ReLU(inplace=True) - # 56 x 56 - - self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) - # 28 x 28 - - self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) - self.conv10 = nn.Conv2d(256,512,3,1,0) - self.relu10 = nn.ReLU(inplace=True) - - self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) - self.conv11 = nn.Conv2d(512,512,3,1,0) - self.relu11 = nn.ReLU(inplace=True) - - self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) - self.conv12 = nn.Conv2d(512,512,3,1,0) - self.relu12 = nn.ReLU(inplace=True) - - self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) - self.conv13 = nn.Conv2d(512,512,3,1,0) - self.relu13 = nn.ReLU(inplace=True) - - self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) - self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) - self.conv14 = nn.Conv2d(512,512,3,1,0) - self.relu14 = nn.ReLU(inplace=True) - - def forward(self,x,sF=None,contentV256=None,styleV256=None,matrix11=None,matrix21=None,matrix31=None): - output = {} - out = self.conv1(x) - out = self.reflecPad1(out) - out = self.conv2(out) - output['r11'] = self.relu2(out) - out = self.reflecPad7(output['r11']) - - #out = self.reflecPad3(output['r11']) - out = self.conv3(out) - output['r12'] = self.relu3(out) - - output['p1'] = self.maxPool(output['r12']) - out = self.reflecPad4(output['p1']) - out = self.conv4(out) - output['r21'] = self.relu4(out) - out = self.reflecPad7(output['r21']) - - #out = self.reflecPad5(output['r21']) - out = self.conv5(out) - output['r22'] = self.relu5(out) - - output['p2'] = self.maxPool2(output['r22']) - out = self.reflecPad6(output['p2']) - out = self.conv6(out) - output['r31'] = self.relu6(out) - if(styleV256 is not None): - feature = matrix31(output['r31'],sF['r31'],contentV256,styleV256) - out = self.reflecPad7(feature) - else: - out = self.reflecPad7(output['r31']) - out = self.conv7(out) - output['r32'] = self.relu7(out) - - out = self.reflecPad8(output['r32']) - out = self.conv8(out) - output['r33'] = self.relu8(out) - - out = self.reflecPad9(output['r33']) - out = self.conv9(out) - output['r34'] = self.relu9(out) - - output['p3'] = self.maxPool3(output['r34']) - out = self.reflecPad10(output['p3']) - out = self.conv10(out) - output['r41'] = self.relu10(out) - - out = self.reflecPad11(output['r41']) - out = self.conv11(out) - output['r42'] = self.relu11(out) - - out = self.reflecPad12(output['r42']) - out = self.conv12(out) - output['r43'] = self.relu12(out) - - out = self.reflecPad13(output['r43']) - out = self.conv13(out) - output['r44'] = self.relu13(out) - - output['p4'] = self.maxPool4(output['r44']) - - out = self.reflecPad14(output['p4']) - out = self.conv14(out) - output['r51'] = self.relu14(out) - return output - -class decoder5(nn.Module): - def __init__(self): - super(decoder5,self).__init__() - - # decoder - self.reflecPad15 = nn.ReflectionPad2d((1,1,1,1)) - self.conv15 = nn.Conv2d(512,512,3,1,0) - self.relu15 = nn.ReLU(inplace=True) - - self.unpool = nn.UpsamplingNearest2d(scale_factor=2) - # 28 x 28 - - self.reflecPad16 = nn.ReflectionPad2d((1,1,1,1)) - self.conv16 = nn.Conv2d(512,512,3,1,0) - self.relu16 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad17 = nn.ReflectionPad2d((1,1,1,1)) - self.conv17 = nn.Conv2d(512,512,3,1,0) - self.relu17 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad18 = nn.ReflectionPad2d((1,1,1,1)) - self.conv18 = nn.Conv2d(512,512,3,1,0) - self.relu18 = nn.ReLU(inplace=True) - # 28 x 28 - - self.reflecPad19 = nn.ReflectionPad2d((1,1,1,1)) - self.conv19 = nn.Conv2d(512,256,3,1,0) - self.relu19 = nn.ReLU(inplace=True) - # 28 x 28 - - self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) - # 56 x 56 - - self.reflecPad20 = nn.ReflectionPad2d((1,1,1,1)) - self.conv20 = nn.Conv2d(256,256,3,1,0) - self.relu20 = nn.ReLU(inplace=True) - # 56 x 56 - - self.reflecPad21 = nn.ReflectionPad2d((1,1,1,1)) - self.conv21 = nn.Conv2d(256,256,3,1,0) - self.relu21 = nn.ReLU(inplace=True) - - self.reflecPad22 = nn.ReflectionPad2d((1,1,1,1)) - self.conv22 = nn.Conv2d(256,256,3,1,0) - self.relu22 = nn.ReLU(inplace=True) - - self.reflecPad23 = nn.ReflectionPad2d((1,1,1,1)) - self.conv23 = nn.Conv2d(256,128,3,1,0) - self.relu23 = nn.ReLU(inplace=True) - - self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) - # 112 X 112 - - self.reflecPad24 = nn.ReflectionPad2d((1,1,1,1)) - self.conv24 = nn.Conv2d(128,128,3,1,0) - self.relu24 = nn.ReLU(inplace=True) - - self.reflecPad25 = nn.ReflectionPad2d((1,1,1,1)) - self.conv25 = nn.Conv2d(128,64,3,1,0) - self.relu25 = nn.ReLU(inplace=True) - - self.unpool4 = nn.UpsamplingNearest2d(scale_factor=2) - - self.reflecPad26 = nn.ReflectionPad2d((1,1,1,1)) - self.conv26 = nn.Conv2d(64,64,3,1,0) - self.relu26 = nn.ReLU(inplace=True) - - self.reflecPad27 = nn.ReflectionPad2d((1,1,1,1)) - self.conv27 = nn.Conv2d(64,3,3,1,0) - - def forward(self,x): - # decoder - out = self.reflecPad15(x) - out = self.conv15(out) - out = self.relu15(out) - out = self.unpool(out) - out = self.reflecPad16(out) - out = self.conv16(out) - out = self.relu16(out) - out = self.reflecPad17(out) - out = self.conv17(out) - out = self.relu17(out) - out = self.reflecPad18(out) - out = self.conv18(out) - out = self.relu18(out) - out = self.reflecPad19(out) - out = self.conv19(out) - out = self.relu19(out) - out = self.unpool2(out) - out = self.reflecPad20(out) - out = self.conv20(out) - out = self.relu20(out) - out = self.reflecPad21(out) - out = self.conv21(out) - out = self.relu21(out) - out = self.reflecPad22(out) - out = self.conv22(out) - out = self.relu22(out) - out = self.reflecPad23(out) - out = self.conv23(out) - out = self.relu23(out) - out = self.unpool3(out) - out = self.reflecPad24(out) - out = self.conv24(out) - out = self.relu24(out) - out = self.reflecPad25(out) - out = self.conv25(out) - out = self.relu25(out) - out = self.unpool4(out) - out = self.reflecPad26(out) - out = self.conv26(out) - out = self.relu26(out) - out = self.reflecPad27(out) - out = self.conv27(out) - return out diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md deleted file mode 100644 index f3c3c7e83fb5e80b0a63d14d4e546772f49cc1e6..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# pytorch_spn -To build, install [pytorch](https://github.com/pytorch) and run: - -$ sh make.sh - -See left_right_demo.py for usage: - -$ mv left_right_demo.py ../ - -$ python left_right_demo.py - -The original codes (caffe) and models will be relesed [HERE](https://github.com/Liusifei/caffe-spn.git). diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py deleted file mode 100644 index cab752cdaacefb3a3a2987dd235f75daf3df150d..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/_ext/gaterecurrent2dnoind/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ - -from torch.utils.ffi import _wrap_function -from ._gaterecurrent2dnoind import lib as _lib, ffi as _ffi - -__all__ = [] -def _import_symbols(locals): - for symbol in dir(_lib): - fn = getattr(_lib, symbol) - if callable(fn): - locals[symbol] = _wrap_function(fn, _ffi) - else: - locals[symbol] = fn - __all__.append(symbol) - -_import_symbols(locals()) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py deleted file mode 100644 index 70aecaecdc5b935cf2addd4e80786ee922d8d259..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/build.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import torch -from torch.utils.ffi import create_extension - -this_file = os.path.dirname(__file__) - -sources = [] -headers = [] -defines = [] -with_cuda = False - -if torch.cuda.is_available(): - print('Including CUDA code.') - sources += ['src/gaterecurrent2dnoind_cuda.c'] - headers += ['src/gaterecurrent2dnoind_cuda.h'] - defines += [('WITH_CUDA', None)] - with_cuda = True - -this_file = os.path.dirname(os.path.realpath(__file__)) -extra_objects = ['src/cuda/gaterecurrent2dnoind_kernel.cu.o'] -extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] - -ffi = create_extension( - '_ext.gaterecurrent2dnoind', - headers=headers, - sources=sources, - define_macros=defines, - relative_to=__file__, - with_cuda=with_cuda, - extra_objects=extra_objects -) - -if __name__ == '__main__': - ffi.build() diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py deleted file mode 100644 index 972b726ce8093322c22ecd54aa0460c59273a985..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/functions/gaterecurrent2dnoind.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -from torch.autograd import Function -from .._ext import gaterecurrent2dnoind as gaterecurrent2d - -class GateRecurrent2dnoindFunction(Function): - def __init__(self, horizontal_, reverse_): - self.horizontal = horizontal_ - self.reverse = reverse_ - - def forward(self, X, G1, G2, G3): - num, channels, height, width = X.size() - output = torch.zeros(num, channels, height, width) - - if not X.is_cuda: - print("cpu version is not ready at this time") - return 0 - else: - output = output.cuda() - gaterecurrent2d.gaterecurrent2dnoind_forward_cuda(self.horizontal,self.reverse, X, G1, G2, G3, output) - - self.X = X - self.G1 = G1 - self.G2 = G2 - self.G3 = G3 - self.output = output - self.hiddensize = X.size() - return output - - def backward(self, grad_output): - assert(self.hiddensize is not None and grad_output.is_cuda) - num, channels, height, width = self.hiddensize - - grad_X = torch.zeros(num, channels, height, width).cuda() - grad_G1 = torch.zeros(num, channels, height, width).cuda() - grad_G2 = torch.zeros(num, channels, height, width).cuda() - grad_G3 = torch.zeros(num, channels, height, width).cuda() - - gaterecurrent2d.gaterecurrent2dnoind_backward_cuda(self.horizontal, self.reverse, self.output, grad_output, self.X, self.G1, self.G2, self.G3, grad_X, grad_G1, grad_G2, grad_G3) - - del self.hiddensize - del self.G1 - del self.G2 - del self.G3 - del self.output - del self.X - - return grad_X, grad_G1, grad_G2, grad_G3 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py deleted file mode 100644 index 7c8d58ebbafe29ed90e9f8530c68402156ecce74..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/left_right_demo.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -An example of left->right propagation - -Other direction settings: -left->right: Propagator = GateRecurrent2dnoind(True,False) -right->left: Propagator = GateRecurrent2dnoind(True,True) -top->bottom: Propagator = GateRecurrent2dnoind(False,False) -bottom->top: Propagator = GateRecurrent2dnoind(False,True) - -X: any signal/feature map to be filtered -G1~G3: three coefficient maps (e.g., left-top, left-center, left-bottom) - -Note: -1. G1~G3 constitute the affinity, they can be a bounch of output maps coming from any CNN, with the input of any useful known information (e.g., RGB images) -2. for any pixel i, |G1(i)| + |G2(i)| + |G3(i)| <= 1 is a sufficent condition for model stability (see paper) -""" -import torch -from torch.autograd import Variable -from pytorch_spn.modules.gaterecurrent2dnoind import GateRecurrent2dnoind - -Propagator = GateRecurrent2dnoind(True,False) - -X = Variable(torch.randn(1,3,10,10)) -G1 = Variable(torch.randn(1,3,10,10)) -G2 = Variable(torch.randn(1,3,10,10)) -G3 = Variable(torch.randn(1,3,10,10)) - -sum_abs = G1.abs() + G2.abs() + G3.abs() -mask_need_norm = sum_abs.ge(1) -mask_need_norm = mask_need_norm.float() -G1_norm = torch.div(G1, sum_abs) -G2_norm = torch.div(G2, sum_abs) -G3_norm = torch.div(G3, sum_abs) - -G1 = torch.add(-mask_need_norm, 1) * G1 + mask_need_norm * G1_norm -G2 = torch.add(-mask_need_norm, 1) * G2 + mask_need_norm * G2_norm -G3 = torch.add(-mask_need_norm, 1) * G3 + mask_need_norm * G3_norm - -X = X.cuda() -G1 = G1.cuda() -G2 = G2.cuda() -G3 = G3.cuda() - -output = Propagator.forward(X,G1,G2,G3) -print(X) -print(output) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh deleted file mode 100644 index 7bafe812c4c8f94986f42e71fcf4f8ca6aa72904..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/make.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env bash - -CUDA_PATH=/usr/local/cuda/ - -cd src/cuda/ -echo "Compiling gaterecurrent2dnoind layer kernels by nvcc..." -nvcc -c -o gaterecurrent2dnoind_kernel.cu.o gaterecurrent2dnoind_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 -cd ../../ -python build.py diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py deleted file mode 100644 index 46ffe82da20aea4546b13fd2b0bdd6ac7a548192..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/modules/gaterecurrent2dnoind.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch.nn as nn -from ..functions.gaterecurrent2dnoind import GateRecurrent2dnoindFunction - -class GateRecurrent2dnoind(nn.Module): - """docstring for .""" - def __init__(self, horizontal_, reverse_): - super(GateRecurrent2dnoind, self).__init__() - self.horizontal = horizontal_ - self.reverse = reverse_ - - def forward(self, X, G1, G2, G3): - return GateRecurrent2dnoindFunction(self.horizontal, self.reverse)(X, G1, G2, G3) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store deleted file mode 100644 index 6c7b6212579ceb149ee62bdbb280ee4aa553a616..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/.DS_Store and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu deleted file mode 100644 index ddad6f59cbbf7675a42f0679fe236b8745e5f1e9..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu +++ /dev/null @@ -1,697 +0,0 @@ -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include -#include "gaterecurrent2dnoind_kernel.h" - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - -__device__ void get_gate_idx_sf(int h1, int w1, int h2, int w2, int * out, int horizontal, int reverse) -{ - if(horizontal && ! reverse) // left -> right - { - if(w1>w2) - { - out[0]=h1; - out[1]=w1; - } - else - { - out[0]=h2; - out[1]=w2; - } - } - if(horizontal && reverse) // right -> left - { - if(w1 bottom - { - if(h1>h2) - { - out[0]=h1; - out[1]=w1; - } - else - { - out[0]=h2; - out[1]=w2; - } - } - if(!horizontal && reverse) // bottom -> top - { - if(h1=height) - return 0; - if(w<0 || w >= width) - return 0; - - return data[n*channels*height*width + c * height*width + h * width + w]; -} - -__device__ void set_data_sf(float * data, int num, int channels,int height, int width,int n,int c,int h,int w, float v) -{ - if(h<0 || h >=height) - return ; - if(w<0 || w >= width) - return ; - - data[n*channels*height*width + c * height*width + h * width + w]=v; -} - -__device__ float get_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse) -{ - if(h1<0 || h1 >=height) - return 0; - if(w1<0 || w1 >= width) - return 0; - if(h2<0 || h2 >=height) - return 0; - if(w2<0 || w2 >= width) - return 0; - int idx[2]; - - get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse); - - int h = idx[0]; - int w = idx[1]; - - return data[n*channels*height*width + c * height*width + h * width + w]; -} - -__device__ void set_gate_sf(float * data, int num, int channels,int height, int width,int n,int c,int h1,int w1,int h2,int w2,int horizontal,int reverse, float v) -{ - if(h1<0 || h1 >=height) - return ; - if(w1<0 || w1 >= width) - return ; - if(h2<0 || h2 >=height) - return ; - if(w2<0 || w2 >= width) - return ; - int idx[2]; - - get_gate_idx_sf(h1,w1,h2,w2, idx,horizontal, reverse); - - int h = idx[0]; - int w = idx[1]; - - data[n*channels*height*width + c * height*width + h * width + w]=v; -} - -// we do not use set_gate_add_sf(...) in the caffe implimentation -// avoid using atomicAdd - -__global__ void forward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, int horizontal, int reverse) { - CUDA_1D_KERNEL_LOOP(index, count) { - - int hc_count = height * channels; - - int n,c,h,w; - int temp=index; - w = T; - n = temp / hc_count; - temp = temp % hc_count; - c = temp / height; - temp = temp % height; - h = temp; - - - float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); - - float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse); - float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1); - float h1_minus1 = g_data_1 * h_minus1_data_1; - - float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w-1,horizontal,reverse); - float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w-1); - float h2_minus1 = g_data_2 * h_minus1_data_2; - - float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse); - float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1); - float h3_minus1 = g_data_3 * h_minus1_data_3; - - float h_hype = h1_minus1 + h2_minus1 + h3_minus1; - float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; - - float h_data = x_hype + h_hype; - - set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); - - } -} - -__global__ void forward_one_col_right_left( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { - CUDA_1D_KERNEL_LOOP(index, count) { - - int hc_count = height * channels; - int n,c,h,w; - int temp=index; - w = T; - n = temp / hc_count; - temp = temp % hc_count; - c = temp / height; - temp = temp % height; - h = temp; - - float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); - - float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse); - float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1); - float h1_minus1 = g_data_1 * h_minus1_data_1; - - float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h,w+1,horizontal,reverse); - float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h,w+1); - float h2_minus1 = g_data_2 * h_minus1_data_2; - - float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse); - float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1); - float h3_minus1 = g_data_3 * h_minus1_data_3; - - float h_hype = h1_minus1 + h2_minus1 + h3_minus1; - float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; - - float h_data = x_hype + h_hype; - - set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); - - } -} - -__global__ void forward_one_row_top_bottom( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { - CUDA_1D_KERNEL_LOOP(index, count) { - - int wc_count = width * channels; - - int n,c,h,w; - int temp=index; - h = T; - n = temp / wc_count; - temp = temp % wc_count; - c = temp / width; - temp = temp % width; - w = temp; - - - float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); - - - float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h-1,w-1,horizontal,reverse); - float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h-1,w-1); - float h1_minus1 = g_data_1 * h_minus1_data_1; - - float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h-1,w,horizontal,reverse); - float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h-1,w); - float h2_minus1 = g_data_2 * h_minus1_data_2; - - float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h-1,w+1,horizontal,reverse); - float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h-1,w+1); - float h3_minus1 = g_data_3 * h_minus1_data_3; - - float h_hype = h1_minus1 + h2_minus1 + h3_minus1; - float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; - - float h_data = x_hype + h_hype; - - set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); - - } -} - -__global__ void forward_one_row_bottom_top( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H,int horizontal,int reverse) { - CUDA_1D_KERNEL_LOOP(index, count) { - - int wc_count = width * channels; - - int n,c,h,w; - int temp=index; - h = T; - n = temp / wc_count; - temp = temp % wc_count; - c = temp / width; - temp = temp % width; - w = temp; - - - float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); - - - float g_data_1 = get_gate_sf(G1,num,channels,height,width,n,c,h,w,h+1,w-1,horizontal,reverse); - float h_minus1_data_1 = get_data_sf(H,num,channels,height,width,n,c,h+1,w-1); - float h1_minus1 = g_data_1 * h_minus1_data_1; - - - float g_data_2 = get_gate_sf(G2,num,channels,height,width,n,c,h,w,h+1,w,horizontal,reverse); - float h_minus1_data_2 = get_data_sf(H,num,channels,height,width,n,c,h+1,w); - float h2_minus1 = g_data_2 * h_minus1_data_2; - - float g_data_3 = get_gate_sf(G3,num,channels,height,width,n,c,h,w,h+1,w+1,horizontal,reverse); - float h_minus1_data_3 = get_data_sf(H,num,channels,height,width,n,c,h+1,w+1); - float h3_minus1 = g_data_3 * h_minus1_data_3; - - float h_hype = h1_minus1 + h2_minus1 + h3_minus1; - float x_hype = (1 - g_data_1 - g_data_2 - g_data_3) * x_data; - - float h_data = x_hype + h_hype; - - set_data_sf(H,num,channels,height,width,n,c,h,w,h_data); - - } -} - - -__global__ void backward_one_col_left_right( int count, int T, int num,int channels, int height, int width, float* X, float* G1, float* G2, float* G3, float* H, float * X_diff, float * G1_diff,float* G2_diff,float * G3_diff, float * Hdiff,int horizontal,int reverse) { - CUDA_1D_KERNEL_LOOP(index, count) { - - int hc_count = height * channels; - - int n,c,h,w; - int temp=index; - - w = T; - n = temp / hc_count; - temp = temp % hc_count; - c = temp / height; - temp = temp % height; - h = temp; - - float x_data = get_data_sf(X,num,channels,height,width,n,c,h,w); - - //h(t)_diff = top(t)_diff - float h_diff = get_data_sf(Hdiff,num,channels,height,width,n,c,h,w); - - //h(t)_diff += h(t+1)_diff * g(t+1) if t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = height_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = width_ - 1; t >= 0; t--) { - forward_one_col_right_left<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = width_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t=0; t< height_; t++) { - forward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = width_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = height_-1; t >= 0; t--) { - forward_one_row_bottom_top<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = height_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = width_ -1; t>=0; t--) - { - backward_one_col_left_right<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = height_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = width_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = height_-1; t>=0; t--) - { - backward_one_row_top_bottom<<<(count + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream) -{ - int count = width_ * channels_ * num_; - int kThreadsPerBlock = 1024; - cudaError_t err; - - for(int t = 0; t>>(count, t, num_, channels_, height_, width_, X, G1, G2, G3, H, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_); - - err = cudaGetLastError(); - if(cudaSuccess != err) - { - fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); - exit( -1 ); - } - } - return 1; -} - -#ifdef __cplusplus -} -#endif diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o deleted file mode 100644 index 5d12c0a4eb2089d47b24bb29aec8ded0b91f0517..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.cu.o and /dev/null differ diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h deleted file mode 100644 index ebfadd16a03ef4c8d34365940ce034ee0caea1fe..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/cuda/gaterecurrent2dnoind_kernel.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _GATERECURRENT2DNOIND_KERNEL -#define _GATERECURRENT2DNOIND_KERNEL - -#ifdef __cplusplus -extern "C" { -#endif - -int Forward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); - -int Forward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); - -int Forward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); - -int Forward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, int horizontal_, int reverse_, cudaStream_t stream); - -int Backward_left_right(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); - -int Backward_right_left(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); - -int Backward_top_bottom(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); - -int Backward_bottom_top(int num_, int channels_, int height_, int width_, float * X, float * G1, float * G2, float * G3, float * H, float * X_diff, float * G1_diff, float * G2_diff, float * G3_diff, float * H_diff, int horizontal_, int reverse_, cudaStream_t stream); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c deleted file mode 100644 index 2211c588e4a264c22b1f9e47cde53f3d986c14a3..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.c +++ /dev/null @@ -1,91 +0,0 @@ -// gaterecurrent2dnoind_cuda.c -#include -#include -#include "gaterecurrent2dnoind_cuda.h" -#include "cuda/gaterecurrent2dnoind_kernel.h" - -// typedef bool boolean; - -// this symbol will be resolved automatically from PyTorch libs -extern THCState *state; - -int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output) -{ - // Grab the input tensor to flat - float * X_data = THCudaTensor_data(state, X); - float * G1_data = THCudaTensor_data(state, G1); - float * G2_data = THCudaTensor_data(state, G2); - float * G3_data = THCudaTensor_data(state, G3); - float * H_data = THCudaTensor_data(state, output); - - // dimensions - int num_ = THCudaTensor_size(state, X, 0); - int channels_ = THCudaTensor_size(state, X, 1); - int height_ = THCudaTensor_size(state, X, 2); - int width_ = THCudaTensor_size(state, X, 3); - - cudaStream_t stream = THCState_getCurrentStream(state); - - if(horizontal_ && !reverse_) // left to right - { - //const int count = height_ * channels_ * num_; - Forward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); - } - else if(horizontal_ && reverse_) // right to left - { - Forward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); - } - else if(!horizontal_ && !reverse_) // top to bottom - { - Forward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); - } - else - { - Forward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, horizontal_, reverse_, stream); - } - - return 1; -} - -int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_grad, THCudaTensor * G1_grad, THCudaTensor * G2_grad, THCudaTensor * G3_grad) -{ - //Grab the input tensor to flat - float * X_data = THCudaTensor_data(state, X); - float * G1_data = THCudaTensor_data(state, G1); - float * G2_data = THCudaTensor_data(state, G2); - float * G3_data = THCudaTensor_data(state, G3); - float * H_data = THCudaTensor_data(state, top); - - float * H_diff = THCudaTensor_data(state, top_grad); - - float * X_diff = THCudaTensor_data(state, X_grad); - float * G1_diff = THCudaTensor_data(state, G1_grad); - float * G2_diff = THCudaTensor_data(state, G2_grad); - float * G3_diff = THCudaTensor_data(state, G3_grad); - - // dimensions - int num_ = THCudaTensor_size(state, X, 0); - int channels_ = THCudaTensor_size(state, X, 1); - int height_ = THCudaTensor_size(state, X, 2); - int width_ = THCudaTensor_size(state, X, 3); - - cudaStream_t stream = THCState_getCurrentStream(state); - - if(horizontal_ && ! reverse_) //left to right - { - Backward_left_right(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); - } - else if(horizontal_ && reverse_) //right to left - { - Backward_right_left(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); - } - else if(!horizontal_ && !reverse_) //top to bottom - { - Backward_top_bottom(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); - } - else { - Backward_bottom_top(num_, channels_, height_, width_, X_data, G1_data, G2_data, G3_data, H_data, X_diff, G1_diff, G2_diff, G3_diff, H_diff, horizontal_, reverse_, stream); - } - - return 1; -} diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h deleted file mode 100644 index da60473306ca9c3addcf0a2164bd13c86b226d7e..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/pytorch_spn/src/gaterecurrent2dnoind_cuda.h +++ /dev/null @@ -1,6 +0,0 @@ - -// #include -// gaterecurrent2dnoind_cuda.h -int gaterecurrent2dnoind_forward_cuda(int horizontal_, int reverse_, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * output); - -int gaterecurrent2dnoind_backward_cuda(int horizontal_, int reverse_, THCudaTensor* top, THCudaTensor* top_grad, THCudaTensor * X, THCudaTensor * G1, THCudaTensor * G2, THCudaTensor * G3, THCudaTensor * X_diff, THCudaTensor * G1_diff, THCudaTensor * G2_diff, THCudaTensor * G3_diff); diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/smooth_filter.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/smooth_filter.py deleted file mode 100644 index 4968615b1883dc095fe1d5900fe15cd3c6fe8ecf..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/smooth_filter.py +++ /dev/null @@ -1,407 +0,0 @@ -""" -Code cc from https://github.com/LouieYang/deep-photo-styletransfer-tf/blob/master/smooth_local_affine.py -""" -src = ''' - #include "/usr/local/cuda/include/math_functions.h" - #define TB 256 - #define EPS 1e-7 - - __device__ bool InverseMat4x4(double m_in[4][4], double inv_out[4][4]) { - double m[16], inv[16]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - m[i * 4 + j] = m_in[i][j]; - } - } - - inv[0] = m[5] * m[10] * m[15] - - m[5] * m[11] * m[14] - - m[9] * m[6] * m[15] + - m[9] * m[7] * m[14] + - m[13] * m[6] * m[11] - - m[13] * m[7] * m[10]; - - inv[4] = -m[4] * m[10] * m[15] + - m[4] * m[11] * m[14] + - m[8] * m[6] * m[15] - - m[8] * m[7] * m[14] - - m[12] * m[6] * m[11] + - m[12] * m[7] * m[10]; - - inv[8] = m[4] * m[9] * m[15] - - m[4] * m[11] * m[13] - - m[8] * m[5] * m[15] + - m[8] * m[7] * m[13] + - m[12] * m[5] * m[11] - - m[12] * m[7] * m[9]; - - inv[12] = -m[4] * m[9] * m[14] + - m[4] * m[10] * m[13] + - m[8] * m[5] * m[14] - - m[8] * m[6] * m[13] - - m[12] * m[5] * m[10] + - m[12] * m[6] * m[9]; - - inv[1] = -m[1] * m[10] * m[15] + - m[1] * m[11] * m[14] + - m[9] * m[2] * m[15] - - m[9] * m[3] * m[14] - - m[13] * m[2] * m[11] + - m[13] * m[3] * m[10]; - - inv[5] = m[0] * m[10] * m[15] - - m[0] * m[11] * m[14] - - m[8] * m[2] * m[15] + - m[8] * m[3] * m[14] + - m[12] * m[2] * m[11] - - m[12] * m[3] * m[10]; - - inv[9] = -m[0] * m[9] * m[15] + - m[0] * m[11] * m[13] + - m[8] * m[1] * m[15] - - m[8] * m[3] * m[13] - - m[12] * m[1] * m[11] + - m[12] * m[3] * m[9]; - - inv[13] = m[0] * m[9] * m[14] - - m[0] * m[10] * m[13] - - m[8] * m[1] * m[14] + - m[8] * m[2] * m[13] + - m[12] * m[1] * m[10] - - m[12] * m[2] * m[9]; - - inv[2] = m[1] * m[6] * m[15] - - m[1] * m[7] * m[14] - - m[5] * m[2] * m[15] + - m[5] * m[3] * m[14] + - m[13] * m[2] * m[7] - - m[13] * m[3] * m[6]; - - inv[6] = -m[0] * m[6] * m[15] + - m[0] * m[7] * m[14] + - m[4] * m[2] * m[15] - - m[4] * m[3] * m[14] - - m[12] * m[2] * m[7] + - m[12] * m[3] * m[6]; - - inv[10] = m[0] * m[5] * m[15] - - m[0] * m[7] * m[13] - - m[4] * m[1] * m[15] + - m[4] * m[3] * m[13] + - m[12] * m[1] * m[7] - - m[12] * m[3] * m[5]; - - inv[14] = -m[0] * m[5] * m[14] + - m[0] * m[6] * m[13] + - m[4] * m[1] * m[14] - - m[4] * m[2] * m[13] - - m[12] * m[1] * m[6] + - m[12] * m[2] * m[5]; - - inv[3] = -m[1] * m[6] * m[11] + - m[1] * m[7] * m[10] + - m[5] * m[2] * m[11] - - m[5] * m[3] * m[10] - - m[9] * m[2] * m[7] + - m[9] * m[3] * m[6]; - - inv[7] = m[0] * m[6] * m[11] - - m[0] * m[7] * m[10] - - m[4] * m[2] * m[11] + - m[4] * m[3] * m[10] + - m[8] * m[2] * m[7] - - m[8] * m[3] * m[6]; - - inv[11] = -m[0] * m[5] * m[11] + - m[0] * m[7] * m[9] + - m[4] * m[1] * m[11] - - m[4] * m[3] * m[9] - - m[8] * m[1] * m[7] + - m[8] * m[3] * m[5]; - - inv[15] = m[0] * m[5] * m[10] - - m[0] * m[6] * m[9] - - m[4] * m[1] * m[10] + - m[4] * m[2] * m[9] + - m[8] * m[1] * m[6] - - m[8] * m[2] * m[5]; - - double det = m[0] * inv[0] + m[1] * inv[4] + m[2] * inv[8] + m[3] * inv[12]; - - if (abs(det) < 1e-9) { - return false; - } - - - det = 1.0 / det; - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - inv_out[i][j] = inv[i * 4 + j] * det; - } - } - - return true; - } - - extern "C" - __global__ void best_local_affine_kernel( - float *output, float *input, float *affine_model, - int h, int w, float epsilon, int kernel_radius - ) - { - int size = h * w; - int id = blockIdx.x * blockDim.x + threadIdx.x; - - if (id < size) { - int x = id % w, y = id / w; - - double Mt_M[4][4] = {}; // 4x4 - double invMt_M[4][4] = {}; - double Mt_S[3][4] = {}; // RGB -> 1x4 - double A[3][4] = {}; - for (int i = 0; i < 4; i++) - for (int j = 0; j < 4; j++) { - Mt_M[i][j] = 0, invMt_M[i][j] = 0; - if (i != 3) { - Mt_S[i][j] = 0, A[i][j] = 0; - if (i == j) - Mt_M[i][j] = 1e-3; - } - } - - for (int dy = -kernel_radius; dy <= kernel_radius; dy++) { - for (int dx = -kernel_radius; dx <= kernel_radius; dx++) { - - int xx = x + dx, yy = y + dy; - int id2 = yy * w + xx; - - if (0 <= xx && xx < w && 0 <= yy && yy < h) { - - Mt_M[0][0] += input[id2 + 2*size] * input[id2 + 2*size]; - Mt_M[0][1] += input[id2 + 2*size] * input[id2 + size]; - Mt_M[0][2] += input[id2 + 2*size] * input[id2]; - Mt_M[0][3] += input[id2 + 2*size]; - - Mt_M[1][0] += input[id2 + size] * input[id2 + 2*size]; - Mt_M[1][1] += input[id2 + size] * input[id2 + size]; - Mt_M[1][2] += input[id2 + size] * input[id2]; - Mt_M[1][3] += input[id2 + size]; - - Mt_M[2][0] += input[id2] * input[id2 + 2*size]; - Mt_M[2][1] += input[id2] * input[id2 + size]; - Mt_M[2][2] += input[id2] * input[id2]; - Mt_M[2][3] += input[id2]; - - Mt_M[3][0] += input[id2 + 2*size]; - Mt_M[3][1] += input[id2 + size]; - Mt_M[3][2] += input[id2]; - Mt_M[3][3] += 1; - - Mt_S[0][0] += input[id2 + 2*size] * output[id2 + 2*size]; - Mt_S[0][1] += input[id2 + size] * output[id2 + 2*size]; - Mt_S[0][2] += input[id2] * output[id2 + 2*size]; - Mt_S[0][3] += output[id2 + 2*size]; - - Mt_S[1][0] += input[id2 + 2*size] * output[id2 + size]; - Mt_S[1][1] += input[id2 + size] * output[id2 + size]; - Mt_S[1][2] += input[id2] * output[id2 + size]; - Mt_S[1][3] += output[id2 + size]; - - Mt_S[2][0] += input[id2 + 2*size] * output[id2]; - Mt_S[2][1] += input[id2 + size] * output[id2]; - Mt_S[2][2] += input[id2] * output[id2]; - Mt_S[2][3] += output[id2]; - } - } - } - - bool success = InverseMat4x4(Mt_M, invMt_M); - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - for (int k = 0; k < 4; k++) { - A[i][j] += invMt_M[j][k] * Mt_S[i][k]; - } - } - } - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - int affine_id = i * 4 + j; - affine_model[12 * id + affine_id] = A[i][j]; - } - } - } - return ; - } - - extern "C" - __global__ void bilateral_smooth_kernel( - float *affine_model, float *filtered_affine_model, float *guide, - int h, int w, int kernel_radius, float sigma1, float sigma2 - ) - { - int id = blockIdx.x * blockDim.x + threadIdx.x; - int size = h * w; - if (id < size) { - int x = id % w; - int y = id / w; - - double sum_affine[12] = {}; - double sum_weight = 0; - for (int dx = -kernel_radius; dx <= kernel_radius; dx++) { - for (int dy = -kernel_radius; dy <= kernel_radius; dy++) { - int yy = y + dy, xx = x + dx; - int id2 = yy * w + xx; - if (0 <= xx && xx < w && 0 <= yy && yy < h) { - float color_diff1 = guide[yy*w + xx] - guide[y*w + x]; - float color_diff2 = guide[yy*w + xx + size] - guide[y*w + x + size]; - float color_diff3 = guide[yy*w + xx + 2*size] - guide[y*w + x + 2*size]; - float color_diff_sqr = - (color_diff1*color_diff1 + color_diff2*color_diff2 + color_diff3*color_diff3) / 3; - - float v1 = exp(-(dx * dx + dy * dy) / (2 * sigma1 * sigma1)); - float v2 = exp(-(color_diff_sqr) / (2 * sigma2 * sigma2)); - float weight = v1 * v2; - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - int affine_id = i * 4 + j; - sum_affine[affine_id] += weight * affine_model[id2*12 + affine_id]; - } - } - sum_weight += weight; - } - } - } - - for (int i = 0; i < 3; i++) { - for (int j = 0; j < 4; j++) { - int affine_id = i * 4 + j; - filtered_affine_model[id*12 + affine_id] = sum_affine[affine_id] / sum_weight; - } - } - } - return ; - } - - - extern "C" - __global__ void reconstruction_best_kernel( - float *input, float *filtered_affine_model, float *filtered_best_output, - int h, int w - ) - { - int id = blockIdx.x * blockDim.x + threadIdx.x; - int size = h * w; - if (id < size) { - double out1 = - input[id + 2*size] * filtered_affine_model[id*12 + 0] + // A[0][0] + - input[id + size] * filtered_affine_model[id*12 + 1] + // A[0][1] + - input[id] * filtered_affine_model[id*12 + 2] + // A[0][2] + - filtered_affine_model[id*12 + 3]; //A[0][3]; - double out2 = - input[id + 2*size] * filtered_affine_model[id*12 + 4] + //A[1][0] + - input[id + size] * filtered_affine_model[id*12 + 5] + //A[1][1] + - input[id] * filtered_affine_model[id*12 + 6] + //A[1][2] + - filtered_affine_model[id*12 + 7]; //A[1][3]; - double out3 = - input[id + 2*size] * filtered_affine_model[id*12 + 8] + //A[2][0] + - input[id + size] * filtered_affine_model[id*12 + 9] + //A[2][1] + - input[id] * filtered_affine_model[id*12 + 10] + //A[2][2] + - filtered_affine_model[id*12 + 11]; // A[2][3]; - - filtered_best_output[id] = out1; - filtered_best_output[id + size] = out2; - filtered_best_output[id + 2*size] = out3; - } - return ; - } - ''' - -import cv2 -import torch -import numpy as np -from PIL import Image -from cupy.cuda import function -from pynvrtc.compiler import Program -from collections import namedtuple - - -def smooth_local_affine(output_cpu, input_cpu, epsilon, patch, h, w, f_r, f_e): - # program = Program(src.encode('utf-8'), 'best_local_affine_kernel.cu'.encode('utf-8')) - # ptx = program.compile(['-I/usr/local/cuda/include'.encode('utf-8')]) - program = Program(src, 'best_local_affine_kernel.cu') - ptx = program.compile(['-I/usr/local/cuda/include']) - m = function.Module() - m.load(bytes(ptx.encode())) - - _reconstruction_best_kernel = m.get_function('reconstruction_best_kernel') - _bilateral_smooth_kernel = m.get_function('bilateral_smooth_kernel') - _best_local_affine_kernel = m.get_function('best_local_affine_kernel') - Stream = namedtuple('Stream', ['ptr']) - s = Stream(ptr=torch.cuda.current_stream().cuda_stream) - - filter_radius = f_r - sigma1 = filter_radius / 3 - sigma2 = f_e - radius = (patch - 1) / 2 - - filtered_best_output = torch.zeros(np.shape(input_cpu)).cuda() - affine_model = torch.zeros((h * w, 12)).cuda() - filtered_affine_model =torch.zeros((h * w, 12)).cuda() - - input_ = torch.from_numpy(input_cpu).cuda() - output_ = torch.from_numpy(output_cpu).cuda() - _best_local_affine_kernel( - grid=(int((h * w) / 256 + 1), 1), - block=(256, 1, 1), - args=[output_.data_ptr(), input_.data_ptr(), affine_model.data_ptr(), - np.int32(h), np.int32(w), np.float32(epsilon), np.int32(radius)], stream=s - ) - - _bilateral_smooth_kernel( - grid=(int((h * w) / 256 + 1), 1), - block=(256, 1, 1), - args=[affine_model.data_ptr(), filtered_affine_model.data_ptr(), input_.data_ptr(), np.int32(h), np.int32(w), np.int32(f_r), np.float32(sigma1), np.float32(sigma2)], stream=s - ) - - _reconstruction_best_kernel( - grid=(int((h * w) / 256 + 1), 1), - block=(256, 1, 1), - args=[input_.data_ptr(), filtered_affine_model.data_ptr(), filtered_best_output.data_ptr(), - np.int32(h), np.int32(w)], stream=s - ) - numpy_filtered_best_output = filtered_best_output.cpu().numpy() - return numpy_filtered_best_output - - -def smooth_filter(initImg, contentImg, f_radius=15,f_edge=1e-1): - ''' - :param initImg: intermediate output. Either image path or PIL Image - :param contentImg: content image output. Either path or PIL Image - :return: stylized output image. PIL Image - ''' - if type(initImg) == str: - initImg = Image.open(initImg).convert("RGB") - best_image_bgr = np.array(initImg, dtype=np.float32) - bW, bH, bC = best_image_bgr.shape - best_image_bgr = best_image_bgr[:, :, ::-1] - best_image_bgr = best_image_bgr.transpose((2, 0, 1)) - - if type(contentImg) == str: - contentImg = Image.open(contentImg).convert("RGB") - content_input = contentImg.resize((bH,bW)) - else: - content_input = cv2.resize(contentImg,(bH,bW)) - content_input = np.array(content_input, dtype=np.float32) - content_input = content_input[:, :, ::-1] - content_input = content_input.transpose((2, 0, 1)) - input_ = np.ascontiguousarray(content_input, dtype=np.float32) / 255. - _, H, W = np.shape(input_) - output_ = np.ascontiguousarray(best_image_bgr, dtype=np.float32) / 255. - best_ = smooth_local_affine(output_, input_, 1e-7, 3, H, W, f_radius, f_edge) - best_ = best_.transpose(1, 2, 0) - result = Image.fromarray(np.uint8(np.clip(best_ * 255., 0, 255.))) - return result diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/utils.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/utils.py deleted file mode 100644 index 79c4a274028d3c396aa686476cb0ae400113a1fe..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/libs/utils.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import division -import os -import cv2 -import time -import torch -import scipy.misc -import numpy as np -import scipy.sparse -from PIL import Image -import scipy.sparse.linalg -from cv2.ximgproc import jointBilateralFilter -from numpy.lib.stride_tricks import as_strided - -def whiten(cF): - cFSize = cF.size() - c_mean = torch.mean(cF,1) # c x (h x w) - c_mean = c_mean.unsqueeze(1).expand_as(cF) - cF = cF - c_mean - - contentConv = torch.mm(cF,cF.t()).div(cFSize[1]-1) + torch.eye(cFSize[0]).double() - c_u,c_e,c_v = torch.svd(contentConv,some=False) - - k_c = cFSize[0] - for i in range(cFSize[0]): - if c_e[i] < 0.00001: - k_c = i - break - - c_d = (c_e[0:k_c]).pow(-0.5) - step1 = torch.mm(c_v[:,0:k_c],torch.diag(c_d)) - step2 = torch.mm(step1,(c_v[:,0:k_c].t())) - whiten_cF = torch.mm(step2,cF) - return whiten_cF - -def numpy2cv2(cont,style,prop,width,height): - cont = cont.transpose((1,2,0)) - cont = cont[...,::-1] - cont = cont * 255 - cont = cv2.resize(cont,(width,height)) - #cv2.resize(iimg,(width,height)) - style = style.transpose((1,2,0)) - style = style[...,::-1] - style = style * 255 - style = cv2.resize(style,(width,height)) - - prop = prop.transpose((1,2,0)) - prop = prop[...,::-1] - prop = prop * 255 - prop = cv2.resize(prop,(width,height)) - - #return np.concatenate((cont,np.concatenate((style,prop),axis=1)),axis=1) - return prop,cont - -def makeVideo(content,style,props,outf): - print('Stack transferred frames back to video...') - layers,height,width = content[0].shape - fourcc = cv2.VideoWriter_fourcc(*'MJPG') - video = cv2.VideoWriter(os.path.join(outf,'transfer.avi'),fourcc,10.0,(width,height)) - ori_video = cv2.VideoWriter(os.path.join(outf,'content.avi'),fourcc,10.0,(width,height)) - for j in range(len(content)): - prop,cont = numpy2cv2(content[j],style,props[j],width,height) - cv2.imwrite('prop.png',prop) - cv2.imwrite('content.png',cont) - # TODO: this is ugly, fix this - imgj = cv2.imread('prop.png') - imgc = cv2.imread('content.png') - - video.write(imgj) - ori_video.write(imgc) - # RGB or BRG, yuks - video.release() - ori_video.release() - os.remove('prop.png') - os.remove('content.png') - print('Transferred video saved at %s.'%outf) - -def print_options(opt): - message = '' - message += '----------------- Options ---------------\n' - for k, v in sorted(vars(opt).items()): - comment = '' - message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) - message += '----------------- End -------------------' - print(message) - - # save to the disk - expr_dir = os.path.join(opt.outf) - os.makedirs(expr_dir,exist_ok=True) - file_name = os.path.join(expr_dir, 'opt.txt') - with open(file_name, 'wt') as opt_file: - opt_file.write(message) - opt_file.write('\n') diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r31.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r31.pth deleted file mode 100644 index b7ddc0933441ef7622c0b0e752c0644828b9db4c..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r31.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3ccc3bbc97a15e1d002d0b13523543c518dda2d0346c8f4d39c1d381d8490f68 -size 2221888 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r41.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r41.pth deleted file mode 100644 index 7407cde48fb8b841852732487e5f9dad0b874cdb..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/dec_r41.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e6858f96d3d0882fa3b40652a0315928219086e4bcb0e3efbe43bd04ea631911 -size 14023509 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r31.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r31.pth deleted file mode 100644 index d8c9a170e28f28d84dd6c1cbfb161422aeedb48c..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r31.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8bb75be684b331a105a8fad266343556067e4eec888249f7abb693a66b5ad7e3 -size 11564438 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41.pth deleted file mode 100644 index d7ddefca5c34b0d5281c6fff1daeaa09570c6780..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:98cdafb2f553ea3071782255cc5e739eecf2a74fcdde280e5d7e09b8017fad4d -size 20627360 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41_spn.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41_spn.pth deleted file mode 100644 index b12cb57b30e0adc3e107ce6568603b89af640d90..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/r41_spn.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:79db437b39507a97b642b0f7ed00b86fb62c9bf49fd86d3eab0a62ce023a9db8 -size 3098678 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r31.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r31.pth deleted file mode 100644 index 0c058c9cd91bcb91957c390c981b5f39d629d522..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r31.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f9bffd632f87360d4de1bcd1627246d2b41a6278aa46e1c1fe7212796b646b7e -size 2223422 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r41.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r41.pth deleted file mode 100644 index d2f5756f6c0669490b202621ca7d5dc69838cbcf..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r41.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7949c5ea891cd75de2fce4a9cdb0a14f9bc1672053f5c92db80ded641b7e57d0 -size 14026238 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r51.pth b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r51.pth deleted file mode 100644 index ade4f4e2febab7f63b3a8aeeac1bd3f31a6dee66..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/models/vgg_r51.pth +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:71041e0ee08540e690cfa08982f2da86b590c90335468808854092ff07c81cfc -size 51784517 diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/real-time-demo.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/real-time-demo.py deleted file mode 100644 index 4d3565ee2a84db8ced9e72f1d39c6473e0d2937a..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer/real-time-demo.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import cv2 -import torch -import argparse -import numpy as np -from PIL import Image -from libs.Loader import Dataset -from libs.Matrix import MulLayer -from libs.utils import makeVideo -import torch.backends.cudnn as cudnn -from libs.models import encoder3,encoder4 -from libs.models import decoder3,decoder4 -import torchvision.transforms as transforms - -parser = argparse.ArgumentParser() -parser.add_argument("--vgg_dir", default='models/vgg_r31.pth', - help='pre-trained encoder path') -parser.add_argument("--decoder_dir", default='models/dec_r31.pth', - help='pre-trained decoder path') -parser.add_argument("--style", default="data/style/in2.jpg", - help='path to style image') -parser.add_argument("--matrixPath", default="models/r31.pth", - help='path to pre-trained model') -parser.add_argument('--fineSize', type=int, default=256, - help='crop image size') -parser.add_argument("--name",default="transferred_video", - help="name of generated video") -parser.add_argument("--layer",default="r31", - help="features of which layer to transfer") -parser.add_argument("--outf",default="real_time_demo_output", - help="output folder") - -################# PREPARATIONS ################# -opt = parser.parse_args() -opt.cuda = torch.cuda.is_available() -print(opt) -os.makedirs(opt.outf,exist_ok=True) -cudnn.benchmark = True - -################# DATA ################# -def loadImg(imgPath): - img = Image.open(imgPath).convert('RGB') - transform = transforms.Compose([ - transforms.Scale(opt.fineSize), - transforms.ToTensor()]) - return transform(img) -style = loadImg(opt.style).unsqueeze(0) - -################# MODEL ################# -if(opt.layer == 'r31'): - matrix = MulLayer(layer='r31') - vgg = encoder3() - dec = decoder3() -elif(opt.layer == 'r41'): - matrix = MulLayer(layer='r41') - vgg = encoder4() - dec = decoder4() -vgg.load_state_dict(torch.load(opt.vgg_dir)) -dec.load_state_dict(torch.load(opt.dec_dir)) -matrix.load_state_dict(torch.load(opt.matrixPath)) -for param in vgg.parameters(): - param.requires_grad = False -for param in dec.parameters(): - param.requires_grad = False -for param in matrix.parameters(): - param.requires_grad = False - -################# GLOBAL VARIABLE ################# -content = torch.Tensor(1,3,opt.fineSize,opt.fineSize) - -################# GPU ################# -if(opt.cuda): - vgg.cuda() - dec.cuda() - matrix.cuda() - - style = style.cuda() - content = content.cuda() - -totalTime = 0 -imageCounter = 0 -result_frames = [] -contents = [] -styles = [] -cap = cv2.VideoCapture(0) -cap.set(3,256) -cap.set(4,512) -fourcc = cv2.VideoWriter_fourcc(*'MJPG') -out = cv2.VideoWriter(os.path.join(opt.outf,opt.name+'.avi'),fourcc,20.0,(512,256)) - -with torch.no_grad(): - sF = vgg(style) - -while(True): - ret,frame = cap.read() - frame = cv2.resize(frame,(512,256),interpolation=cv2.INTER_CUBIC) - frame = frame.transpose((2,0,1)) - frame = frame[::-1,:,:] - frame = frame/255.0 - frame = torch.from_numpy(frame.copy()).unsqueeze(0) - content.data.resize_(frame.size()).copy_(frame) - with torch.no_grad(): - cF = vgg(content) - if(opt.layer == 'r41'): - feature,transmatrix = matrix(cF[opt.layer],sF[opt.layer]) - else: - feature,transmatrix = matrix(cF,sF) - transfer = dec(feature) - transfer = transfer.clamp(0,1).squeeze(0).data.cpu().numpy() - transfer = transfer.transpose((1,2,0)) - transfer = transfer[...,::-1] - out.write(np.uint8(transfer*255)) - cv2.imshow('frame',transfer) - if cv2.waitKey(1) & 0xFF == ord('q'): - break - -# When everything done, release the capture -out.release() -cap.release() -cv2.destroyAllWindows() diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_matrix.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_matrix.py deleted file mode 100644 index 3ed08fe4baed3fd97156a40c03ea0ae1ad8ece0f..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_matrix.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch.nn as nn - -from copy import deepcopy - -from selectionConv import SelectionConv - -class CNN(nn.Module): - def __init__(self,matrixSize=32): - super(CNN,self).__init__() - - self.conv1 = SelectionConv(512,256,3,padding_mode="zeros") - self.conv2 = SelectionConv(256,128,3,padding_mode="zeros") - self.conv3 = SelectionConv(128,matrixSize,3,padding_mode="zeros") - self.relu = torch.nn.ReLU() - - self.fc = nn.Linear(matrixSize*matrixSize,matrixSize*matrixSize) - - def forward(self,x,edge_index,selections,interp_values=None): - out = self.relu(self.conv1(x,edge_index,selections,interp_values)) - out = self.relu(self.conv2(out,edge_index,selections,interp_values)) - out = self.conv3(out,edge_index,selections,interp_values) - - n,ch = out.size() - - out = torch.mm(out.t(), out).div(n) - - out = out.view(-1) - return self.fc(out) - -class TransformLayer(nn.Module): - def __init__(self,matrixSize=32): - super(TransformLayer,self).__init__() - self.snet = CNN(matrixSize) - self.cnet = CNN(matrixSize) - self.matrixSize = matrixSize - - self.compress = SelectionConv(512,matrixSize,1) - self.unzip = SelectionConv(matrixSize,512,1) - - def forward(self,cF,sF,content_edge_index,content_selections,style_edge_index,style_selections,content_interps=None,style_interps=None,trans=True): - cMean = torch.mean(cF,dim=0,keepdim=True) - cF = cF - cMean - - sMean = torch.mean(sF,dim=0,keepdim=True) - sF = sF - sMean - - compress_content = self.compress(cF,content_edge_index,content_selections,content_interps) - - if(trans): - cMatrix = self.cnet(cF,content_edge_index,content_selections,content_interps) - sMatrix = self.snet(sF,style_edge_index,style_selections,style_interps) - - sMatrix = sMatrix.view(self.matrixSize,self.matrixSize) - cMatrix = cMatrix.view(self.matrixSize,self.matrixSize) - transmatrix = torch.mm(sMatrix,cMatrix) - transfeature = torch.mm(transmatrix,compress_content.transpose(1,0)) - out = self.unzip(transfeature.transpose(1,0),content_edge_index,content_selections,content_interps) - out = out + sMean - return out, transmatrix - else: - out = self.unzip(compress_content,content_edge_index,content_selections,content_interps) - out = out + cMean - return out - - def copy_weights(self, model): - self.cnet.conv1.copy_weights(model.cnet.convs[0].weight,model.cnet.convs[0].bias) - self.cnet.conv2.copy_weights(model.cnet.convs[2].weight,model.cnet.convs[2].bias) - self.cnet.conv3.copy_weights(model.cnet.convs[4].weight,model.cnet.convs[4].bias) - - self.snet.conv1.copy_weights(model.snet.convs[0].weight,model.snet.convs[0].bias) - self.snet.conv2.copy_weights(model.snet.convs[2].weight,model.snet.convs[2].bias) - self.snet.conv3.copy_weights(model.snet.convs[4].weight,model.snet.convs[4].bias) - - self.cnet.fc = deepcopy(model.cnet.fc) - self.snet.fc = deepcopy(model.snet.fc) - - self.compress.copy_weights(model.compress.weight,model.compress.bias) - self.unzip.copy_weights(model.unzip.weight,model.unzip.bias) diff --git a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_vgg.py b/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_vgg.py deleted file mode 100644 index 179325787c65be0cf91e6d33d992850a5aa32f3b..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/LinearStyleTransfer_vgg.py +++ /dev/null @@ -1,172 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.nn import Sequential as Seq, Linear as Lin, ReLU - -from selectionConv import SelectionConv - -from pooling import unpoolCluster, maxPoolCluster - -class encoder(torch.nn.Module): - - def __init__(self,padding_mode="reflect"): - super(encoder, self).__init__() - self.conv1 = SelectionConv(3,3,1,padding_mode=padding_mode) - self.conv2 = SelectionConv(3,64,3,padding_mode=padding_mode) - self.conv3 = SelectionConv(64,64,3,padding_mode=padding_mode) - - self.conv4 = SelectionConv(64,128,3,padding_mode=padding_mode) - self.conv5 = SelectionConv(128,128,3,padding_mode=padding_mode) - - self.conv6 = SelectionConv(128,256,3,padding_mode=padding_mode) - self.conv7 = SelectionConv(256,256,3,padding_mode=padding_mode) - self.conv8 = SelectionConv(256,256,3,padding_mode=padding_mode) - self.conv9 = SelectionConv(256,256,3,padding_mode=padding_mode) - - self.conv10 = SelectionConv(256,512,3,padding_mode=padding_mode) - - self.relu = torch.nn.ReLU() - - def copy_weights(self, model): - self.conv1.copy_weights(model.conv1.weight,model.conv1.bias) - self.conv2.copy_weights(model.conv2.weight,model.conv2.bias) - self.conv3.copy_weights(model.conv3.weight,model.conv3.bias) - self.conv4.copy_weights(model.conv4.weight,model.conv4.bias) - self.conv5.copy_weights(model.conv5.weight,model.conv5.bias) - self.conv6.copy_weights(model.conv6.weight,model.conv6.bias) - self.conv7.copy_weights(model.conv7.weight,model.conv7.bias) - self.conv8.copy_weights(model.conv8.weight,model.conv8.bias) - self.conv9.copy_weights(model.conv9.weight,model.conv9.bias) - self.conv10.copy_weights(model.conv10.weight,model.conv10.bias) - - - def forward(self,graph): - - edge_index = graph.edge_indexes[0] - selections = graph.selections_list[0] - interps = graph.interps_list[0] if hasattr(graph,'interps_list') else None - - output = {} - out = self.conv1(graph.x,edge_index,selections,interps) - out = self.conv2(out,edge_index,selections,interps) - output['r11'] = self.relu(out) - - out = self.conv3(output['r11'],edge_index,selections,interps) - output['r12'] = self.relu(out) - - output['p1'] = maxPoolCluster(output['r12'],graph.clusters[0]) - edge_index = graph.edge_indexes[1] - selections = graph.selections_list[1] - interps = graph.interps_list[1] if hasattr(graph,'interps_list') else None - - out = self.conv4(output['p1'],edge_index,selections,interps) - output['r21'] = self.relu(out) - - out = self.conv5(output['r21'],edge_index,selections,interps) - output['r22'] = self.relu(out) - - output['p2'] = maxPoolCluster(output['r22'],graph.clusters[1]) - edge_index = graph.edge_indexes[2] - selections = graph.selections_list[2] - interps = graph.interps_list[2] if hasattr(graph,'interps_list') else None - - out = self.conv6(output['p2'],edge_index,selections,interps) - output['r31'] = self.relu(out) - #if(matrix31 is not None): - # feature3,transmatrix3 = matrix31(output['r31'],sF['r31']) - # out = self.reflecPad7(feature3) - #else: - # out = self.reflecPad7() - out = self.conv7(output['r31'],edge_index,selections,interps) - output['r32'] = self.relu(out) - - out = self.conv8(output['r32'],edge_index,selections,interps) - output['r33'] = self.relu(out) - - out = self.conv9(output['r33'],edge_index,selections,interps) - output['r34'] = self.relu(out) - - output['p3'] = maxPoolCluster(output['r34'],graph.clusters[2]) - edge_index = graph.edge_indexes[3] - selections = graph.selections_list[3] - interps = graph.interps_list[3] if hasattr(graph,'interps_list') else None - - out = self.conv10(output['p3'],edge_index,selections,interps) - output['r41'] = self.relu(out) - - return output - -class decoder(torch.nn.Module): - - def __init__(self,padding_mode="reflect"): - super(decoder, self).__init__() - self.conv11 = SelectionConv(512,256,3,padding_mode=padding_mode) - - self.conv12 = SelectionConv(256,256,3,padding_mode=padding_mode) - self.conv13 = SelectionConv(256,256,3,padding_mode=padding_mode) - self.conv14 = SelectionConv(256,256,3,padding_mode=padding_mode) - self.conv15 = SelectionConv(256,128,3,padding_mode=padding_mode) - - self.conv16 = SelectionConv(128,128,3,padding_mode=padding_mode) - self.conv17 = SelectionConv(128,64,3,padding_mode=padding_mode) - - self.conv18 = SelectionConv(64,64,3,padding_mode=padding_mode) - self.conv19 = SelectionConv(64,3,3,padding_mode=padding_mode) - - self.relu = torch.nn.ReLU() - - def copy_weights(self, model): - self.conv11.copy_weights(model.conv11.weight,model.conv11.bias) - self.conv12.copy_weights(model.conv12.weight,model.conv12.bias) - self.conv13.copy_weights(model.conv13.weight,model.conv13.bias) - self.conv14.copy_weights(model.conv14.weight,model.conv14.bias) - self.conv15.copy_weights(model.conv15.weight,model.conv15.bias) - self.conv16.copy_weights(model.conv16.weight,model.conv16.bias) - self.conv17.copy_weights(model.conv17.weight,model.conv17.bias) - self.conv18.copy_weights(model.conv18.weight,model.conv18.bias) - self.conv19.copy_weights(model.conv19.weight,model.conv19.bias) - - - def forward(self,x,graph): - - edge_index = graph.edge_indexes[3] - selections = graph.selections_list[3] - interps = graph.interps_list[3] if hasattr(graph,'interps_list') else None - - out = self.conv11(x,edge_index,selections,interps) - out = self.relu(out) - - out = unpoolCluster(out,graph.clusters[2]) - edge_index = graph.edge_indexes[2] - selections = graph.selections_list[2] - interps = graph.interps_list[2] if hasattr(graph,'interps_list') else None - - out = self.conv12(out,edge_index,selections,interps) - out = self.relu(out) - out = self.conv13(out,edge_index,selections,interps) - out = self.relu(out) - out = self.conv14(out,edge_index,selections,interps) - out = self.relu(out) - out = self.conv15(out,edge_index,selections,interps) - out = self.relu(out) - - out = unpoolCluster(out,graph.clusters[1]) - edge_index = graph.edge_indexes[1] - selections = graph.selections_list[1] - interps = graph.interps_list[1] if hasattr(graph,'interps_list') else None - - out = self.conv16(out,edge_index,selections,interps) - out = self.relu(out) - out = self.conv17(out,edge_index,selections,interps) - out = self.relu(out) - - out = unpoolCluster(out,graph.clusters[0]) - edge_index = graph.edge_indexes[0] - selections = graph.selections_list[0] - interps = graph.interps_list[0] if hasattr(graph,'interps_list') else None - - out = self.conv18(out,edge_index,selections,interps) - out = self.relu(out) - out = self.conv19(out,edge_index,selections,interps) - - return out - diff --git a/FastSplatStyler_huggingface/graph_networks/__init__.py b/FastSplatStyler_huggingface/graph_networks/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/FastSplatStyler_huggingface/graph_networks/graph_transforms.py b/FastSplatStyler_huggingface/graph_networks/graph_transforms.py deleted file mode 100644 index 0a89590343c25eb76cde1f0366920f3cf531a52f..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/graph_networks/graph_transforms.py +++ /dev/null @@ -1,389 +0,0 @@ -""" graph_transforms.py - -This is the implementation of transforming a traditional CNN -to a SelectionConv-based graph CNN -So far this is just used for segmentation -""" -from copy import deepcopy -from typing import Dict, Iterable, OrderedDict, Tuple, Union #,Literal Only supported in Python 3.8+ -from typing_extensions import Literal -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_geometric.data import Data -from torchvision.models.segmentation.fcn import FCN - -from selectionConv import SelectionConv -import pooling as P - -def transform_network(network: nn.Module): - """ Transforms a neural network from a tensor based network to a graph based network - - Parameters - ---------- - - network: the network to transform - - Returns - ------- - - the transformed network - """ - network = deepcopy(network) - if type(network) in __MAPPING: - with torch.no_grad(): - return __MAPPING[type(network)].from_torch(network) - if not isinstance(network, nn.Module): - raise ValueError(f"Must be of type Module but got: {type(network)}") - if not list(network.children()): - raise NotImplementedError(f"{type(network)} is not implemented yet") - for name, child in network.named_children(): - transformed_child = transform_network(child) - setattr(network, name, transformed_child) - return network - - -class GraphTracker: - """ a wrapper around the graph data for easily overwriting the forward function of existing modules. - - Parameters - ---------- - - graph: the graph data - - level: the current depth the graph is being operated on - - x: the node data at the current level - """ - def __init__(self, graph, x=None, level=0): - self.graph = graph - self.x = graph.x if x is None else x - self.level = level - - def from_x(self, x): - """ create the same graph with different node values""" - return GraphTracker(self.graph,x,level=self.level) - - def edge_index(self): - return self.graph.edge_indexes[self.level] - - def selections(self): - return self.graph.selections_list[self.level] - - def interps(self): - if hasattr(self.graph,"interps_list"): - return self.graph.interps_list[self.level] - else: - return None - - def cluster(self): - return self.graph.clusters[self.level] - - def __iadd__(self, other): - self.x = self.x + other.x - return self - - def __repr__(self): - return f"GraphTracker(x={tuple(self.x.shape)},level={self.level})" - - - -def _single(pair, name): - """ converts a tuple into a single number - - Parameters - ---------- - - pair: the potential pair of values - - name: the name of the values for more readable errors - - Returns - ------- - - the single value - """ - if isinstance(pair, int): - return pair - if not isinstance(pair, tuple): - raise ValueError(f"{name} must either be int or tuple but got: {type(pair)}") - if len(pair) != 2: - raise ValueError(f"{name} must be a 2-tuple but got: {pair}") - if pair[0] != pair[1]: - raise ValueError(f"{name} must be a square tuple") - return pair[0] - - -class SelModule(nn.Module): - """ A super class for all graph based modules to inherit from - """ - @classmethod - def from_torch(cls, network): - """ creates a new graph based module from an existing 2d based module and copies weights accordingly. Each child class should implement this method - - Parameters - ---------- - - network: the existing 2d based module - - Returns - ------- - - the new graph based module - """ - raise NotImplementedError - - -class SelConv(SelModule, nn.modules.conv._ConvNd): - """ A wrapper class around the SelectionConv class that allows for easy - use in a transformed network - - Parameters - ---------- - - in_channels: the number of incoming channels - - out_channels: the number of outgoing channels - - kernel_size: the size of the convolution kernel - - stride: the stride at which to perform convolution - - padding: the amount of padding to be used - - dilation: the dilation of the kernel - - groups: the groups of filters for the convolution - - bias: whether or not to include a bias - - padding_mode: the type of padding to be used - """ - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]]=1, - padding: Union[int, Tuple[int, int]]=0, - dilation: Union[int, Tuple[int, int]]=1, - groups: int = 1, - bias: bool=True, - padding_mode: str='zeros', - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, (0,), groups, bias, padding_mode, **factory_kwargs) - self.single_stride = _single(stride, "stride") - if self.single_stride not in (1, 2): - raise NotImplementedError(f"Only strides of 1 and 2 are supported but got {stride}") - self.conv_operation = SelectionConv( - in_channels, - out_channels, - _single(kernel_size, "kernel_size"), - _single(dilation, "dilation"), - padding_mode, - ) - - def forward(self, inputs: GraphTracker): - x = self.conv_operation(inputs.x, inputs.edge_index(), inputs.selections(), inputs.interps()) - ret = inputs.from_x(x) - if self.single_stride == 2: - x = P.stridePoolCluster(x, ret.cluster()) - ret.x = x - ret.level += 1 - return ret - - @classmethod - def from_torch(cls, network): - ret = SelConv( - network.in_channels, - network.out_channels, - network.kernel_size, - network.stride, - network.padding, - network.dilation, - network.groups, - network.bias is not None, - network.padding_mode, - ) - ret.conv_operation.copy_weights(network.weight, network.bias) - return ret - -class SelMaxPool(SelModule): - """ A graph based max pool module - - Parameters - ---------- - - kernel_size: the size of the maxpool kernel - """ - - def __init__(self, kernel_size): - super().__init__() - self.kernel_size = kernel_size - - def forward(self, inputs): - x = P.maxPoolKernel(inputs.x, inputs.edge_index(), inputs.selections(), inputs.cluster(), self.kernel_size) - ret = inputs.from_x(x) - ret.level += 1 - return ret - - @classmethod - def from_torch(cls, network): - ret = SelMaxPool(network.kernel_size) - return ret - -class SelBatchNorm(SelModule): - """ A graph based BatchNorm module - """ - def __init__(self,num_features): - super().__init__() - self.bn = nn.BatchNorm1d(num_features) - #self.bn = SimpleBatchNorm() - - def forward(self, inputs): - x = self.bn(inputs.x) - ret = inputs.from_x(x) - return ret - - def copyBatchNorm(self,source): - self.bn.weight = source.weight - self.bn.bias = source.bias - self.bn.running_mean = source.running_mean - self.bn.running_var = source.running_var - self.bn.eps = source.eps - - @classmethod - def from_torch(cls, network): - ret = SelBatchNorm(network.num_features) - ret.copyBatchNorm(network) - #ret.bn.set_values(network) - return ret - -class SelReLU(SelModule): - """ A graph based ReLU module - - Parameters - ---------- - - inplace: whether or not to perform relu in place - """ - def __init__(self, inplace=False): - super().__init__() - self.inplace = inplace - - def forward(self, inputs): - if self.inplace: - inputs.x = F.relu(inputs.x, self.inplace) - return inputs - else: - x = F.relu(inputs.x, self.inplace) - ret = inputs.from_x(x) - return ret - - @classmethod - def from_torch(cls, network): - return SelReLU(network.inplace) - - -class SelSequential(SelModule, nn.Sequential): - """ A graph based Sequential module - """ - @classmethod - def from_torch(cls, network: nn.Sequential): - return SelSequential(*map(transform_network, network)) - - -class SelDropout(SelModule): - """ A graph based dropout module - """ - def forward(self, inputs): - return inputs - - @classmethod - def from_torch(cls, network): - return SelDropout() - -def sel_binlinear_interp( - inputs: GraphTracker, - up_or_down: Literal["up", "down"]="up", - ) -> GraphTracker: - """ Performs bilinear interpolation as a single cluster step - - Parameters - ---------- - - inputs: the input graph - - up_or_down: either "up" or "down" indicating if it is upsampling or downsampling - - Returns - ------- - - the interpolated graph - """ - supported_up_or_downs = ("up", "down") - if up_or_down not in supported_up_or_downs: - raise ValueError(f"up_or_down must either be 'up' or 'down' not: {up_or_down}") - ret = inputs.from_x(inputs.x) - dx = -1 if up_or_down == "up" else 1 - ret.level += dx - cluster = ret.cluster() - up_edge_index = ret.edge_index() - - #up_selections = ret.selections() - #ret.x = P.unpoolBilinear(ret.x, cluster, up_edge_index, up_selections) - - up_interps = ret.interps() - ret.x = P.unpoolInterpolated(ret.x,cluster,up_edge_index,up_interps) - - #ret.x = P.unpoolCluster(inputs.x, inputs.clusters[inputs.cluster_id]) - return ret - - -def sel_interpolate( - inputs: GraphTracker, - target_level: int, - ) -> GraphTracker: - """ interpolates a graph to a given cluster_id - - Parameters - ---------- - - inputs: the input graph data - - target_cluster_id: the target cluster - - Returns - ------- - - the interpolated graph - """ - up_or_down = "up" if target_level < inputs.level else "down" - while inputs.level != target_level: - inputs = sel_binlinear_interp(inputs, up_or_down) - return inputs - - -class SelSimpleSegmentationModel(SelModule): - """ A graph version of the simple segmentation model defined in torchvision's segmentation model. This is needed since the interpolate function we use needs different parameters than what is used in torch. - """ - __constants__ = ["aux_classifier"] - def __init__(self, backbone, classifier, aux_classifier = None): - super().__init__() - self.backbone = backbone - self.classifier = classifier - self.aux_classifier = aux_classifier - - def forward(self, x: GraphTracker) -> Dict[str, GraphTracker]: - starting_level = x.level - features = self.backbone(x) - result = OrderedDict() - x = features["out"] - x = self.classifier(x) - x = sel_interpolate(x, starting_level) - result["out"] = x - - if self.aux_classifier is not None: - x = features["aux"] - x = self.aux_classifier(x) - x = sel_interpolate(x, starting_level) - result["aux"] = x - return result - - @classmethod - def from_torch(cls, network): - ret = SelSimpleSegmentationModel( - backbone = transform_network(network.backbone), - classifier = transform_network(network.classifier), - aux_classifier=transform_network(network.aux_classifier) if network.aux_classifier is not None else None, - ) - return ret - - -__MAPPING = { - nn.Conv2d: SelConv, - nn.BatchNorm2d: SelBatchNorm, - nn.ReLU: SelReLU, - nn.Sequential: SelSequential, - nn.Dropout: SelDropout, - nn.MaxPool2d: SelMaxPool, - FCN: SelSimpleSegmentationModel, -} diff --git a/FastSplatStyler_huggingface/mesh_config.py b/FastSplatStyler_huggingface/mesh_config.py deleted file mode 100644 index e7c236bb9f3399c6452dc7852e9e7aa8303f5d8f..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/mesh_config.py +++ /dev/null @@ -1,54 +0,0 @@ - -mesh_info = {} - -mesh_info["teddy"] = { - "load_directory":"mesh_data/original/teddy_mesh/", - "mesh_fn":"teddy.obj", - "texture_fn":"teddy.png", - "save_directory":"mesh_data/edited/teddy_mesh/" - } - -mesh_info["lime"] = { - "load_directory":"mesh_data/original/lime_mesh/", - "mesh_fn":"lime.obj", - "texture_fn":"lime.jpg", - "save_directory":"mesh_data/edited/lime_mesh/" - } - -mesh_info["crate"] = { - "load_directory":"mesh_data/original/crate_mesh/", - "mesh_fn":"wooden_crate_01_4k.gltf", - "texture_fn":"textures/wooden_crate_01_diff_4k.jpg", - "save_directory":"mesh_data/edited/crate_mesh/" - } - -mesh_info["barrel"] = { - "load_directory":"mesh_data/original/barrel_mesh/", - "mesh_fn":"barrel.obj", - #"mesh_fn":"barrel_03_4k.gltf", - "texture_fn":"textures/barrel_03_diff_4k.jpg", - "save_directory":"mesh_data/edited/barrel_mesh/" - } - -mesh_info["hammer"] = { - "load_directory":"mesh_data/original/hammer_mesh/", - "mesh_fn":"hammer.obj", - #"mesh_fn":"wooden_hammer_01_2k.gltf", - "texture_fn":"textures/wooden_hammer_01_diff_2k.jpg", - "save_directory":"mesh_data/edited/hammer_mesh/" - } - -mesh_info["crow"] = { - "load_directory":"mesh_data/original/crow_mesh/", - "mesh_fn":"Kruk.obj", - #"mesh_fn":"wooden_hammer_01_2k.gltf", - "texture_fn":"kruk-mat_Diffuse.jpg", - "save_directory":"mesh_data/edited/crow_mesh/" - } - -mesh_info["horse"] = { - "load_directory":"mesh_data/original/horse_mesh/", - "mesh_fn":"horse.obj", - "texture_fn":"teddy.png", - "save_directory":"mesh_data/edited/horse_mesh/" - } \ No newline at end of file diff --git a/FastSplatStyler_huggingface/mesh_helpers.py b/FastSplatStyler_huggingface/mesh_helpers.py deleted file mode 100644 index a55bde9f304b4ec89e88d332bb1bf248e7930723..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/mesh_helpers.py +++ /dev/null @@ -1,93 +0,0 @@ -import trimesh -import torch -import numpy as np - -def loadMesh(mesh_fn): - mesh = trimesh.load(mesh_fn, force="mesh") - #mesh = trimesh.load_mesh(mesh_fn) - return mesh - -def getUVs(mesh): - if (isinstance(mesh.visual,trimesh.visual.color.ColorVisuals)): - uvs = mesh.visual.to_texture().uv - else: - uvs = mesh.visual.uv - - return uvs - -def setTexture(mesh,texture): - from PIL import Image - if texture.dtype == np.float32 or texture.dtype == np.float64: - texture = (255*texture).astype(np.uint8) - im = Image.fromarray(texture) - tex = trimesh.visual.TextureVisuals(uv=getUVs(mesh),image=im) - new_mesh=trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=tex, validate=True, process=True) - return new_mesh - -def sampleSurface(mesh,N,return_x=False,device='cpu'): - result = trimesh.sample.sample_surface(mesh,N,face_weight=None,sample_color=return_x) - pos3D = result[0] - face_indexes = result[1] - if return_x: - x = result[2][:,:3]/255 # No alpha channels, between 0-1 - - # Determine normals (the same as face normals) - normals = mesh.face_normals[face_indexes] - - if return_x: - - # Determine UVs associated with 3D points - vertices_all = mesh.vertices - uvs_all = getUVs(mesh) - faces = mesh.faces[face_indexes] - - a = vertices_all[faces[:,0]] - b = vertices_all[faces[:,1]] - c = vertices_all[faces[:,2]] - - w0, w1, w2 = getBarycentricWeights(pos3D,a,b,c,use_numpy=True) - - w0 = np.expand_dims(w0,axis=1) - w1 = np.expand_dims(w1,axis=1) - w2 = np.expand_dims(w2,axis=1) - - uvs = w0 * uvs_all[faces[:,0]] + w1 * uvs_all[faces[:,1]] + w2 * uvs_all[faces[:,2]] - - uvs = torch.tensor(uvs,dtype=torch.float).to(device) - - # Return as tensor - pos3D = torch.tensor(pos3D,dtype=torch.float).to(device) - normals = torch.tensor(normals,dtype=torch.float).to(device) - - if return_x: - x = torch.tensor(x,dtype=torch.float).to(device) - return pos3D,normals,uvs,x - else: - return pos3D,normals - -def getBarycentricWeights(p,a,b,c,use_numpy=False): - - # Taken from https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates - v0 = b-a - v1 = c-a - v2 = p-a - - if use_numpy: - d00 = np.sum(v0*v0,axis=1) - d01 = np.sum(v0*v1,axis=1) - d11 = np.sum(v1*v1,axis=1) - d20 = np.sum(v2*v0,axis=1) - d21 = np.sum(v2*v1,axis=1) - else: - d00 = torch.sum(v0*v0,dim=1) - d01 = torch.sum(v0*v1,dim=1) - d11 = torch.sum(v1*v1,dim=1) - d20 = torch.sum(v2*v0,dim=1) - d21 = torch.sum(v2*v1,dim=1) - denom = d00*d11 - d01*d01 - w1 = (d11 * d20 - d01 * d21)/denom - w2 = (d00 * d21 - d01 * d20)/denom - - w0 = 1 - w1 - w2 - - return w0,w1,w2 \ No newline at end of file diff --git a/FastSplatStyler_huggingface/output.splat b/FastSplatStyler_huggingface/output.splat deleted file mode 100644 index ce04350c02b3cb30f7add04400d510b8cb32005b..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/output.splat +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9c98911aa0af25cefc959d8db8da26fa4770986b630d26ea39d4c223ab478c45 -size 4149408 diff --git a/FastSplatStyler_huggingface/plyio.py b/FastSplatStyler_huggingface/plyio.py deleted file mode 100644 index 602271149c33d3eb10b7e6820f7be0e5d81cb219..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/plyio.py +++ /dev/null @@ -1,136 +0,0 @@ -# Modified from https://antimatter15.com/splat - -from plyfile import PlyData -import numpy as np -import argparse -from io import BytesIO - - -def splat_to_numpy(file_path): - with open(file_path, 'rb') as f: - splat_data = f.read() - - splat_dtype = np.dtype([ - ('position', np.float32, 3), - ('scale', np.float32, 3), - ('color', np.uint8, 4), - ('rotation', np.uint8, 4) - ]) - - - splat_array = np.frombuffer(splat_data, dtype=splat_dtype) - - points = splat_array["position"] - scales = splat_array["scale"] - rots = (splat_array["rotation"]/255)*2 - 1 - color = splat_array["color"]/255 - - - return points, scales, rots.astype(np.float32), color.astype(np.float32) - - -def ply_to_numpy(ply_file_path): - plydata = PlyData.read(ply_file_path) - vert = plydata["vertex"] - sorted_indices = np.argsort( - -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) - / (1 + np.exp(-vert["opacity"])) - ) - buffer = BytesIO() - - positions = np.zeros((len(sorted_indices), 3), dtype=np.float32) - scales = np.zeros((len(sorted_indices), 3), dtype=np.float32) - rots = np.zeros((len(sorted_indices), 4), dtype=np.float32) - colors = np.zeros((len(sorted_indices), 4), dtype=np.float32) - - for idx in sorted_indices: - v = plydata["vertex"][idx] - position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) - scale = np.exp( - np.array( - [v["scale_0"], v["scale_1"], v["scale_2"]], - dtype=np.float32, - ) - ) - rot = np.array( - [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], - dtype=np.float32, - ) - SH_C0 = 0.28209479177387814 - color = np.array( - [ - 0.5 + SH_C0 * v["f_dc_0"], - 0.5 + SH_C0 * v["f_dc_1"], - 0.5 + SH_C0 * v["f_dc_2"], - 1 / (1 + np.exp(-v["opacity"])), - ] - ) - - - positions[idx] = position - scales[idx] = scale - rots[idx] = rot - colors[idx] = color - return positions, scales, rots, colors - - - - -def numpy_to_splat(positions, scales, rots, colors, output_path, file_type): - buffer = BytesIO() - if file_type == 'ply': - - for idx in range(len(positions)): - position = positions[idx] - scale = scales[idx] - rot = rots[idx] - color = colors[idx] - buffer.write(position.tobytes()) - buffer.write(scale.tobytes()) - buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) - buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes() - ) - - splat_data = buffer.getvalue() - with open(output_path, "wb") as f: - f.write(splat_data) - else: - for idx in range(len(positions)): - position = positions[idx] - scale = scales[idx] - rot = rots[idx] - color = colors[idx] - buffer.write(position.tobytes()) - buffer.write(scale.tobytes()) - buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) - buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes() - ) - - splat_data = buffer.getvalue() - with open(output_path, "wb") as f: - f.write(splat_data) - - return splat_data - - - - -def main(): - parser = argparse.ArgumentParser(description="Convert PLY files to SPLAT format.") - parser.add_argument( - "input_files", nargs="+", help="The input PLY files to process." - ) - parser.add_argument( - "--output", "-o", default="output.splat", help="The output SPLAT file." - ) - args = parser.parse_args() - for input_file in args.input_files: - print(f"Processing {input_file}...") - positions, scales, rotations, colors = ply_to_numpy(input_file) - - numpy_to_splat(positions, scales, rotations, colors, args.output) - - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/FastSplatStyler_huggingface/pointCloudToMesh.py b/FastSplatStyler_huggingface/pointCloudToMesh.py deleted file mode 100644 index 6015a03c30c1014435e3095b1ea4445ae85dc5c2..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/pointCloudToMesh.py +++ /dev/null @@ -1,201 +0,0 @@ -import numpy as np -import open3d as o3d -from skimage import measure -from scipy.spatial import cKDTree -import splat_helpers as splt - -# based on code from https://towardsdatascience.com/transform-point-clouds-into-3d-meshes-a-python-guide-8b0407a780e6 -# credit Florent Poux -# Towards Data Science (2024) - -def MarchingCubes_from_ply(dataset, voxel_size, iso_level_percentile): - - pcd = o3d.io.read_point_cloud(dataset) - - # Convert Open3D point cloud to numpy array - points = np.asarray(pcd.points) - # Compute the bounds of the point cloud - mins = np.min(points, axis=0) - maxs = np.max(points, axis=0) - # Create a 3D grid - x = np.arange(mins[0], maxs[0], voxel_size) - y = np.arange(mins[1], maxs[1], voxel_size) - z = np.arange(mins[2], maxs[2], voxel_size) - x, y, z = np.meshgrid(x, y, z, indexing='ij') - - # Create a KD-tree for efficient nearest neighbor search - tree = cKDTree(points) - - # Compute the scalar field (distance to nearest point) - grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T - distances, _ = tree.query(grid_points) - scalar_field = distances.reshape(x.shape) - - # Determine iso-level based on percentile of distances - iso_level = np.percentile(distances, iso_level_percentile) - - # Apply Marching Cubes - verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level) - - # Scale and translate vertices back to original coordinate system - verts = verts * voxel_size + mins - # Create mesh - mesh = o3d.geometry.TriangleMesh() - mesh.vertices = o3d.utility.Vector3dVector(verts) - mesh.triangles = o3d.utility.Vector3iVector(faces) - # Compute vertex normals - mesh.compute_vertex_normals() - # Visualize the result - o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True) - - - -def MarchingCubes_with_filtering(dataset, voxel_size, iso_level_percentile, threshold=99, out_file = 'out.obj'): - - pos3D, _, _, _, _, _, _ = splt.splat_unpacker_threshold(25, dataset, threshold) - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(pos3D.numpy()) - - # Convert Open3D point cloud to numpy array - points = np.asarray(pcd.points) - # Compute the bounds of the point cloud - mins = np.min(points, axis=0) - maxs = np.max(points, axis=0) - # Create a 3D grid - x = np.arange(mins[0], maxs[0], voxel_size) - y = np.arange(mins[1], maxs[1], voxel_size) - z = np.arange(mins[2], maxs[2], voxel_size) - x, y, z = np.meshgrid(x, y, z, indexing='ij') - - # Create a KD-tree for efficient nearest neighbor search - tree = cKDTree(points) - - # Compute the scalar field (distance to nearest point) - grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T - distances, _ = tree.query(grid_points) - scalar_field = distances.reshape(x.shape) - - # Determine iso-level based on percentile of distances - iso_level = np.percentile(distances, iso_level_percentile) - - # Apply Marching Cubes - verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level) - - # Scale and translate vertices back to original coordinate system - verts = verts * voxel_size + mins - # Create mesh - mesh = o3d.geometry.TriangleMesh() - mesh.vertices = o3d.utility.Vector3dVector(verts) - mesh.triangles = o3d.utility.Vector3iVector(faces) - # Compute vertex normals - mesh.compute_vertex_normals() - # Save the result - o3d.io.write_triangle_mesh(out_file, mesh) - # Visualize the result - o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True) - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(np.asarray(mesh.vertices)) - pcd.estimate_normals() - o3d.visualization.draw_geometries([pcd]) - - - -def MarchingCubes_return_vertices(dataset, visualize = False): - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(dataset.numpy()) - voxel_size_tensor = (abs(dataset.min()) + abs(dataset.max()))/100 - voxel_size = voxel_size_tensor.item() - iso_level_percentile = 5 - - # Convert Open3D point cloud to numpy array - points = np.asarray(pcd.points) - # Compute the bounds of the point cloud - mins = np.min(points, axis=0) - maxs = np.max(points, axis=0) - # Create a 3D grid - x = np.arange(mins[0], maxs[0], voxel_size) - y = np.arange(mins[1], maxs[1], voxel_size) - z = np.arange(mins[2], maxs[2], voxel_size) - x, y, z = np.meshgrid(x, y, z, indexing='ij') - - # Create a KD-tree for efficient nearest neighbor search - tree = cKDTree(points) - - # Compute the scalar field (distance to nearest point) - grid_points = np.vstack([x.ravel(), y.ravel(), z.ravel()]).T - distances, _ = tree.query(grid_points) - scalar_field = distances.reshape(x.shape) - - # Determine iso-level based on percentile of distances - iso_level = np.percentile(distances, iso_level_percentile) - - # Apply Marching Cubes - verts, faces, _, _ = measure.marching_cubes(scalar_field, level=iso_level) - - # Scale and translate vertices back to original coordinate system - verts = verts * voxel_size + mins - # Create mesh - mesh = o3d.geometry.TriangleMesh() - mesh.vertices = o3d.utility.Vector3dVector(verts) - mesh.triangles = o3d.utility.Vector3iVector(faces) - # Compute vertex normals - mesh.compute_vertex_normals() - # Save the result - #o3d.io.write_triangle_mesh(out_file, mesh) - - pcd2 = o3d.geometry.PointCloud() - pcd2.points = o3d.utility.Vector3dVector(np.asarray(mesh.vertices)) - pcd2.estimate_normals() - if visualize == True: - # Visualize the result - o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True) - o3d.visualization.draw_geometries([pcd2]) - - vertices = np.asarray(pcd2.points, dtype= np.float32) - return vertices - -def graph_Points(points, colors): - import matplotlib.pyplot as plt - - x = points[:,0] - y = points[:,1] - z = points[:,2] - - - # 3D Plot - fig = plt.figure() - ax = fig.add_subplot(projection='3d') - - # Scatter plot with colors - ax.scatter(x, y, z, c=colors) - - # Set labels for axes - ax.set_xlabel('X') - ax.set_ylabel('Y') - ax.set_zlabel('Z') - - # Display the plot - plt.show() - - return - -def Estimate_Normals(points, n = 25, threshold=75): - - #pos3D, _, _, _, _ = splt.splat_unpacker_threshold(n, dataset, threshold) - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(points.numpy()) - - pcd.estimate_normals() - pcd.normalize_normals() - normals = np.asarray(pcd.normals) - - #scaleSize = normals.max() - #normalsBad = np.random.rand(normals.shape[0],normals.shape[1]) - #normalsBad = normalsBad * scaleSize - #return normalsBad - - return normals - - - \ No newline at end of file diff --git a/FastSplatStyler_huggingface/pooling.py b/FastSplatStyler_huggingface/pooling.py deleted file mode 100644 index 17401e3e42bbc521e331c038a0403e02f3069787..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/pooling.py +++ /dev/null @@ -1,143 +0,0 @@ -import torch -from torch_scatter import scatter - -def avgPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]): - - is_even = kernel_size % 2 == 0 - full_passes = kernel_size//2 - int(is_even) - - # Assumes the lowest number node index is the topleft most position in the cluster - indices = torch.arange(len(x)).to(x.device) - - # Find the minimum node index in each cluster and select those x values - picks = scatter(indices, cluster, dim=0, reduce='min') - - # Send max pool messages the appropriate number of times - for _ in range(full_passes): - message = x[edge_index[1]] - x = scatter(message,edge_index[0],dim=0,reduce='mean') # Aggregate - - # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner - # Repeat the process one more time going to the bottom right - if is_even: - - # Prefilter edge_index - keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device) - for i in even_dirs: - keep[torch.where(selections == i)] = True - even_edge_index = edge_index[:,torch.where(keep)[0]] - - message = x[even_edge_index[1]] - x = scatter(message,even_edge_index[0],dim=0,reduce='mean') # Aggregate - - # Take the previously selected nodes - x = x[picks] - - return x - -def maxPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]): - - is_even = kernel_size % 2 == 0 - full_passes = kernel_size//2 - int(is_even) - - # Assumes the lowest number node index is the topleft most position in the cluster - indices = torch.arange(len(x)).to(x.device) - - # Find the minimum node index in each cluster and select those x values - picks = scatter(indices, cluster, dim=0, reduce='min') - - # Send max pool messages the appropriate number of times - for _ in range(full_passes): - message = x[edge_index[1]] - x = scatter(message,edge_index[0],dim=0,reduce='max') # Aggregate - - # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner - # Repeat the process one more time going to the bottom right - if is_even: - - # Prefilter edge_index - keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device) - for i in even_dirs: - keep[torch.where(selections == i)] = True - even_edge_index = edge_index[:,torch.where(keep)[0]] - - message = x[even_edge_index[1]] - x = scatter(message,even_edge_index[0],dim=0,reduce='max') # Aggregate - - # Take the previously selected nodes - x = x[picks] - - return x - - -def stridePoolCluster(x,cluster): - - # Assumes the lowest number node index is the topleft most position in the cluster - indices = torch.arange(len(x)).to(x.device) - - # Find the minimum node index in each cluster and select those x values - picks = scatter(indices, cluster, dim=0, reduce='min') - x = x[picks] - - return x - - -def maxPoolCluster(x,cluster): - - x = scatter(x, cluster, dim=0, reduce='max') - return x - -def avgPoolCluster(x,cluster,edge_index=None, edge_weight=None): - - x = scatter(x, cluster, dim=0, reduce='mean') - return x - - -def unpoolInterpolated(x,cluster,up_edge_index,up_interps=None): - - if up_interps is None: - return unpoolEdgeAverage(x,cluster,up_edge_index) - - # Determine node averages based on based on interps - target_clusters = cluster[up_edge_index[1]] - - node_vals = x[target_clusters]*up_interps.unsqueeze(1) - x = scatter(node_vals,up_edge_index[0],dim=0,reduce='add') - norm = scatter(up_interps,up_edge_index[0],dim=0) - x/=norm.unsqueeze(1) - - return x - -def unpoolBilinear(x,cluster,up_edge_index,up_selections,selection_dirs=[0,1,7,8]): - - # Remove edges that won't be used for the bilinear interpolation calculation - keep = torch.zeros_like(up_selections,dtype=torch.bool).to(x.device) - for i in selection_dirs: - keep[torch.where(up_selections == i)] = True - - ref_edge_index = up_edge_index[:,torch.where(keep)[0]] - - cluster_index = torch.vstack((ref_edge_index[0],cluster[ref_edge_index[1]])) - cluster_index = torch.unique(cluster_index,dim=1) - x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean') - - return x - -def unpoolEdgeAverage(x,cluster,up_edge_index,weighted=True): - # Interpolates based on the number of connections to each previous cluster. Works best with dense data. - # If weighted = False, clusters are weighted equally regardless of the number of connections - - if weighted: - target_clusters = cluster[up_edge_index[1]] - x = scatter(x[target_clusters],up_edge_index[0],dim=0,reduce='mean') - - else: - cluster_index = torch.vstack((up_edge_index[0],cluster[up_edge_index[1]])) - cluster_index = torch.unique(cluster_index,dim=1) - x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean') - - return x - -def unpoolCluster(x,cluster): - - return x[cluster] diff --git a/FastSplatStyler_huggingface/pyntcloud_io.py b/FastSplatStyler_huggingface/pyntcloud_io.py deleted file mode 100644 index 7c51e45c6f1656f27a4de6a43318be0a5c47807d..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/pyntcloud_io.py +++ /dev/null @@ -1,327 +0,0 @@ -# -#CREDIT TO https://github.com/daavoo/pyntcloud/blob/master/pyntcloud/io/ply.py -# -#User daavoo -# - -# HAKUNA MATATA - -import sys -import numpy as np -import pandas as pd -from collections import defaultdict - -sys_byteorder = ('>', '<')[sys.byteorder == 'little'] - -ply_dtypes = dict([ - (b'int8', 'i1'), - (b'char', 'i1'), - (b'uint8', 'u1'), - (b'uchar', 'b1'), - (b'uchar', 'u1'), - (b'int16', 'i2'), - (b'short', 'i2'), - (b'uint16', 'u2'), - (b'ushort', 'u2'), - (b'int32', 'i4'), - (b'int', 'i4'), - (b'uint32', 'u4'), - (b'uint', 'u4'), - (b'float32', 'f4'), - (b'float', 'f4'), - (b'float64', 'f8'), - (b'double', 'f8') -]) - -valid_formats = {'ascii': '', 'binary_big_endian': '>', - 'binary_little_endian': '<'} - - -def read_ply(filename, allow_bool=False): - """ Read a .ply (binary or ascii) file and store the elements in pandas DataFrame. - - Parameters - ---------- - filename: str - Path to the filename - allow_bool: bool - flag to allow bool as a valid PLY dtype. False by default to mirror original PLY specification. - - Returns - ------- - data: dict - Elements as pandas DataFrames; comments and ob_info as list of string - """ - if allow_bool: - ply_dtypes[b'bool'] = '?' - - with open(filename, 'rb') as ply: - - if b'ply' not in ply.readline(): - raise ValueError('The file does not start with the word ply') - # get binary_little/big or ascii - fmt = ply.readline().split()[1].decode() - # get extension for building the numpy dtypes - ext = valid_formats[fmt] - - line = [] - dtypes = defaultdict(list) - count = 2 - points_size = None - mesh_size = None - has_texture = False - comments = [] - while b'end_header' not in line and line != b'': - line = ply.readline() - - if b'element' in line: - line = line.split() - name = line[1].decode() - size = int(line[2]) - if name == "vertex": - points_size = size - elif name == "face": - mesh_size = size - - elif b'property' in line: - line = line.split() - # element mesh - if b'list' in line: - - if b"vertex_indices" in line[-1] or b"vertex_index" in line[-1]: - mesh_names = ["n_points", "v1", "v2", "v3"] - else: - has_texture = True - mesh_names = ["n_coords"] + ["v1_u", "v1_v", "v2_u", "v2_v", "v3_u", "v3_v"] - - if fmt == "ascii": - # the first number has different dtype than the list - dtypes[name].append( - (mesh_names[0], ply_dtypes[line[2]])) - # rest of the numbers have the same dtype - dt = ply_dtypes[line[3]] - else: - # the first number has different dtype than the list - dtypes[name].append( - (mesh_names[0], ext + ply_dtypes[line[2]])) - # rest of the numbers have the same dtype - dt = ext + ply_dtypes[line[3]] - - for j in range(1, len(mesh_names)): - dtypes[name].append((mesh_names[j], dt)) - else: - if fmt == "ascii": - dtypes[name].append( - (line[2].decode(), ply_dtypes[line[1]])) - else: - dtypes[name].append( - (line[2].decode(), ext + ply_dtypes[line[1]])) - - elif b'comment' in line: - line = line.split(b" ", 1) - comment = line[1].decode().rstrip() - comments.append(comment) - - count += 1 - - # for bin - end_header = ply.tell() - - data = {} - - if comments: - data["comments"] = comments - - if fmt == 'ascii': - top = count - bottom = 0 if mesh_size is None else mesh_size - - names = [x[0] for x in dtypes["vertex"]] - - data["points"] = pd.read_csv(filename, sep=" ", header=None, engine="python", - skiprows=top, skipfooter=bottom, usecols=names, names=names) - - for n, col in enumerate(data["points"].columns): - data["points"][col] = data["points"][col].astype( - dtypes["vertex"][n][1]) - - if mesh_size : - top = count + points_size - - names = np.array([x[0] for x in dtypes["face"]]) - usecols = [1, 2, 3, 5, 6, 7, 8, 9, 10] if has_texture else [1, 2, 3] - names = names[usecols] - - data["mesh"] = pd.read_csv( - filename, sep=" ", header=None, engine="python", skiprows=top, usecols=usecols, names=names) - - for n, col in enumerate(data["mesh"].columns): - data["mesh"][col] = data["mesh"][col].astype( - dtypes["face"][n + 1][1]) - - else: - with open(filename, 'rb') as ply: - ply.seek(end_header) - points_np = np.fromfile(ply, dtype=dtypes["vertex"], count=points_size) - if ext != sys_byteorder: - points_np = points_np.byteswap().newbyteorder() - data["points"] = pd.DataFrame(points_np) - if mesh_size: - mesh_np = np.fromfile(ply, dtype=dtypes["face"], count=mesh_size) - if ext != sys_byteorder: - mesh_np = mesh_np.byteswap().newbyteorder() - data["mesh"] = pd.DataFrame(mesh_np) - data["mesh"].drop('n_points', axis=1, inplace=True) - - return data - - -def write_ply(filename, points=None, mesh=None, as_text=False, comments=None): - """Write a PLY file populated with the given fields. - - Parameters - ---------- - filename: str - The created file will be named with this - points: ndarray - mesh: ndarray - as_text: boolean - Set the write mode of the file. Default: binary - comments: list of string - - Returns - ------- - boolean - True if no problems - - """ - if not filename.endswith('ply'): - filename += '.ply' - - # open in text mode to write the header - with open(filename, 'w') as ply: - header = ['ply'] - - if as_text: - header.append('format ascii 1.0') - else: - header.append('format binary_' + sys.byteorder + '_endian 1.0') - - if comments: - for comment in comments: - header.append('comment ' + comment) - - if points is not None: - header.extend(describe_element('vertex', points)) - if mesh is not None: - mesh = mesh.copy() - mesh.insert(loc=0, column="n_points", value=3) - mesh["n_points"] = mesh["n_points"].astype("u1") - header.extend(describe_element('face', mesh)) - - header.append('end_header') - - for line in header: - ply.write("%s\n" % line) - - if as_text: - if points is not None: - points.to_csv(filename, sep=" ", index=False, header=False, mode='a', - encoding='ascii') - if mesh is not None: - mesh.to_csv(filename, sep=" ", index=False, header=False, mode='a', - encoding='ascii') - - else: - with open(filename, 'ab') as ply: - if points is not None: - points.to_records(index=False).tofile(ply) - if mesh is not None: - mesh.to_records(index=False).tofile(ply) - - return True - - -def describe_element(name, df): - """ Takes the columns of the dataframe and builds a ply-like description - - Parameters - ---------- - name: str - df: pandas DataFrame - - Returns - ------- - element: list[str] - """ - property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int', 'b': 'bool'} - element = ['element ' + name + ' ' + str(len(df))] - - if name == 'face': - element.append("property list uchar int vertex_indices") - - else: - for i in range(len(df.columns)): - # get first letter of dtype to infer format - f = property_formats[str(df.dtypes[i])[0]] - element.append('property ' + f + ' ' + df.columns.values[i]) - - return element - - - - - -def write_ply_float(filename, points=None, mesh=None, as_text=False, comments=None): - """Write a PLY file populated with the given fields. - - Parameters - ---------- - filename: str - The created file will be named with this - points: ndarray - mesh: ndarray - as_text: boolean - Set the write mode of the file. Default: binary - comments: list of string - - Returns - ------- - boolean - True if no problems - - """ - if not filename.endswith('ply'): - filename += '.ply' - - # open in text mode to write the header - with open(filename, 'w') as ply: - header = ['ply'] - - #header.append('format ascii 1.0') - header.append('format binary_' + sys.byteorder + '_endian 1.0') - - if comments: - for comment in comments: - header.append('comment ' + comment) - - if points is not None: - header.extend(describe_element('vertex', points)) - if mesh is not None: - mesh = mesh.copy() - mesh.insert(loc=0, column="n_points", value=3) - mesh["n_points"] = mesh["n_points"].astype("u1") - header.extend(describe_element('face', mesh)) - - header.append('end_header') - - for line in header: - ply.write("%s\n" % line) - - with open(filename, 'ab') as ply: - if points is not None: - pointsNumPy = points.to_numpy() - pointsNumPy.tofile(filename,format=' Tensor: - """""" - - all_nodes = torch.arange(x.shape[0]).to(x.device) - - if self.padding_mode == 'constant': - # Constant value of the average of the all the nodes - x_mean = torch.mean(x,dim=0) - - out = torch.zeros((x.shape[0],self.out_channels)).to(x.device) - - if self.padding_mode == 'normalize': - dir_count = torch.zeros((x.shape[0],1)).to(x.device) - - if self.kernel_size == 1 or self.kernel_size == 3: - - # Find the appropriate node for each selection by stepping through connecting edges - for s in range(self.selection_count): - cur_dir = torch.where(selections == s)[0] - - cur_source = edge_index[0,cur_dir] - cur_target = edge_index[1,cur_dir] - - if interps is not None: - cur_interps = interps[cur_dir] - cur_interps = torch.unsqueeze(cur_interps,dim=1) - #print(torch.amin(cur_interps),torch.amax(cur_interps)) - - if self.dilation > 1: - for _ in range(1, self.dilation): - vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,cur_dir]) - cur_source = cur_source[ind1] - cur_target = edge_index[1,cur_dir][ind2] - if interps is not None: - cur_interps = cur_interps[ind1] - - # Main Calculation - if interps is None: - #out[cur_source] += torch.matmul(x[cur_target], self.weight[s]) - result = torch.matmul(x[cur_target], self.weight[s]) - else: - #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s]) - result = cur_interps*torch.matmul(x[cur_target], self.weight[s]) - - # Adding with duplicate indices - out.index_add_(0,cur_source,result) - - # Sanity check - #from tqdm import tqdm - #for i,node in enumerate(tqdm(cur_source)): - # out[node] += result[i] - - if self.padding_mode == 'constant': - missed_nodes = setdiff1d(all_nodes, cur_source) - #out[missed_nodes] += torch.matmul(x_mean, self.weight[s]) - out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s])) - - if self.padding_mode == 'replicate': - missed_nodes = setdiff1d(all_nodes, cur_source) - #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s]) - out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s])) - - if self.padding_mode == 'reflect': - missed_nodes = setdiff1d(all_nodes, cur_source) - - opposite = s+4 - if opposite > 8: - opposite = opposite % 9 + 1 - - op_dir = torch.where(selections == opposite)[0] - - op_source = edge_index[0,op_dir] - op_target = edge_index[1,op_dir] - - if interps is not None: - op_interps = interps[op_dir] - op_interps = torch.unsqueeze(op_interps,dim=1) - - - # Only take edges that are part of missed nodes - vals, ind1, ind2 = intersect1d(op_source,missed_nodes) - op_source = op_source[ind1] - op_target = op_target[ind1] - if interps is not None: - op_interps = op_interps[ind1] - - if self.dilation > 1: - for _ in range(1, self.dilation): - vals, ind1, ind2 = intersect1d(op_target,edge_index[0,op_dir]) - op_source = op_source[ind1] - op_target = edge_index[1,op_dir][ind2] - if interps is not None: - op_interps = op_interps[ind1] - - # Main Calculation - if interps is None: - result = torch.matmul(x[op_target], self.weight[s]) - else: - result = op_interps * torch.matmul(x[op_target], self.weight[s]) - - out.index_add_(0,op_source,result) - - if self.padding_mode == 'normalize': - dir_count[torch.unique(cur_source)] += 1 - - else: - width = self.kernel_size//2 - horiz = torch.arange(-width,width+1).to(x.device) - vert = torch.arange(-width,width+1).to(x.device) - - right = torch.where(selections == 1)[0] - left = torch.where(selections == 5)[0] - down = torch.where(selections == 7)[0] - up = torch.where(selections == 3)[0] - - center = torch.where(selections == 0)[0] - - # Find the appropriate node for each selection by stepping through connecting edges - s = 0 - for i in range(self.kernel_size): - for j in range(self.kernel_size): - x_loc = horiz[j] - y_loc = vert[i] - - cur_source = edge_index[0,center] #Starting location - cur_target = edge_index[1,center] - - if interps is not None: - cur_interps = interps[center] - cur_interps = torch.unsqueeze(cur_interps,dim=1) - - #print(torch.sum(cur_target-cur_source)) - - #print(cur_target.shape) - - if x_loc < 0: - for _ in range(self.dilation*abs(x_loc)): - vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,left]) - cur_source = cur_source[ind1] - cur_target = edge_index[1,left][ind2] - if interps is not None: - cur_interps = cur_interps[ind1] - if x_loc > 0: - for _ in range(self.dilation*abs(x_loc)): - vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,right]) - cur_source = cur_source[ind1] - cur_target = edge_index[1,right][ind2] - if interps is not None: - cur_interps = cur_interps[ind1] - - if y_loc < 0: - for _ in range(self.dilation*abs(y_loc)): - vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,up]) - cur_source = cur_source[ind1] - cur_target = edge_index[1,up][ind2] - if interps is not None: - cur_interps = cur_interps[ind1] - - if y_loc > 0: - for _ in range(self.dilation*abs(y_loc)): - vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,down]) - cur_source = cur_source[ind1] - cur_target = edge_index[1,down][ind2] - if interps is not None: - cur_interps = cur_interps[ind1] - - # Main Calculation - if interps is None: - #out[cur_source] += torch.matmul(x[cur_target], self.weight[s]) - result = torch.matmul(x[cur_target], self.weight[s]) - else: - #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s]) - result = cur_interps*torch.matmul(x[cur_target], self.weight[s]) - - # Adding with duplicate indices - out.index_add_(0,cur_source,result) - - if self.padding_mode == 'constant': - missed_nodes = setdiff1d(all_nodes, cur_source) - #out[missed_nodes] += torch.matmul(x_mean, self.weight[s]) - out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s])) - - if self.padding_mode == 'replicate': - missed_nodes = setdiff1d(all_nodes, cur_source) - #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s]) - out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s])) - - if self.padding_mode == 'reflect': - raise ValueError("Reflect padding not yet implemented for larger kernels") - - if self.padding_mode == 'normalize': - dir_count[torch.unique(cur_source)] += 1 - - s+=1 - - #print(self.selection_count/(dir_count + 1e-8)) - #test_val = self.selection_count/(dir_count + 1e-8) - # print(torch.max(test_val),torch.min(test_val),torch.mean(test_val)) - - if self.padding_mode == 'zeros': - pass # Already accounted for in the graph structure, no further computation needed - elif self.padding_mode == 'normalize': - out *= self.selection_count/(dir_count + 1e-8) - elif self.padding_mode == 'constant': - pass # Processed earlier - elif self.padding_mode == 'replicate': - pass - elif self.padding_mode == 'reflect': - pass - elif self.padding_mode == 'circular': - raise ValueError("Circular padding cannot be generalized on a graph. Instead, create a graph with edges connecting to the wrapped around nodes") - else: - raise ValueError(f"Unknown padding mode: {self.padding_mode}") - - # Add bias if applicable - out += self.bias - - return out - - - def copy_weightsNxN(self,weight,bias=None): - - width = int(sqrt(self.selection_count)) - - # Assumes weight comes in as [output channels, input channels, row, col] - for i in range(self.selection_count): - self.weight[i] = weight[:,:,i//width,i%width].permute(1,0) - - - def copy_weights3x3(self,weight,bias=None): - - - # Assumes weight comes in as [output channels, input channels, row, col] - # Assumes weight is a 3x3 - - # Current Ordering - # 4 3 2 - # 5 0 1 - # 6 7 8 - - # Need to flip horizontally per implementation of convolution - #self.weight[5] = weight[:,:,1,2].permute(1,0) - #self.weight[7] = weight[:,:,0,1].permute(1,0) - #self.weight[1] = weight[:,:,1,0].permute(1,0) - #self.weight[3] = weight[:,:,2,1].permute(1,0) - #self.weight[6] = weight[:,:,0,2].permute(1,0) - #self.weight[8] = weight[:,:,0,0].permute(1,0) - #self.weight[2] = weight[:,:,2,0].permute(1,0) - #self.weight[4] = weight[:,:,2,2].permute(1,0) - #self.weight[0] = weight[:,:,1,1].permute(1,0) - - self.weight[1] = weight[:,:,1,2].permute(1,0) - self.weight[3] = weight[:,:,0,1].permute(1,0) - self.weight[5] = weight[:,:,1,0].permute(1,0) - self.weight[7] = weight[:,:,2,1].permute(1,0) - self.weight[2] = weight[:,:,0,2].permute(1,0) - self.weight[4] = weight[:,:,0,0].permute(1,0) - self.weight[6] = weight[:,:,2,0].permute(1,0) - self.weight[8] = weight[:,:,2,2].permute(1,0) - self.weight[0] = weight[:,:,1,1].permute(1,0) - - - def copy_weights1x1(self, weight, bias=None): - self.weight[0] = weight[:,:,0,0].permute(1, 0) - - - def copy_weights(self,weight,bias=None): - - if self.kernel_size == 3: - self.copy_weights3x3(weight,bias) - elif self.kernel_size == 1: - self.copy_weights1x1(weight, bias) - else: - self.copy_weightsNxN(weight,bias) - - if bias is None: - self.bias[:] = 0.0 - else: - self.bias = bias - - - def __repr__(self): - return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, - self.out_channels) - diff --git a/FastSplatStyler_huggingface/sphere_helpers.py b/FastSplatStyler_huggingface/sphere_helpers.py deleted file mode 100644 index 650be180309411fa7e36f1e8b54220c3ab137bfd..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/sphere_helpers.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import numpy as np -import graph_helpers as gh -from math import pi, sqrt, sin - -def equirec2spherical(rows, cols, device = 'cpu'): - theta_steps = torch.linspace(0, 2*pi, int(cols+1)).to(device)[:-1] # Avoid overlapping points - phi_steps = torch.linspace(0, pi, int(rows+1)).to(device)[:-1] - theta, phi = torch.meshgrid(theta_steps, phi_steps,indexing='xy') - return theta.flatten(),phi.flatten() - -def spherical2equirec(theta,phi,rows,cols): - x = theta*cols/(2*pi) - y = phi*rows/pi - return x,y - -def spherical2xyz(theta,phi): - x, y, z = torch.cos(theta) * torch.sin(phi), torch.cos(phi), torch.sin(theta) * torch.sin(phi),; - return x,y,z - -def sampleSphere_Equirec(rows,cols): - theta, phi = equirec2spherical(rows, cols) - x,y,z = spherical2xyz(theta,phi) - - spherical = torch.stack((theta,phi),dim=1) - cartesian = torch.stack((x,y,z),dim=1) - return cartesian, spherical - -def sampleSphere_Layering(rows,cols=None): - N_phi = rows - - # Make rows match the number of phi locations - N = 4/pi*N_phi*N_phi - - # Alternative method presented in https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf - a = 4*pi/N - d = sqrt(a) - M_phi = int(round(pi/d)) - d_phi = pi/M_phi - d_theta = a/d_phi - - theta_list = [] - phi_list = [] - - for m in range(M_phi): - phi = pi*(m+0.5)/M_phi - M_theta = int(round(2*pi*sin(phi)/d_theta)) - for n in range(M_theta): - theta = 2*pi*n/M_theta - - phi_list.append(phi) - theta_list.append(theta) - - theta = torch.tensor(theta_list,dtype=torch.float) - phi = torch.tensor(phi_list,dtype=torch.float) - x,y,z = spherical2xyz(theta,phi) - - spherical = torch.stack((theta,phi),dim=1) - cartesian = torch.stack((x,y,z),dim=1) - return cartesian, spherical - -def sampleSphere_Spiral(rows,cols): - # Get a sufficient number of points to represent the resolution of the image. - N = int(2/pi*cols*rows) - - # Use the Fibonacci Spiral to equally space points - # Visualization at https://www.youtube.com/watch?v=Ua0kig6N3po - goldenRatio = (1 + 5**0.5)/2 - i = torch.arange(0, N) - theta = (2 * pi * i / goldenRatio) % (2 * pi) - phi = torch.arccos(1 - 2*(i+0.5)/N) - x,y,z = spherical2xyz(theta,phi) - - spherical = torch.stack((theta,phi),dim=1) - cartesian = torch.stack((x,y,z),dim=1) - return cartesian, spherical - -def sampleSphere_Icosphere(rows,cols=None): - num_subdiv = int(np.floor(np.log2(rows)) - 1) - cartesian = icosphere(num_subdiv) - - # Spherical coordinates calculator - xy = cartesian[:,0]**2 + cartesian[:,1]**2 - phi = np.arctan2(np.sqrt(xy), cartesian[:,2]) # for elevation angle defined from Z-axis down - theta = np.arctan2(cartesian[:,1], cartesian[:,0]) - theta += np.pi - - # Convert to torch - phi = torch.tensor(phi,dtype=torch.float) - theta = torch.tensor(theta,dtype=torch.float) - - x,y,z = spherical2xyz(theta,phi) # Redefine for axis consistency - - spherical = torch.stack((theta,phi),dim=1) - cartesian = torch.stack((x,y,z),dim=1) - return cartesian, spherical - -def sampleSphere_Random(rows,cols): - # Get a sufficient number of points to represent the resolution of the image. - N = int(2/pi*cols*rows) - - # Alternative method presented in https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf - z = 2*torch.rand(N) - 1 - phi = 2*pi*torch.rand(N) - - x = torch.sqrt(1 - z**2)*torch.cos(phi) - y = torch.sqrt(1 - z**2)*torch.sin(phi) - - # Spherical coordinates calculator - xy = x**2 + y**2 - phi = torch.atan2(torch.sqrt(xy), z) # for elevation angle defined from Z-axis down - theta = torch.atan2(y, x) - theta += pi - - x,y,z = spherical2xyz(theta,phi) # Redefine for axis consistency - - spherical = torch.stack((theta,phi),dim=1) - cartesian = torch.stack((x,y,z),dim=1) - return cartesian, spherical - -#### Icosphere Methods ##### -# Code taken from https://yaweiliu.github.io/research_notes/notes/20210301_Creating%20an%20icosphere%20with%20Python.html -from scipy.spatial.transform import Rotation as R - -def vertex(x, y, z): - """ Return vertex coordinates fixed to the unit sphere """ - length = np.sqrt(x**2 + y**2 + z**2) - return [i / length for i in (x,y,z)] - -def middle_point(verts,middle_point_cache,point_1, point_2): - """ Find a middle point and project to the unit sphere """ - # We check if we have already cut this edge first - # to avoid duplicated verts - smaller_index = min(point_1, point_2) - greater_index = max(point_1, point_2) - key = '{0}-{1}'.format(smaller_index, greater_index) - if key in middle_point_cache: return middle_point_cache[key] - # If it's not in cache, then we can cut it - vert_1 = verts[point_1] - vert_2 = verts[point_2] - middle = [sum(i)/2 for i in zip(vert_1, vert_2)] - verts.append(vertex(*middle)) - index = len(verts) - 1 - middle_point_cache[key] = index - return index - -def icosphere(subdiv): - # verts for icosahedron - r = (1.0 + np.sqrt(5.0)) / 2.0; - verts = np.array([[-1.0, r, 0.0],[ 1.0, r, 0.0],[-1.0, -r, 0.0], - [1.0, -r, 0.0],[0.0, -1.0, r],[0.0, 1.0, r], - [0.0, -1.0, -r],[0.0, 1.0, -r],[r, 0.0, -1.0], - [r, 0.0, 1.0],[ -r, 0.0, -1.0],[-r, 0.0, 1.0]]); - # rescale the size to radius of 0.5 - verts /= np.linalg.norm(verts[0]) - # adjust the orientation - r = R.from_quat([[0.19322862,-0.68019314,-0.19322862,0.68019314]]) - verts = r.apply(verts) - verts = list(verts) - - faces = [[0, 11, 5],[0, 5, 1],[0, 1, 7],[0, 7, 10], - [0, 10, 11],[1, 5, 9],[5, 11, 4],[11, 10, 2], - [10, 7, 6],[7, 1, 8],[3, 9, 4],[3, 4, 2], - [3, 2, 6],[3, 6, 8],[3, 8, 9],[5, 4, 9], - [2, 4, 11],[6, 2, 10],[8, 6, 7],[9, 8, 1],]; - - for i in range(subdiv): - middle_point_cache = {} - faces_subdiv = [] - for tri in faces: - v1 = middle_point(verts,middle_point_cache,tri[0], tri[1]) - v2 = middle_point(verts,middle_point_cache,tri[1], tri[2]) - v3 = middle_point(verts,middle_point_cache,tri[2], tri[0]) - faces_subdiv.append([tri[0], v1, v3]) - faces_subdiv.append([tri[1], v2, v1]) - faces_subdiv.append([tri[2], v3, v2]) - faces_subdiv.append([v1, v2, v3]) - faces = faces_subdiv - - return np.array(verts) diff --git a/FastSplatStyler_huggingface/splat_helpers.py b/FastSplatStyler_huggingface/splat_helpers.py deleted file mode 100644 index ca901d7655d6c35c26a6611793ef8c221cbfc3d0..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/splat_helpers.py +++ /dev/null @@ -1,312 +0,0 @@ -import torch -import pyntcloud_io as plyio -from torch_geometric.nn import knn -import numpy as np -from torch_scatter import scatter_mean -import plyio as splatio -import pandas as pd -import plotly.graph_objects as go -from sklearn.cluster import KMeans -import matplotlib.pyplot as plt - -def splat_update_and_save(rawdata, results, fileName): - - - #df1[['f_dc_0','f_dc_1','f_dc_2']] = results - #plyNumpy = df1.to_numpy() - #plyNumpy.tofile('Scaniverse_chair_testOut2.ply') - - #plyNumpy = df1.to_numpy() - #np.savetxt('Scaniverse_chair_test.csv', plyNumpy, delimiter=',') - - #SH_C0 = 0.28209479177387814 - #normedResults = 0.5 - results/SH_C0 - - #df1[['f_dc_0','f_dc_1','f_dc_2']] = normedResults - - - #plyNumpy = df1.to_numpy() - #np.savetxt('Scaniverse_chair_testOut.csv', plyNumpy, delimiter=',') - #plyNumpy = df1.to_numpy() - #plyNumpy.tofile('Scaniverse_chair_testOutNormed.ply') - - - outFileName='Scaniverse_chair_outputFloat.ply' - df1 = pd.DataFrame(rawdata, columns=['x','y','z', - 'nx','ny','nz', - 'f_dc_0','f_dc_1','f_dc_2', - 'opacity', - 'scale_0','scale_1','scale_2', - 'rot_0','rot_1','rot_2','rot_3' - ]) - - plyio.write_ply_float(filename=outFileName, points= df1, mesh=None, as_text=False) - - - outFilePath='C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.splat' - opacity = rawdata[:,-1].reshape(224713, 1) - positions = rawdata[:,0:3] - scales = rawdata[:,3:6] - rots = rawdata[:,6:10] - colors = results - - splatio.numpy_to_splat(positions, scales, rots, colors, opacity, outFilePath) - - #save as csv - normalized = False - #normalize the color values if needed - if (normalized): - SH_C0 = 0.28209479177387814 - colors = colors/SH_C0-0.5 - opacity = -np.log(1/(opacity-1)) - - colors = np.concatenate((colors, opacity), axis=1) - splat = np.concatenate((positions, scales, rots, colors), axis=1) - #np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',', header=splat.dtype.names) - np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',') - - return - - -def splat_save(positions, scales, rots, colors, output_path): - - - - splatio.numpy_to_splat(positions, scales, rots, colors, output_path) - - return - - - - - - -def splat_unpacker(neighbors, fileName, removeTopQ = 0): - - ''' - #Get the raw PLY data into a tensor - plyout = plyio.read_ply('Scaniverse_chair_StyleOutput.ply') - df1out = plyout.get('points') - - #Get the raw PLY data into a tensor - ply = plyio.read_ply(fileName) - df1 = ply.get('points') - f1 = df1.to_numpy() - f2 = df1out.to_numpy() - ''' - - positions, scales, rots, colors = splatio.ply_to_numpy(fileName) - - - #Get the raw PLY data into a tensor - torchPoints = torch.Tensor(positions) - samples = torchPoints.size()[0] - - y = torchPoints - x = torchPoints - assign_index = knn(x, y, neighbors) - - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - diffvector = y - out[0:samples] - - diffvectorNum =diffvector.numpy() - - normals = torch.from_numpy(diffvectorNum) - - pos3D = torch.from_numpy(positions) - - torchColors = torch.Tensor(colors) - torchColors.clamp(0,1) - - return pos3D, normals, colors, scales, rots - - - -def splat_downsampler(pos3D, colors, scaleV): - - pointsNP = pos3D.numpy() - colorsNP = colors.numpy() - - #cluster - kmeans = KMeans(n_clusters=scaleV*256, random_state=0, n_init="auto").fit(pointsNP) - - centers = kmeans.cluster_centers_ - labels = kmeans.labels_ - - - x = np.indices(pointsNP.shape)[0] - y = np.indices(pointsNP.shape)[1] - z = np.indices(pointsNP.shape)[2] - col = pointsNP.flatten() - - # 3D Plot - fig = plt.figure() - ax3D = fig.add_subplot(projection='3d') - cm = plt.colormaps['brg'] - p3d = ax3D.scatter(x, y, z, c=col, cmap=cm) - plt.colorbar(p3d) - plt.colorbar(p3d) - - - - return pos3D, colors - - - -def splat_unpacker_threshold(neighbors, fileName, threshold): - - #check if .ply or .splat - if fileName[-3:] == 'ply': - positionsFile, scalesFile, rotsFile, colorsFile = splatio.ply_to_numpy(fileName) - fileType = 'ply' - else: - positionsFile, scalesFile, rotsFile, colorsFile = splatio.splat_to_numpy(fileName) - fileType = 'splat' - - - - positionsNP = positionsFile.copy() - scalesNP = scalesFile.copy() - rotsNP = rotsFile.copy() - colorsNP = colorsFile.copy() - - - #Get the raw PLY data into tensors - pos3D = torch.from_numpy(positionsNP) - colors = torch.from_numpy(colorsNP) - colors.clamp(0,1) - rots = torch.from_numpy(rotsNP) - scales = torch.from_numpy(scalesNP) - points = torch.from_numpy(positionsNP) - - samples = points.size()[0] - y = points - x = points - - #knn find midpoint - assign_index = knn(x, y, neighbors) - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - #calculate normals - normals = y - out[0:samples] - - - #use Euclidian distances to create a mask - #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 ) - distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 ) - threshold = np.clip(threshold, 0, 100) - boundry = np.percentile(distances, threshold) - mask = distances < boundry - - - pos3DFiltered = pos3D[mask] - normalsFiltered = normals[mask] - colorsFiltered = colors[mask] - scalesFiltered = scales[mask] - rotsFiltered = rots[mask] - - return pos3DFiltered, normalsFiltered, colorsFiltered[:,0:3], colorsFiltered[:,3], scalesFiltered, rotsFiltered, fileType - - -def splat_unpacker_threshold_graph_normals(neighbors, fileName, threshold): - - coneSize = 1 - - positionsNP, scalesNP, rotsNP, colorsNP = splatio.ply_to_numpy(fileName) - - #Get the raw PLY data into tensors - pos3D = torch.from_numpy(positionsNP) - colors = torch.from_numpy(colorsNP) - colors.clamp(0,1) - rots = torch.from_numpy(rotsNP) - scales = torch.from_numpy(scalesNP) - points = torch.from_numpy(positionsNP) - - samples = points.size()[0] - y = points - x = points - - #knn find midpoint - assign_index = knn(x, y, neighbors) - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - #calculate normals - normals = y - out[0:samples] - - - #use Euclidian distances to create a mask - #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 ) - distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 ) - threshold = np.clip(threshold, 0, 100) - boundry = np.percentile(distances, threshold) - mask = distances < boundry - - - pos3DFiltered = pos3D[mask] - normalsFiltered = normals[mask] - colorsFiltered = colors[mask] - scalesFiltered = scales[mask] - rotsFiltered = rots[mask] - - - - diffvector = y - out[0:samples] - - diffvectorNum =diffvector.numpy() - diffNum = distances.numpy() - resultNum = out[0:samples].numpy() - # Normalised [0,1] - diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum) - - ''' - #point cloud - marker_data = go.Scatter3d( - x=points[:, 0], - y=points[:, 2], - z=-points[:, 1], - marker=go.scatter3d.Marker(size=3, color= diffNumNorm), - opacity=0.8, - mode='markers' - ) - fig=go.Figure(data=marker_data) - fig.show() - ''' - #normals - fig = go.Figure(data=go.Cone( - x=points[:, 0], - y=points[:, 2], - z=-points[:, 1], - u=diffvectorNum[:, 0], - v=diffvectorNum[:, 2], - w=-diffvectorNum[:, 1], - sizemode="absolute", - sizeref=coneSize, - anchor="tail")) - - fig.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig.show() - - - return pos3DFiltered, normalsFiltered, colorsFiltered, scalesFiltered, rotsFiltered \ No newline at end of file diff --git a/FastSplatStyler_huggingface/splat_mesh_helpers.py b/FastSplatStyler_huggingface/splat_mesh_helpers.py deleted file mode 100644 index 7c6e4708221a633a0615f05d5e7e6e0f1f0c4f07..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/splat_mesh_helpers.py +++ /dev/null @@ -1,687 +0,0 @@ -import torch -import torch.utils -import pyntcloud_io as plyio -from torch_geometric.nn import knn -import numpy as np -from torch_scatter import scatter_mean -import plyio as splatio -import pandas as pd -import plotly.graph_objects as go -from sklearn.cluster import KMeans -import matplotlib.pyplot as plt - - - -def splat_update_and_save(rawdata, results, fileName): - - - #df1[['f_dc_0','f_dc_1','f_dc_2']] = results - #plyNumpy = df1.to_numpy() - #plyNumpy.tofile('Scaniverse_chair_testOut2.ply') - - #plyNumpy = df1.to_numpy() - #np.savetxt('Scaniverse_chair_test.csv', plyNumpy, delimiter=',') - - #SH_C0 = 0.28209479177387814 - #normedResults = 0.5 - results/SH_C0 - - #df1[['f_dc_0','f_dc_1','f_dc_2']] = normedResults - - - #plyNumpy = df1.to_numpy() - #np.savetxt('Scaniverse_chair_testOut.csv', plyNumpy, delimiter=',') - #plyNumpy = df1.to_numpy() - #plyNumpy.tofile('Scaniverse_chair_testOutNormed.ply') - - - outFileName='Scaniverse_chair_outputFloat.ply' - df1 = pd.DataFrame(rawdata, columns=['x','y','z', - 'nx','ny','nz', - 'f_dc_0','f_dc_1','f_dc_2', - 'opacity', - 'scale_0','scale_1','scale_2', - 'rot_0','rot_1','rot_2','rot_3' - ]) - - plyio.write_ply_float(filename=outFileName, points= df1, mesh=None, as_text=False) - - - outFilePath='C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.splat' - opacity = rawdata[:,-1].reshape(224713, 1) - positions = rawdata[:,0:3] - scales = rawdata[:,3:6] - rots = rawdata[:,6:10] - colors = results - - splatio.numpy_to_splat(positions, scales, rots, colors, opacity, outFilePath) - - #save as csv - normalized = False - #normalize the color values if needed - if (normalized): - SH_C0 = 0.28209479177387814 - colors = colors/SH_C0-0.5 - opacity = -np.log(1/(opacity-1)) - - colors = np.concatenate((colors, opacity), axis=1) - splat = np.concatenate((positions, scales, rots, colors), axis=1) - #np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',', header=splat.dtype.names) - np.savetxt('C:\\Users\\Home\\Documents\\Thesis\\IntSelcConv_with_pyVista\\Scaniverse_out.csv', splat, delimiter=',') - - return - - -def splat_save(positions, scales, rots, colors, output_path, fileType): - - splatio.numpy_to_splat(positions, scales, rots, colors, output_path, fileType) - - return - - - -def splat_randomsampler(pos3D): - - pointsNP = pos3D.numpy().copy() - np.random.shuffle(pointsNP) - scaler = int(pointsNP.shape[0]/2) - valuesNP = pointsNP[:scaler] - values64 = torch.from_numpy(valuesNP) - values = values64.to(torch.float32) - return values - - -def splat_GaussianSuperSampler(pos3D_fr, colors_fr, opacity_fr, scales_fr, rots_fr, GaussianSamples): - - pos3D = pos3D_fr.clone() - colors = colors_fr.clone() - opacity = opacity_fr.clone() - scales = scales_fr.clone() - rots = rots_fr.clone() - - #get all the splats' relative sizes/densities and opacities - importance = (torch.absolute(scales[:,0])+torch.absolute(scales[:,1])+torch.absolute(scales[:,2]))/scales.max()+torch.absolute(opacity)/opacity.max() - importanceConstant = importance.sum().to(torch.int32) - - increaseNormalizer = GaussianSamples/importanceConstant.item() - copiesToMakeQuantity = (importance*increaseNormalizer).to(torch.int32) - #create duplicates of all the larger Gaussians, create matched tensors for all the other attributes too - duplicatePoints = torch.from_numpy(np.repeat(pos3D.numpy(), copiesToMakeQuantity.numpy(), axis=0)) - duplicateColors = torch.from_numpy(np.repeat(colors.numpy(), copiesToMakeQuantity.numpy(), axis=0)) - duplicateRots = torch.from_numpy(np.repeat(rots.numpy(), copiesToMakeQuantity.numpy(), axis=0)) - duplicateScales = torch.from_numpy(np.repeat(scales.numpy(), copiesToMakeQuantity.numpy(), axis=0)) - - #create a copy for saving the rotation results - duplicateScalesRotated = duplicateScales.clone() - - q0 = duplicateRots[:,0] - q1 = duplicateRots[:,1] - q2 = duplicateRots[:,2] - q3 = duplicateRots[:,3] - x00 = q0**2 + q1**2 - q2**2 - q3**2 - x01 = 2*q1*q2 - 2*q0*q3 - x02 = 2*q1*q3 + 2*q0*q2 - x10 = 2*q1*q2 + 2*q0*q3 - x11 = q0**2 - q1**2 + q2**2 - q3**2 - x12 = 2*q2*q3 - 2*q0*q1 - x20 = 2*q1*q3 - 2*q0*q2 - x21 = 2*q2*q3 + 2*q0*q1 - x22 = q0**2 - q1**2 - q2**2 + q3**2 - - X0 = torch.stack((x00, x01, x02), dim=1) - X1 = torch.stack((x10, x11, x12), dim=1) - X2 = torch.stack((x20, x21, x22), dim=1) - - - R_matrix = torch.stack((X0, X1, X2), dim=1) - duplicateScalesRotated = torch.bmm(R_matrix, duplicateScales.unsqueeze(2)).squeeze(2) - #torch.normal can't handle negative numbers, store the signs and add them back afterwards - duplicateScalesRotatedSigns = duplicateScalesRotated.sign() - duplicateScalesRotatedABS = torch.absolute(duplicateScalesRotated) - - GaussianPointsNoSignScale = torch.normal(duplicatePoints, duplicateScalesRotatedABS) - #correct for signs - GaussianPoints = duplicatePoints + (GaussianPointsNoSignScale-duplicatePoints)*duplicateScalesRotatedSigns - - pos3D = torch.cat((pos3D_fr, GaussianPoints - ), dim=0) - colors = torch.cat((colors_fr,duplicateColors - ), dim=0) - return pos3D, colors.to(torch.float32) - - - - -def splat_unpacker_with_threshold(neighbors, fileName, threshold): - - #check if .ply or .splat or gen - if fileName[-3:] == 'ply': - positionsFile, scalesFile, rotsFile, colorsFile = splatio.ply_to_numpy(fileName) - fileType = 'ply' - elif fileName[-3:] == 'gen': - resolution = 100*5 - radius = 1 - indices = np.arange(0, resolution, dtype=float) + 0.5 - phi = np.arccos(1 - 2*indices/resolution) - theta = np.pi * (1 + 5**0.5) * indices - - x = radius * np.sin(phi) * np.cos(theta) - y = radius * np.sin(phi) * np.sin(theta) - z = radius * np.cos(phi) - - pos3DFiltered = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32))) - normalsFiltered = torch.full_like(pos3DFiltered, 0.0) - opacity = torch.full_like(pos3DFiltered, 0.5) - #colorsFiltered = torch.full_like(pos3DFiltered, 0.5) - scalesFiltered = torch.full_like(pos3DFiltered, 0.015) - rotsFiltered = torch.full((resolution, 4), 0.0) - rotsFiltered[:,0] = 1.0 - fileType = 'splat' - - #make is stripy - c1 = torch.tensor([0.9, 0.9, 0.9], dtype=torch.float32) - c2 = torch.tensor([0.9, 0.0, 0.0], dtype=torch.float32) - c3 = torch.tensor([0.0, 0.9, 0.0], dtype=torch.float32) - c4 = torch.tensor([0.0, 0.0, 0.9], dtype=torch.float32) - c5 = torch.tensor([0.9, 0.9, 0.0], dtype=torch.float32) - - C1 = torch.from_numpy(np.tile(c1.numpy(), (int(resolution/5),1))) - C2 = torch.from_numpy(np.tile(c2.numpy(), (int(resolution/5),1))) - C3 = torch.from_numpy(np.tile(c3.numpy(), (int(resolution/5),1))) - C4 = torch.from_numpy(np.tile(c4.numpy(), (int(resolution/5),1))) - C5 = torch.from_numpy(np.tile(c5.numpy(), (int(resolution/5),1))) - - colorsFiltered = torch.cat((C1, C2, C3, C4, C5), dim=0) - - addnoise = True - if (addnoise): - noisePoints = torch.from_numpy(np.random.uniform(-1.5, 1.5, (50,3)).astype(np.float32)) - pos3DFiltered = torch.cat((noisePoints, pos3DFiltered - ), dim=0) - - noiseColors = torch.from_numpy(np.random.uniform(0, 9.99, (50,3)).astype(np.float32)) - colorsFiltered = torch.cat((noiseColors, colorsFiltered - ), dim=0) - - noiseNormals = torch.full_like(noisePoints, 0.0) - normalsFiltered = torch.cat((noiseNormals, normalsFiltered - ), dim=0) - - noiseScales = torch.full_like(noisePoints, 0.015) - scalesFiltered = torch.cat((noiseScales, scalesFiltered - ), dim=0) - - noiseopacity = torch.full_like(noisePoints, 0.5) - opacity = torch.cat((noiseopacity, opacity - ), dim=0) - - noiseScales = torch.full((50, 4), 0.0) - rotsFiltered = torch.cat((noiseScales, rotsFiltered - ), dim=0) - rotsFiltered[:,0] = 1.0 - - - return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType - elif fileName[-4:] == 'gen2': - - - z = torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9], dtype=torch.float32) - y = torch.tensor([1.7, 1.4, 1.0, 0.5, 0.0], dtype=torch.float32) - x = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=torch.float32) - - - pos3DFiltered = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32))) - normalsFiltered = torch.full_like(pos3DFiltered, 0.0) - opacity = torch.full_like(pos3DFiltered, 0.99) - #colorsFiltered = torch.full_like(pos3DFiltered, 0.5) - #scalesFiltered = torch.full_like(pos3DFiltered, 0.02) - rotsFiltered = torch.full((5, 4), 0.0) - rotsFiltered[:,0] = 1.0 - fileType = 'splat' - - scalesFiltered = torch.tensor([[0.03, 0.03, 0.01], - [0.03, 0.03, 0.05], - [0.07, 0.07, 0.03], - [0.05, 0.05, 0.09], - [0.09, 0.12, 0.12,] - ], dtype= torch.float32) - - - #each Gaussian has its own color - c1 = torch.tensor([0.1, 0.2, 0.5, 0.7, 0.9], dtype=torch.float32) - c2 = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1], dtype=torch.float32) - c3 = torch.tensor([0.9, 0.7, 0.5, 0.2, 0.1], dtype=torch.float32) - - - colorsFiltered = torch.from_numpy(np.transpose(np.stack([c1,c2,c3]).astype(np.float32))) - - - - return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType - else: - positionsFile, scalesFile, rotsFile, colorsFile = splatio.splat_to_numpy(fileName) - fileType = 'splat' - - - positionsNP = positionsFile.copy() - scalesNP = scalesFile.copy() - rotsNP = rotsFile.copy() - colorsNP = colorsFile.copy() - - - #Get the raw PLY data into tensors - pos3D = torch.from_numpy(positionsNP) - colors = torch.from_numpy(colorsNP) - colors.clamp(0,1) - rots = torch.from_numpy(rotsNP) - scales = torch.from_numpy(scalesNP) - points = torch.from_numpy(positionsNP) - - samples = points.size()[0] - y = points - x = points - - #knn find midpoint - assign_index = knn(x, y, neighbors) - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - #calculate normals - normals = y - out[0:samples] - - if threshold == 100: - - pos3DFiltered = pos3D - normalsFiltered = normals - colorsFiltered = colors - scalesFiltered = scales - rotsFiltered = rots - - else: - - #use Euclidian distances to create a mask - #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 ) - distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 ) - threshold = np.clip(threshold, 0, 100) - boundry = np.percentile(distances, threshold) - mask = distances < boundry - - - pos3DFiltered = pos3D[mask] - normalsFiltered = normals[mask] - colorsFiltered = colors[mask] - scalesFiltered = scales[mask] - rotsFiltered = rots[mask] - - - displayNormals = False - if (displayNormals): - - coneSize = 0.1 - diffvectorNum =normals.numpy() - - showBoth = False - if (showBoth): - - - diffBoth = torch.cat((normals, normalsFiltered), dim=0).numpy().astype(np.float16) - pos3dShifted = pos3DFiltered - pos3dShifted[:,2] = pos3dShifted[:,2] + 1 - posBoth = torch.cat((pos3D, pos3dShifted), dim=0).numpy().astype(np.float16) - - distancesMasked = distances[mask] - diffBothNum = torch.cat((distances, distancesMasked), dim=0).numpy().astype(np.float16) - - # Normalised [0,1] - diffNumNorm = (diffBothNum - np.min(diffBothNum))/np.ptp(diffBothNum) - - #point cloud - marker_data = go.Scatter3d( - x=posBoth[:, 0], - y=posBoth[:, 2], - z=-posBoth[:, 1], - marker=go.scatter3d.Marker(size=1, color= diffNumNorm), - opacity=0.8, - mode='markers' - ) - fig=go.Figure(data=marker_data) - fig.show() - - - - ''' - #Too high memory usage :( - diffBoth = torch.cat((normals, normalsFiltered), dim=0).numpy().astype(np.float16) - pos3dShifted = pos3DFiltered - pos3dShifted[:,2] = pos3dShifted[:,2] + 1 - posBoth = torch.cat((pos3D, pos3dShifted), dim=0).numpy().astype(np.float16) - - - #normals - fig = go.Figure(data=go.Cone( - x=posBoth[:, 0], - y=posBoth[:, 2], - z=-posBoth[:, 1], - u=diffBoth[:, 0], - v=diffBoth[:, 2], - w=-diffBoth[:, 1], - sizemode="raw", - sizeref=coneSize, - anchor="tail")) - - fig.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig.show() - - exit() - ''' - - - else: - - diffNum = distances.numpy() - # Normalised [0,1] - diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum) - - #point cloud - marker_data = go.Scatter3d( - x=pos3DFiltered[:, 0], - y=pos3DFiltered[:, 2], - z=-pos3DFiltered[:, 1], - marker=go.scatter3d.Marker(size=1, color= diffNumNorm), - opacity=0.8, - mode='markers' - ) - fig=go.Figure(data=marker_data) - fig.show() - - - - #normals - fig = go.Figure(data=go.Cone( - x=pos3D[:, 0], - y=pos3D[:, 2], - z=-pos3D[:, 1], - u=diffvectorNum[:, 0], - v=diffvectorNum[:, 2], - w=-diffvectorNum[:, 1], - sizemode="absolute", - sizeref=coneSize, - anchor="tail")) - - fig.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig.show() - - - return pos3DFiltered, normalsFiltered, colorsFiltered[:,0:3], colorsFiltered[:,3], scalesFiltered, rotsFiltered, fileType - - - -def generate_with_noise_ablation(neighbors, fileName, threshold): - - resolution = 100*5 - radius = 1 - indices = np.arange(0, resolution, dtype=float) + 0.5 - phi = np.arccos(1 - 2*indices/resolution) - theta = np.pi * (1 + 5**0.5) * indices - - x = radius * np.sin(phi) * np.cos(theta) - y = radius * np.sin(phi) * np.sin(theta) - z = radius * np.cos(phi) - - pos3D = torch.from_numpy(np.transpose(np.stack([x,y,z]).astype(np.float32))) - normalsFiltered = torch.full_like(pos3D, 0.0) - opacity = torch.full_like(pos3D, 0.5) - #colorsFiltered = torch.full_like(pos3DFiltered, 0.5) - scales = torch.full_like(pos3D, 0.015) - rots = torch.full((resolution, 4), 0.0) - rots[:,0] = 1.0 - fileType = 'splat' - - #make is stripy - c1 = torch.tensor([0.9, 0.9, 0.9], dtype=torch.float32) - c2 = torch.tensor([0.9, 0.0, 0.0], dtype=torch.float32) - c3 = torch.tensor([0.0, 0.9, 0.0], dtype=torch.float32) - c4 = torch.tensor([0.0, 0.0, 0.9], dtype=torch.float32) - c5 = torch.tensor([0.9, 0.9, 0.0], dtype=torch.float32) - - C1 = torch.from_numpy(np.tile(c1.numpy(), (int(resolution/5),1))) - C2 = torch.from_numpy(np.tile(c2.numpy(), (int(resolution/5),1))) - C3 = torch.from_numpy(np.tile(c3.numpy(), (int(resolution/5),1))) - C4 = torch.from_numpy(np.tile(c4.numpy(), (int(resolution/5),1))) - C5 = torch.from_numpy(np.tile(c5.numpy(), (int(resolution/5),1))) - - colors = torch.cat((C1, C2, C3, C4, C5), dim=0) - - addnoise = True - if (addnoise): - noisePoints = torch.from_numpy(np.random.uniform(-1.01, 1.01, (50,3)).astype(np.float32)) - pos3D = torch.cat((noisePoints, pos3D - ), dim=0) - - noiseColors = torch.from_numpy(np.random.uniform(0, 9.99, (50,3)).astype(np.float32)) - colors = torch.cat((noiseColors, colors - ), dim=0) - - noiseScales = torch.full_like(noisePoints, 0.015) - scales = torch.cat((noiseScales, scales - ), dim=0) - - noiseopacity = torch.full_like(noisePoints, 0.5) - opacity = torch.cat((noiseopacity, opacity - ), dim=0) - - noiseRots = torch.full((50, 4), 0.0) - rots = torch.cat((noiseRots, rots - ), dim=0) - rots[:,0] = 1.0 - - points = pos3D - - samples = points.size()[0] - y = points - x = points - - #knn find midpoint - assign_index = knn(x, y, neighbors) - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - #calculate normals - normals = y - out[0:samples] - - if threshold == 100: - - pos3DFiltered = pos3D - normalsFiltered = normals - colorsFiltered = colors - scalesFiltered = scales - rotsFiltered = rots - opacity = opacity - - else: - - #use Euclidian distances to create a mask - #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 ) - distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 ) - threshold = np.clip(threshold, 0, 100) - boundry = np.percentile(distances, threshold) - mask = distances < boundry - - - pos3DFiltered = pos3D[mask] - normalsFiltered = normals[mask] - colorsFiltered = colors[mask] - scalesFiltered = scales[mask] - rotsFiltered = rots[mask] - opacity = opacity[mask] - - displayNormals = True - if (displayNormals): - - coneSize = 1 - diffvectorNum = normals.numpy() - diffvectorfilteredNum = normalsFiltered.numpy() - diffNum = distances.numpy() - - # Normalised [0,1] - diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum) - - #point cloud - marker_data = go.Scatter3d( - x=pos3D[:, 0], - y=pos3D[:, 2], - z=-pos3D[:, 1], - marker=go.scatter3d.Marker(size=5, color= diffNumNorm), - opacity=0.8, - mode='markers' - ) - fig=go.Figure(data=marker_data) - fig.show() - - #normals - fig2 = go.Figure(data=go.Cone( - x=pos3D[:, 0], - y=pos3D[:, 2], - z=-pos3D[:, 1], - u=diffvectorNum[:, 0], - v=diffvectorNum[:, 2], - w=-diffvectorNum[:, 1], - sizemode="raw", - sizeref=coneSize, - anchor="tail")) - - fig2.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig2.show() - - - #normals Filtered - fig3 = go.Figure(data=go.Cone( - x=pos3DFiltered[:, 0], - y=pos3DFiltered[:, 2], - z=-pos3DFiltered[:, 1], - u=diffvectorfilteredNum[:, 0], - v=diffvectorfilteredNum[:, 2], - w=-diffvectorfilteredNum[:, 1], - sizemode="raw", - sizeref=coneSize, - anchor="tail")) - - fig3.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig3.show() - - - return pos3DFiltered, normalsFiltered, colorsFiltered, opacity[:,0], scalesFiltered, rotsFiltered, fileType - - - - -def splat_unpacker_threshold_graph_normals(neighbors, fileName, threshold): - - coneSize = 1 - - positionsNP, scalesNP, rotsNP, colorsNP = splatio.ply_to_numpy(fileName) - - #Get the raw PLY data into tensors - pos3D = torch.from_numpy(positionsNP) - colors = torch.from_numpy(colorsNP) - colors.clamp(0,1) - rots = torch.from_numpy(rotsNP) - scales = torch.from_numpy(scalesNP) - points = torch.from_numpy(positionsNP) - - samples = points.size()[0] - y = points - x = points - - #knn find midpoint - assign_index = knn(x, y, neighbors) - indexSrc = assign_index[0:1][0] - indexSrcTrn = indexSrc.reshape(-1,1) - index = indexSrcTrn.expand(indexSrc.size()[0],3) - - src = x[assign_index[1:2]][0] - out = src.new_zeros(src.size()) - out = scatter_mean(src, index, 0, out=out) - - #calculate normals - normals = y - out[0:samples] - - - #use Euclidian distances to create a mask - #distances = torch.sqrt((out[0:samples][:,0])**2 +(out[0:samples][:,1])**2+(out[0:samples][:,2])**2 ) - distances = torch.sqrt((y[:,0] - out[0:samples][:,0])**2 +(y[:,1] - out[0:samples][:,1])**2+(y[:,2] - out[0:samples][:,2])**2 ) - threshold = np.clip(threshold, 0, 100) - boundry = np.percentile(distances, threshold) - mask = distances < boundry - - - pos3DFiltered = pos3D[mask] - normalsFiltered = normals[mask] - colorsFiltered = colors[mask] - scalesFiltered = scales[mask] - rotsFiltered = rots[mask] - - - - diffvector = y - out[0:samples] - - diffvectorNum =diffvector.numpy() - diffNum = distances.numpy() - resultNum = out[0:samples].numpy() - # Normalised [0,1] - diffNumNorm = (diffNum - np.min(diffNum))/np.ptp(diffNum) - - ''' - #point cloud - marker_data = go.Scatter3d( - x=points[:, 0], - y=points[:, 2], - z=-points[:, 1], - marker=go.scatter3d.Marker(size=3, color= diffNumNorm), - opacity=0.8, - mode='markers' - ) - fig=go.Figure(data=marker_data) - fig.show() - ''' - #normals - fig = go.Figure(data=go.Cone( - x=points[:, 0], - y=points[:, 2], - z=-points[:, 1], - u=diffvectorNum[:, 0], - v=diffvectorNum[:, 2], - w=-diffvectorNum[:, 1], - sizemode="absolute", - sizeref=coneSize, - anchor="tail")) - - fig.update_layout( - scene=dict(domain_x=[0, 1], - camera_eye=dict(x=-1.57, y=1.36, z=0.58))) - - fig.show() - - - return pos3DFiltered, normalsFiltered, colorsFiltered, scalesFiltered, rotsFiltered \ No newline at end of file diff --git a/FastSplatStyler_huggingface/style_ims/21.jpg b/FastSplatStyler_huggingface/style_ims/21.jpg deleted file mode 100644 index 462c08de54e803c01fe3d86f48898272629625fc..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/21.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/31.jpg b/FastSplatStyler_huggingface/style_ims/31.jpg deleted file mode 100644 index fa4f7071105d8fc25da420f158b8d4c13371b25e..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/31.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/47.jpg b/FastSplatStyler_huggingface/style_ims/47.jpg deleted file mode 100644 index 464542fb9466390479a9409179cb53db9d80031a..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/47.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/6.jpg b/FastSplatStyler_huggingface/style_ims/6.jpg deleted file mode 100644 index d15ca0986ed0d66e0da2b497d0b929376e465e8b..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/6.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/img11.jpg b/FastSplatStyler_huggingface/style_ims/img11.jpg deleted file mode 100644 index 7167997e015ffbff43584315927d65f9a8a65e9f..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/img11.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/img66.jpg b/FastSplatStyler_huggingface/style_ims/img66.jpg deleted file mode 100644 index d15ca0986ed0d66e0da2b497d0b929376e465e8b..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/img66.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style-10.jpg b/FastSplatStyler_huggingface/style_ims/style-10.jpg deleted file mode 100644 index 70f6adec87d980523fe1d59493a9556b076913b3..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/style_ims/style-10.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:efe9c5fcf5d652ba574fd1079eda9ec26aa823fd630750469963435b552aae82 -size 144264 diff --git a/FastSplatStyler_huggingface/style_ims/style0.jpg b/FastSplatStyler_huggingface/style_ims/style0.jpg deleted file mode 100644 index 820b093f207d0884a0ca0e5d9236d29dc7aa730b..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/style_ims/style0.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1af09b2c18a6674b7d88849cb87564dd77e1ce04d1517bb085449b614cc0c8d8 -size 376101 diff --git a/FastSplatStyler_huggingface/style_ims/style1.jpg b/FastSplatStyler_huggingface/style_ims/style1.jpg deleted file mode 100644 index 63aa06fe4296fa6e0feb39d5c6be57fc177283d7..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style1.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style100.JPG b/FastSplatStyler_huggingface/style_ims/style100.JPG deleted file mode 100644 index 0e6c11487e79fe59db161380d9b26ec34581d7f8..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style100.JPG and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style12.jpg b/FastSplatStyler_huggingface/style_ims/style12.jpg deleted file mode 100644 index 0da803b36c0d6e23922a0f133cc0a1d9fe35ffb3..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style12.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style19.jpg b/FastSplatStyler_huggingface/style_ims/style19.jpg deleted file mode 100644 index a7b5df070485bacad6c1f0c9cfe3e48fe0995509..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style19.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style2.jpg b/FastSplatStyler_huggingface/style_ims/style2.jpg deleted file mode 100644 index c6e61e0ddbdbe6b9cf24d239c2aa64d966f82dba..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/style_ims/style2.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:60192c2cb0eefc036ecc5bf04810f3155ec30239608b6ebefa4ae3ea96864ed2 -size 140682 diff --git a/FastSplatStyler_huggingface/style_ims/style3.jpg b/FastSplatStyler_huggingface/style_ims/style3.jpg deleted file mode 100644 index 3dbb29cf89e70f9421afb3341674eec0a991cf52..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style3.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style31.jpg b/FastSplatStyler_huggingface/style_ims/style31.jpg deleted file mode 100644 index fa4f7071105d8fc25da420f158b8d4c13371b25e..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style31.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style4.jpg b/FastSplatStyler_huggingface/style_ims/style4.jpg deleted file mode 100644 index 5f819f887e6d85dc19f334533ebb5c25b11d1f2f..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style4.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style40.jpg b/FastSplatStyler_huggingface/style_ims/style40.jpg deleted file mode 100644 index 6f086b087a8f895774866565f01b48b5902c197e..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style40.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style44.jpg b/FastSplatStyler_huggingface/style_ims/style44.jpg deleted file mode 100644 index 356f707637fd225bef6c5d4d422713e486239138..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/style_ims/style44.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:28b09fc5f3ca18e8dda06475f72f3ec8d912c71f0b35a372dd309ea2cc4cfda9 -size 178110 diff --git a/FastSplatStyler_huggingface/style_ims/style5.jpg b/FastSplatStyler_huggingface/style_ims/style5.jpg deleted file mode 100644 index 732b7d5de2cc43b4a8d1e092b3a06fed496ec1d6..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style5.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style50.jpg b/FastSplatStyler_huggingface/style_ims/style50.jpg deleted file mode 100644 index d625f87588f641c5fcb42fc930cb0ccd57bb72c5..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style50.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style6.jpg b/FastSplatStyler_huggingface/style_ims/style6.jpg deleted file mode 100644 index 70f6adec87d980523fe1d59493a9556b076913b3..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/style_ims/style6.jpg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:efe9c5fcf5d652ba574fd1079eda9ec26aa823fd630750469963435b552aae82 -size 144264 diff --git a/FastSplatStyler_huggingface/style_ims/style66.jpg b/FastSplatStyler_huggingface/style_ims/style66.jpg deleted file mode 100644 index f077a49af8699a399340df8cbb9e477631b697cf..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style66.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style7.jpg b/FastSplatStyler_huggingface/style_ims/style7.jpg deleted file mode 100644 index cb0e9091c8c18c589ff104a92c29d119ba1bf1b1..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style7.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style8.jpg b/FastSplatStyler_huggingface/style_ims/style8.jpg deleted file mode 100644 index 3462b70be0bfe4c4272ec3d86e4639a7000b1a16..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style8.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style9.jpg b/FastSplatStyler_huggingface/style_ims/style9.jpg deleted file mode 100644 index b2d70e41daa0c72264c13d3e77291c7273837c89..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style9.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/style99.jpg b/FastSplatStyler_huggingface/style_ims/style99.jpg deleted file mode 100644 index 1a07615bff581507454b934ae14c7f53f334eb4e..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/style99.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/style_ims/waves.jpg b/FastSplatStyler_huggingface/style_ims/waves.jpg deleted file mode 100644 index f077a49af8699a399340df8cbb9e477631b697cf..0000000000000000000000000000000000000000 Binary files a/FastSplatStyler_huggingface/style_ims/waves.jpg and /dev/null differ diff --git a/FastSplatStyler_huggingface/styletransfer_splat.py b/FastSplatStyler_huggingface/styletransfer_splat.py deleted file mode 100644 index 8575f9cfb7cc0ad3b7faeb733dc87ca244a5b9af..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/styletransfer_splat.py +++ /dev/null @@ -1,218 +0,0 @@ -import pointCloudToMesh as ply2M -import argparse -import utils -import graph_io as gio -from clusters import * -#from tqdm import tqdm,trange -import splat_mesh_helpers as splt -import clusters as cl -from torch_geometric.data import Data -from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator -import pointCloudToMesh as plyToMesh -import plotly.graph_objects as go -import pyvista as pv - -from time import time - -from graph_networks.LinearStyleTransfer_vgg import encoder,decoder -from graph_networks.LinearStyleTransfer_matrix import TransformLayer - -from graph_networks.LinearStyleTransfer.libs.Matrix import MulLayer -from graph_networks.LinearStyleTransfer.libs.models import encoder4, decoder4 - -#import matplotlib.pyplot as plt - - -def styletransfer_with_filtering_sampling(filename, stylePath, outPath, device = 'cpu', threshold=99.9, samplingRate=1.5, displayPointCloud = False): - - print("Running on device:", device) - - n = 25 - style_ref = utils.loadImage(stylePath, shape=(256*2,256*2)) - ratio=.25 - depth = 3 - - pos3D_Original, _, colors_Original, opacity_Original, scales_Original, rots_Original, fileType = splt.splat_unpacker_with_threshold(n, filename, threshold) - - time1_start = time() - - #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1)) - if samplingRate > 1: - GaussianSamples = int(pos3D_Original.shape[0]*samplingRate) - pos3D, colors = splt.splat_GaussianSuperSampler(pos3D_Original.clone(), colors_Original.clone(), opacity_Original.clone(), scales_Original.clone(), rots_Original.clone(), GaussianSamples) - else: - pos3D, colors = pos3D_Original, colors_Original - #plyToMesh.graph_Points(pos3D, torch.clamp(colors, 0, 1)) - #plyToMesh.graph_Points(pos3D_Original, torch.clamp(colors_Original, 0, 1)) - - time1_end = time() - - print("Number of nodes in the graph:", pos3D.shape[0]) - - print(f"Time taken for Gaussian Super Sampling: {time1_end - time1_start}") - - - if (displayPointCloud): - #point cloud - point_cloud = pv.PolyData(pos3D.numpy()) - - # Add colors to the point data - point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy() - - # Plot the point cloud - plotter = pv.Plotter() - plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.05) - plotter.show_axes() - plotter.show() - - time2_start = time() - - #find normals - normalsNP = ply2M.Estimate_Normals(pos3D, threshold) - normals = torch.from_numpy(normalsNP) - - #print("Time to compute normals:", time() - time2_start) - - up_vector = torch.tensor([[1,1,1]],dtype=torch.float) - #up_vector = 2*torch.rand((1,3))-1 - up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) - - pos3D.to(device) - colors.to(device) - normals.to(device) - up_vector.to(device) - - # Build initial graph - #edge_index are neighbors of a point, directions are the directions from that point - edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) - #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv - edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) - - # Generate info for downsampled versions of the graph - clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) - - time2_end = time() - print(f"Time taken for graph construction: {time2_end - time2_start}") - - time3_start = time() - - # Make final graph and metadata needed for mapping the result after going through the network - content = Data(x=colors,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) - content_meta = Data(pos3D=pos3D) - - style,_ = gio.image2Graph(style_ref,depth=3,device=device) - - # Load original network - enc_ref = encoder4() - dec_ref = decoder4() - matrix_ref = MulLayer('r41') - - enc_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/vgg_r41.pth')) - dec_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/dec_r41.pth')) - matrix_ref.load_state_dict(torch.load('graph_networks/LinearStyleTransfer/models/r41.pth',map_location=torch.device(device))) - - # Copy weights to graph network - enc = encoder(padding_mode="replicate") - dec = decoder(padding_mode="replicate") - matrix = TransformLayer() - - with torch.no_grad(): - enc.copy_weights(enc_ref) - dec.copy_weights(dec_ref) - matrix.copy_weights(matrix_ref) - - content = content.to(device) - style = style.to(device) - enc = enc.to(device) - dec = dec.to(device) - matrix = matrix.to(device) - - # Run graph network - with torch.no_grad(): - cF = enc(content) - sF = enc(style) - feature,transmatrix = matrix(cF['r41'],sF['r41'], - content.edge_indexes[3],content.selections_list[3], - style.edge_indexes[3],style.selections_list[3], - content.interps_list[3] if hasattr(content,'interps_list') else None) - result = dec(feature,content) - result = result.clamp(0,1) - - colors[:, 0:3] = result - - time3_end = time() - print(f"Time taken for stylization: {time3_end - time3_start}") - - if (displayPointCloud): - #point cloud - point_cloud = pv.PolyData(pos3D.numpy()) - - # Add colors to the point data - point_cloud.point_data['colors'] = torch.clamp(colors, 0, 3).numpy() - - # Plot the point cloud - plotter = pv.Plotter() - plotter.add_points(point_cloud, scalars='colors', rgb=True, point_size=0.25) - plotter.show_axes() - plotter.show() - - time4_start = time() - - #create the interpolator - interp2 = NearestNDInterpolator(pos3D.cpu(), colors.cpu()) - results_OriginalNP = interp2(pos3D_Original) - results_OriginalNP64 = torch.from_numpy(results_OriginalNP) - results_Original = results_OriginalNP64.to(torch.float32) - - colors_and_opacity_Original = torch.cat((results_Original, opacity_Original.unsqueeze(1)), dim=1) - - time4_end = time() - print(f"Time taken for interpolation: {time4_end - time4_start}") - - # Save/show result - splt.splat_save(pos3D_Original.numpy(), scales_Original.numpy(), rots_Original.numpy(), colors_and_opacity_Original.numpy(), outPath, fileType) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "filename", - type=str, - default='' - ) - parser.add_argument( - "--stylePath", - type=str, - default="style_ims/style0.jpg" - ) - parser.add_argument( - "--outPath", - type=str, - default='output.splat' - ) - parser.add_argument( - "--device", - default= 0 if torch.cuda.is_available() else "cpu", - choices=list(range(torch.cuda.device_count())) + ["cpu"] or ["cpu"] - ) - parser.add_argument( - "--threshold", - type=float, - default=99.8 - ) - parser.add_argument( - "--samplingRate", - type=float, - default=1.5 - ) - parser.add_argument( - "--displayPointCloud", - action='store_true' - ) - args = parser.parse_args() - styletransfer_with_filtering_sampling(**vars(args)) - - -if __name__ == "__main__": - main() diff --git a/FastSplatStyler_huggingface/utils.py b/FastSplatStyler_huggingface/utils.py deleted file mode 100644 index e018f9b6cb1f99176e421ae7b5035a62a747c36b..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/utils.py +++ /dev/null @@ -1,222 +0,0 @@ -import torch -import numpy as np -from imageio import imread, imwrite -from skimage.transform import resize -import os -from scipy.ndimage import binary_erosion, binary_dilation, distance_transform_cdt, grey_dilation -import matplotlib.pyplot as plt -import json - -def ensure_dir(file_path): - directory = os.path.dirname(file_path) - if not os.path.exists(directory): - os.makedirs(directory) - -def loadJSON(file): - with open(file) as f: - result = json.load(f) - return result - -def saveJSON(data,file): - with open(file, 'w') as outfile: - json.dump(data, outfile) - -def loadImage(filename, asTensor = True, imagenet_mean = False, shape = None): - - image = imread(filename) - - if image.ndim == 2: - image = np.stack((image,image,image),axis=2) - - image = image[:,:,:3]/255 # No alpha channels - - if shape is not None: - image = resize(image, shape) - - if imagenet_mean: - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - image = (image - mean)/std - - # Reorganize for torch - if asTensor: - image = np.transpose(image,(2,0,1)) - - image = np.expand_dims(image,axis=0) - - return torch.tensor(image,dtype=torch.float) - else: - return image - -def loadMask(filename, asTensor = True, shape = None): - - image = imread(filename) - - if image.ndim == 3: - image = image[:,:,0] - - if shape is not None: - image = resize(image, shape) - - mask = image >= 128 - - # Reorganize for torch - if asTensor: - mask = np.expand_dims(mask,axis=0) - - return torch.tensor(mask,dtype=torch.bool) - else: - return mask - -def toTensor(numpy_data): - image = np.transpose(numpy_data,(2,0,1)) - image = np.expand_dims(image,axis=0) - return torch.tensor(image,dtype=torch.float) - -def toTorch(numpy_data): - return toTensor(numpy_data) - -def toNumpy(tensor, permute=True): - image = np.squeeze(tensor.detach().clone().cpu().numpy()) - if permute: - image = np.transpose(image,(1,2,0)) - return image - -def makeCanvas(x,original): - - if torch.is_tensor(original): - original = toNumpy(original) - - if x.ndim > 1: - if x.shape[-1] == original.shape[-1]: - return original - - rows,cols,_ = original.shape - if x.ndim > 1: - return np.zeros((rows,cols,x.shape[-1])) - else: - return np.zeros((rows,cols)) - -def reverse_selection(s): - return [0, 5, 6, 7, 8, 1, 2, 3, 4][s] - - -def extrapolate_image(im, numpy = False): - - rows,cols,ch = im.shape - if numpy: - result = np.zeros((rows+1,cols+1,ch)) - else: - result = torch.zeros((rows+1,cols+1,ch)) - - result[:rows,:cols] = im - - # Extrapolate last column - result[:rows,cols] = 2*result[:rows,cols-1] - result[:rows,cols-2] - # Extrapolate last row - result[rows] = 2*result[rows-1] - result[rows-2] - - return result - -def bilinear_interpolate(im, x, y, numpy = False): - - if numpy: - - im = extrapolate_image(im,numpy=True) - - x = np.asarray(x) - y = np.asarray(y) - - x0 = np.floor(x).astype(int) - x1 = x0 + 1 - y0 = np.floor(y).astype(int) - y1 = y0 + 1 - - x0 = np.clip(x0, 0, im.shape[1]-1); - x1 = np.clip(x1, 0, im.shape[1]-1); - y0 = np.clip(y0, 0, im.shape[0]-1); - y1 = np.clip(y1, 0, im.shape[0]-1); - - Ia = im[ y0, x0 ] - Ib = im[ y1, x0 ] - Ic = im[ y0, x1 ] - Id = im[ y1, x1 ] - - wa = (x1-x) * (y1-y) - wb = (x1-x) * (y-y0) - wc = (x-x0) * (y1-y) - wd = (x-x0) * (y-y0) - - wa = np.expand_dims(wa,axis=1) - wb = np.expand_dims(wb,axis=1) - wc = np.expand_dims(wc,axis=1) - wd = np.expand_dims(wd,axis=1) - - else: - im = im[0].permute((1,2,0)).float() - im = extrapolate_image(im) - - x0 = torch.floor(x).long() - x1 = x0 + 1 - y0 = torch.floor(y).long() - y1 = y0 + 1 - - x0 = torch.clip(x0, 0, im.shape[1]-1); - x1 = torch.clip(x1, 0, im.shape[1]-1); - y0 = torch.clip(y0, 0, im.shape[0]-1); - y1 = torch.clip(y1, 0, im.shape[0]-1); - - Ia = im[y0,x0] - Ib = im[y1,x0] - Ic = im[y0,x1] - Id = im[y1,x1] - - wa = (x1-x) * (y1-y) - wb = (x1-x) * (y-y0) - wc = (x-x0) * (y1-y) - wd = (x-x0) * (y-y0) - - wa = torch.unsqueeze(wa,dim=1) - wb = torch.unsqueeze(wb,dim=1) - wc = torch.unsqueeze(wc,dim=1) - wd = torch.unsqueeze(wd,dim=1) - - return wa*Ia + wb*Ib + wc*Ic + wd*Id - -def cosineWeighting(rows,cols): - phi_vals = np.linspace(-np.pi/2,np.pi/2,rows) - cosines = np.cos(phi_vals) - - result = np.zeros((rows,cols)) - for i in range(rows): - result[i, :] = cosines[i] - - return result - -def cross(a,b): - # Computes the cross product of two torch tensors - out_i = a[:,1]*b[:,2] - a[:,2]*b[:,1] - out_j = a[:,2]*b[:,0] - a[:,0]*b[:,2] - out_k = a[:,0]*b[:,1] - a[:,1]*b[:,0] - - out = torch.stack((out_i,out_j,out_k),dim=1) - - return out - -def interpolatePointCloud2D(source_points,source_features,target_x,target_y,extrapolate=True): - #from scipy.interpolate import LinearNDInterpolator as Interpolater - from scipy.interpolate import CloughTocher2DInterpolator as Interpolater - from scipy.interpolate import NearestNDInterpolator as Extrapolater - interp = Interpolater(source_points,source_features) - result = interp(target_x,target_y) - - chk = np.isnan(result) - - if chk.any() and extrapolate: - nearest = Extrapolater(source_points,source_features) - return np.where(chk, nearest(target_x,target_y), result) - else: - result = np.where(chk, 0, result) - return result - - \ No newline at end of file diff --git a/FastSplatStyler_huggingface/venv libs.txt b/FastSplatStyler_huggingface/venv libs.txt deleted file mode 100644 index 27689b8f3043e9408989c0bc1564afb57beb6fcf..0000000000000000000000000000000000000000 --- a/FastSplatStyler_huggingface/venv libs.txt +++ /dev/null @@ -1,122 +0,0 @@ -aiohttp 3.9.5 -aiosignal 1.3.1 -asttokens 3.0.0 -attrs 23.2.0 -blinker 1.9.0 -bqplot 0.12.44 -certifi 2024.7.4 -charset-normalizer 3.3.2 -click 8.1.8 -colorama 0.4.6 -comm 0.2.2 -ConfigArgParse 1.7 -contourpy 1.2.1 -cycler 0.12.1 -dash 2.18.2 -dash-core-components 2.0.0 -dash-html-components 2.0.0 -dash-table 5.0.0 -decorator 5.1.1 -equilib 0.0.1 -executing 2.2.0 -fastjsonschema 2.21.1 -filelock 3.13.1 -Flask 3.0.3 -fonttools 4.53.1 -frozenlist 1.4.1 -fsspec 2024.2.0 -idna 3.7 -imageio 2.34.2 -importlib_metadata 8.6.1 -intel-openmp 2021.4.0 -ipydatawidgets 4.3.5 -ipython 8.31.0 -ipyvue 1.11.2 -ipyvuetify 1.11.1 -ipywebrtc 0.6.0 -ipywidgets 8.1.5 -itsdangerous 2.2.0 -jedi 0.19.2 -Jinja2 3.1.3 -joblib 1.4.2 -jsonschema 4.23.0 -jsonschema-specifications 2024.10.1 -jupyter_core 5.7.2 -jupyterlab_widgets 3.0.13 -kiwisolver 1.4.5 -lazy_loader 0.4 -markdown-it-py 3.0.0 -MarkupSafe 2.1.5 -matplotlib 3.9.1 -matplotlib-inline 0.1.7 -mdurl 0.1.2 -meshio 5.3.5 -mkl 2021.4.0 -mpmath 1.3.0 -multidict 6.0.5 -nbformat 5.10.4 -nest-asyncio 1.6.0 -networkx 3.2.1 -numpy 1.26.3 -open3d 0.19.0 -opencv-python 4.10.0.84 -packaging 24.1 -pandas 2.2.2 -parso 0.8.4 -pillow 10.2.0 -pip 24.0 -platformdirs 4.2.2 -plotly 5.24.1 -plyfile 1.0.3 -pooch 1.8.2 -prompt_toolkit 3.0.50 -psutil 6.0.0 -pure_eval 0.2.3 -py360convert 0.1.0 -pyg_lib 0.4.0+pt23cu121 -pyglet 1.5.15 -Pygments 2.18.0 -pyparsing 3.1.2 -python-dateutil 2.9.0.post0 -pythreejs 2.4.2 -pytz 2024.1 -pyvista 0.44.1 -pywin32 308 -referencing 0.36.2 -requests 2.32.3 -retrying 1.3.4 -rich 13.7.1 -rpds-py 0.22.3 -scikit-image 0.24.0 -scikit-learn 1.5.1 -scipy 1.14.0 -scooby 0.10.0 -setuptools 75.8.0 -six 1.16.0 -stack-data 0.6.3 -sympy 1.12 -tbb 2021.11.0 -tenacity 9.0.0 -threadpoolctl 3.5.0 -tifffile 2024.7.2 -torch 2.3.1+cu121 -torch_cluster 1.6.3+pt23cu121 -torch_geometric 2.5.3 -torch_scatter 2.1.2+pt23cu121 -torch_sparse 0.6.18+pt23cu121 -torch_spline_conv 1.2.2+pt23cu121 -torchaudio 2.3.1+cu121 -torchvision 0.18.1+cu121 -tqdm 4.66.4 -traitlets 5.14.3 -traittypes 0.2.1 -trimesh 4.4.3 -typing_extensions 4.9.0 -tzdata 2024.1 -urllib3 2.2.2 -vtk 9.3.1 -wcwidth 0.2.13 -Werkzeug 3.0.6 -widgetsnbextension 4.0.13 -yarl 1.9.4 -zipp 3.21.0 \ No newline at end of file