thanks to One-2-3-45 ❤
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +34 -0
- SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf +135 -0
- SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py +394 -0
- SparseNeuS_demo_v1/data/scene.py +101 -0
- SparseNeuS_demo_v1/exp/lod0/.gitignore +1 -0
- SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth +3 -0
- SparseNeuS_demo_v1/exp_runner_generic_blender_val.py +629 -0
- SparseNeuS_demo_v1/loss/__init__.py +0 -0
- SparseNeuS_demo_v1/loss/color_loss.py +152 -0
- SparseNeuS_demo_v1/loss/depth_loss.py +71 -0
- SparseNeuS_demo_v1/loss/depth_metric.py +240 -0
- SparseNeuS_demo_v1/loss/ncc.py +65 -0
- SparseNeuS_demo_v1/models/__init__.py +0 -0
- SparseNeuS_demo_v1/models/embedder.py +101 -0
- SparseNeuS_demo_v1/models/fast_renderer.py +316 -0
- SparseNeuS_demo_v1/models/featurenet.py +91 -0
- SparseNeuS_demo_v1/models/fields.py +333 -0
- SparseNeuS_demo_v1/models/patch_projector.py +211 -0
- SparseNeuS_demo_v1/models/projector.py +425 -0
- SparseNeuS_demo_v1/models/rays.py +320 -0
- SparseNeuS_demo_v1/models/render_utils.py +120 -0
- SparseNeuS_demo_v1/models/rendering_network.py +129 -0
- SparseNeuS_demo_v1/models/sparse_neus_renderer.py +985 -0
- SparseNeuS_demo_v1/models/sparse_sdf_network.py +907 -0
- SparseNeuS_demo_v1/models/trainer_generic.py +1207 -0
- SparseNeuS_demo_v1/ops/__init__.py +0 -0
- SparseNeuS_demo_v1/ops/back_project.py +175 -0
- SparseNeuS_demo_v1/ops/generate_grids.py +33 -0
- SparseNeuS_demo_v1/ops/grid_sampler.py +467 -0
- SparseNeuS_demo_v1/tsparse/__init__.py +0 -0
- SparseNeuS_demo_v1/tsparse/modules.py +326 -0
- SparseNeuS_demo_v1/tsparse/torchsparse_utils.py +137 -0
- SparseNeuS_demo_v1/utils/__init__.py +0 -0
- SparseNeuS_demo_v1/utils/misc_utils.py +219 -0
- configs/sd-objaverse-finetune-c_concat-256.yaml +117 -0
- ldm/data/__init__.py +0 -0
- ldm/data/base.py +40 -0
- ldm/data/coco.py +253 -0
- ldm/data/dummy.py +34 -0
- ldm/data/imagenet.py +394 -0
- ldm/data/inpainting/__init__.py +0 -0
- ldm/data/inpainting/synthetic_mask.py +166 -0
- ldm/data/laion.py +537 -0
- ldm/data/lsun.py +92 -0
- ldm/data/nerf_like.py +165 -0
- ldm/data/simple.py +526 -0
- ldm/extras.py +77 -0
- ldm/guidance.py +96 -0
- ldm/lr_scheduler.py +98 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.DS_Store
|
3 |
+
*.ipynb
|
4 |
+
*.egg-info/
|
README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
---
|
4 |
+
|
5 |
+
|
6 |
+
# One-2-3-45's Inference Model
|
7 |
+
|
8 |
+
<div>
|
9 |
+
<a style="display:inline-block" href="http://one-2-3-45.com"><img src="https://img.shields.io/badge/Project_Homepage-f9f7f7?logo=data:image/webp;base64,UklGRmIRAABXRUJQVlA4IFYRAABQPwCdASrIAMgAPm00lEekpiolqDvpMIANiWJu3pE7maI+vTDkhN5f7PfmGT7nS6p8nKBr0I+YBzr/ML+2/rG/8j1Sf3/1AP6v/sOsW/bn2GP2W9Zv/zeyf/cf+Z+63tReoB/6uCx2p5E/iUkPwG1FO9t/XgHPL7RH7TzI8EvEA4Mr1T2CP0J6xH+d5J/rz2F/LG9h37l+x9+4xQ3m86D2Te/zeVV/tWyTw7s85XZ0ABD4N2CpzWHt8feKiWkqdTkRjojREWrbUDAKXlYsV7EGU9rWR2gCxVXnstqpNVhwra603swvYRlMyRzKc5nJHEEeLuV8EDee/MpPVIq2DNUcXpCZXvFLHgXBWvZWzSZCFo4iub8df+Yu9q7rw5qemOe2Nt1IIoyjBmXdjCunMherehhPjQIQGiI6PDcriy/zhhwHE7O+0gmpUsYmcDR+ixOfLPY0yjnQosZkIoK1pfttDGtirMbSMqndDVi73JMxcSlNb0MNFtgdAAXNk5Z77wgsPz9RRj9oWO/KRpXn5ra4gUt+mMgSCvFG86zgghSehTRD54z10sNxqnG3/rpKDifOvT4EQU1uA9ZckUZcUt5L5C0+dOdj1I56uLJEAsn0432gHD5wRG7dgSfYusXvhGl2uMaczlXSJ0JfX+Z0e9q7sHywvyEkWC+iJREwvtWi1K+NQAD+/JSRhGP+QTeW9xU73vXKZO+JaR/TAb6vV9dNzIjket6jYZdxK0qCcaf95QeouegLeSQL/9WeH5l2/DE2AKdYjhEYzzIefp7c6cTfM3D3q3kSFxAF/xP/f/3YUFjjOzfzl5xrD3XaWz0TAehn6+ze5pANq6t5GDX8ZOfpIBGUplJj6UZXd76ropLkDdM+d/F2Megl53hry7QvtcUGNlKgjLd7/txvzvkYIPre5sKVvAJzj9DEml706Piekk2NTtBnCMQtQAPO7/Soo3p3QbqLnMIY2PKCq3jFUkeMDAB6uvaHy7e8G/yi+LlFCfYgju+h+ha+jj6NYh6xUx/9TpQoQ1VFrpEw7pCAaQ2NbzVcj/EfBLQUWQBwliZd6FG70L3ATK7AS/cu+Pm/ASndDhIDTx08uveDvY2kW7Mqproq8D4ImWzJ7ZwM8JfrvyN9/wh0Iu00O3UbTDU58dYfWzxI1gDb2Yt6+AyvgjRY/WUM8aikx5MTFi6ZEWeffMc8ruwWeKmfwJtpDxNYhJgSN5gZoOS+XedZmwoYfiuaf9hhPdDtJCM429liA9mZQ2GNfMOPtcLJV/4xTUuWJx4/d43remtcIdsy1GlD79SNNSlfWuSCF8LIopGEcQwZnBVOBmJ7O2bQvBHNAQ6dlz+Pc8zL7MgsN7uff5PDGyXHqV4lCh2Q/xbeZaYVv1agSO4QU5FvEX/6AQpNxQqgtvrXM+XsXTWYJpiO7+ucPvPudldDswT7/ITpp7AdSJ9OjPYr3cMRBVy5sXXkyY8SVv0z//QqGJbxMA3IV81dfN5dUNNlMvl+EBv6Qrnq42ZAEXMEbW/zcmuvIO+539I0BKM+COuGTuEmhulQBlMmdlBNVaII5lFuENjHGpIPNULKARA/PhM9jOvRG2xs8SFCjLZ9ZNLyWznJxq3sZWeGUNXcyJPKDr3WAWCmP8tNyjKQk6aKOD1s/+2MCNQ9b4Zb2XJJLW4kBOS6P10n42Scz8D1K6GTeouELbGApoVNEYjw2jdbNqsNLZiJi6XfSs7Kz5ozupOLJsYOzfLQTjDz7BZOzEYFoB+acySl5Qs3mO84Mw6pIb9MqHhzFqtmrsD4V6uwssP9SUvJTMA4orRWZAMdCA9lMHQi6XBPR9CBidQdudyeYtUj5gWMyzdcSkG8l/2zbzpz8THi23y/8+Ijg5naj6GtYnpJna+6QpI6dlY+E2KF7bOK2ctKwBjClUvAjHXa1162i6DsToHLQE4stmDJdvI1POb9Hj0Mq+azo1wrOfqVFcAS5XNc37IJeYBs/cQYZ08mg2vXWWJYVWz648jTHABHf+LiHsy4WRaVo4oOOSyeampoUYSM9WUJ3iOlTMis5U2DCrGoAiATOAyyuwMcYgTni5FGSpdE5BnoS6ORUiYapPetM/XmsvikTkKNn4z4jhiLFFcU+bH1pZ2DseVK9vCgY5s9ZDjNb9Ky+8fwn9dJtsZ6M7opvXhqde9Ljos6KWQ/8hj3pswa2lLZ7WRc9xaxTjq1sytCxfOd+J+VrsXedNuuYDMwumYIzF1Xsbz1VCURDw6C1twAPizF49s4JfToWsGhgG4wtBE5NAU4KvnGleFGzv54AwBR9qqPVD9bcN7ZmZhphTcAGnR2oOcvT98FmknnVXqJYdHSeP9nWG6A8YUUgDmy7rYLtbxgpid5ysrxvLeyMOeTaQXuNZXn5kQeqDGELMfQ5U2PAk+/DhnbWTsirG5NAP0tgGbSNGM5cv9+trgSk6sXdw1lZOLrfqLGZ8Dt19DWcxmjsbDU30CoSc1alYxX5G+uIHy72tQxjzsot1O4iZeNO34PItzBoeg0Fq+YQZGsHdknwJkAbffRl96xFsDxM6l4g22OaMEHxLMC9uFFE8Ee/xf+krkjv7YCfJrCM3Nw6xfyrhtxN3x2GxSg4YTu2dtjb3zVI/6jYNwgGiaDWoh5I29uQ8ZvES8Ros5jgxDzeKB1tJ3HtDM9SFGNJfQiLiSyYZQLBjCcGbi3+vlythB3k6af+P5rDqah2oPFl29Ngnw/tmpkmRIvri5i55FPeY9J4nXfvWYHTHdoB0oVA2NEk2nropP+T7GXhAxA2NgyGtzHaVU2yxiSju87w8MLIo1eac26wOnbEo/oD6Zcb8vyu0x7ug9iERQ5FlppDnIktT6QC6Kk3qBxovLzOPdEvYQoytaN256n2dmkxAaq78klv6PnU7HiH3e/I9RC27VOP0j6JDW19KvC9/uN9tfOi6WMr0IGKpTsZAUZXm+Ukyk/Rpu9ZPIH5/3CL+yfj3ROts+BWIZNj8lpFHfmYhmN/J0+/lDIGmbRVMbvmif9tqr53fqb8EkFpdMHnK8jc0oIYu2Io5SWOzHc7GMdwt5RB8LR5jUjV6Xv+rR7N4IcTQRphe7WarFsxHmihpNr8sLroqsVxBH+6dOjC5DPhRV6aJB9ZB0NjpLtbjRsEKe1By1huo8rJa+DS73fTUfxWPaJjQsasBOcc6xwuob3OBjTFjUOxfiBbhMDNUFcamlMphrkbmTe2smHz0hrScXZjoHxphV537e8PNenBpI//N58bUOcmV4Lo1H1BLLjNTw1gK+rKFgaU/WOZQ0DZ1kRRqCa86XYnFposmkLgDNooS/yeW/RGfvopRDH40d2TeW8t1+2fDHQcwocSXolq+dxC6JMGsu2rCrhdjzhqd1KPMp5EVGQuCyLc8LfjUhQ8fSs63P9aVDYZBDhO8oWSI2Lbk7cRpKJ38ww9dD0b0OjvucHkJl1zIwyQFqKKEfIN7RPvV8Q1Xxot6Y5f8/UqOCOVZRt+IM1JFcJ4AstPMOXs9hAyZzEs1EY9lv3976/18LNNvL8K7RPNH1uz3qwAajMXLaOTEK7IzCjex3YZQ0LCICPzWVKMNbkSFpmy5ow1A54fK4F45T0apL1FE8dc/Jy6ERymiJ8ZvT+BJHUtbS5oB72w8NeIb0zTuqTzYwMQiKeCI+DlJTd6R3dgbvDETb7XtLT1L5quVxBiyJLxgARoeWU1DY3eWTFJkicFp/UqIFCYgLUhQgGm/1gAxylWf4wZmbQy6RGlY3/pfn5qxqFq8Xmza7Unght3AckydGZ6u6yWcooxZwILsHaklA/Bu2HRlCLzLer57IQWfvHUjJ8pqEoZ/TE0WqZc4SF6CBVC4KGEIqyPnH/+chaIQRfGuKg0rKAAc5tB+7vGl4ck72A+dA9iW0UUwXqD6Y333q9MEdov8hbXuiRkRMv1CEm0h1N8yhxOEe1SLWxlCmvUHcVvhojM6S4ODYr2rxHxOqx63MVVCk6PpQAB2gn4D/9+QHVBBqAxLV8Pggh6aRlEmPuHNEc+b/1Zqh4lvwxUgyMFngOgTAhqZAZqBpRRD41KfU7wEbzruYQhOIxxPMbGIe93LRFgYZLz21sLmS/02JhQ1eY6cSu/iumOWleWujzjCcB9qxDpOBlsugpiveHwunOO9Lu66uGNeKuw6Aqo+zTcdEX+BOlserNtSeyYmhQrwLA+mnEqaAtKv5eAyTC03krdlSEI+++xVMU+kqsGF+6H9yNBQj5aZxmOzd7BejBdBBInEjlj868zR80jlgVKb+yQ7XkdiFIvQl/XvaFpPGqYb4UR70U0jNe/I9UuFggu8M90wyOi7Ihm3t9FZTBPv4zmee4ue5pVKpdsOOZLwBSODTdpUJb6ctU3I9n0KAUBM4+RkNoQzCyb+iXoXl22CL2eQWlOlBi8IG84Y2bMIiLnPs5qeUth9zlniN14oQNTtVJibuIgkylT50ExHyuqz1ra2+wW3QDltErT6yyrKnL8rmkPesI3aPAL880z4U6TWXqcU6hkryL8W5gdI94KYuDTBEim0GM6IAAKf8JZNX3sM/OIB9h3XbFUuNXRocJY9iqQAGdinm3YPLbRBxP5S5EWwlTdIVK5yjUpV+tCN0HXOVf7xj9pnyIMPDz/Znf4zufz+0zonywFQLgAdKiVwC5a6EDC2rmxYC4L82QIO17SKc8NCAJZPTWwwrPGgb0nhQdi3g32QzHUAqE2qhq2jyM7WINI34P28PN5IE50uRx/XFn2a8h2Qgla55PnsIT7KbDBo0Nd4XUkCWINxReQK0/NZEDZrUuZghyZYnnIuIi0pTpecJWliTLvfxyiRkIsb9t2mT6VzM8H2HN8nq0rF7BC27r0JoLl/5YgZQZmw763cQ625wkmPOX0vr1M35fZYv06zKm1ux/L+W6O3ju3VdFudKgEgRIeT+bIOQKoKaT+knRugmPDGt1JAt6bKTT2bvIYnf5OvZs9id5x+qy5UeotL3uxYiBj7SyGxTCHdovbak2BG5hGmuVWxRojEJS9IEqUKwy133zg24keiFy0bXsG125D1XgQ/uI3IM8dijJ4N6jHObWneJl3zHvKb+cX97XFAv5VV5ySEfm0Iglkir9QaTOXP9SuTeCao7Q3fULO6Jcp+sOU6x+jCjlsmHiF7oDCAb/RYITB9oIcAGGQJkT6ccTNJPlyPiNf4/nefe5V4j5RWLTDm7Wb/kt426CIGzE2ekBjrvExlch914MGYjMJcBllUj/LTfXTYLSqPwPzU/xBVSUOR9o7wFnFBZTaI0ApgY11rRsoEgTu9yRgBxw0h71O/RjpN5Ku/U/er87C9/jHzucfXpRDcP1JxOOxoJziSE01YxnPjmyDigmCcur63bY/xXdZeNQNprWvE3mAIP14fFkdJ4+0vwkAP+BXokPPQBkZuEWFAUEz1H/YQf4Q9bCQZXl/WSpUpG/TjBo8EpZLTJ2Jwa1G3H2hVIUlifUnV/SvKDYbpUvl6mKuwdgglJxkJOXjtf84FjvjeHUOzf8ZhLw3PH53rUrDz0INySaGJ/n4a/iuvLMaL146Ii98kND4sM0nElTIxnJe+LJF/8aimynAAiTshwnKc7MHqCtuDaUFfEQCGw0tmys5yKZM5zzawr6LW7MdQs9XUDiyTrX+YcI0uPZZ43oMnO737u5Tmc/sAeKCNGIWt8kw87EMQ+BP5NMrf8b8wDvSD2XVEZu7xqwNCizeSYQJGVJVSAdJ27XwXrFfHtdHHrlojW+3BFzE5rOzDsUsA00zYHxt+e9zo9Yn0sImcxGhbDFBGD892Rgz9G+eor3huRF8h4p1qYpjTe/ykVkhWyvHRjNNevOV7Gk1jhgiwOajzrXwNsIJNUvAQQB017GRYgey7MAEeBoAx5RuxYU+oMH6DNk5eYcrJDxo48XGbO4QhCMRgAA"></a>
|
10 |
+
<a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/2306.16928"><img src="https://img.shields.io/badge/2306.16928-f9f7f7?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAADcAAABMCAYAAADJPi9EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAuIwAALiMBeKU/dgAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAAa2SURBVHja3Zt7bBRFGMAXUCDGF4rY7m7bAwuhlggKStFgLBgFEkCIIRJEEoOBYHwRFYKilUgEReVNJEGCJJpehHI3M9vZvd3bUP1DjNhEIRQQsQgSHiJgQZ5dv7krWEvvdmZ7d7vHJN+ft/f99pv5XvOtJMFCqvoCUpTdIEeRLC+L9Ox5i3Q9LACaCeK0kXoSChVcD3C/tQPHpAEsquQ73IkUcEz2kcLCknyGW5MGjkljRFVL8xJOKyi4CwCOuQAeAkfTP1+tNxLkogvgEbDgffkJqKqvuMA5ifOpqg/5qWecRstNg7xoUTI1Fovdxg8oy2s5AP8CGeYHmGngeZaOL4I4LXLcpHg4149/GDz4xqgsb+UAbMKKUpkrqHA43MUyyJpWUK0EHeG2YKRXr7tB+QMcgGewLD+ebTDbtrtbBt7UPlhS4rV4IvcDI7J8P1OeA/AcAI7LHljN7aB8XTowJmZt9EFRD/o0SDMH4HlwMhMyDWZZSAHFf3YDs3RS49WDLuaAY3IJq+qzmQKLxXAZKN7oDoYbdV3v5elPqiSpMyiOuAEVZVqHXb1OhloUH+MA+ztO0cAO/RkrfyBE7OAEbAZvO8vzVtTRWFD6DAfY5biBM3PWiaL0a4lvXICwnV8WjmE6ntYmhqX2jjp5LbMZjCw/wbYeN6CizOa2GMVzQOlmHjB4Ceuyk6LJ8huccEmR5Xddg7OOV/NAtchW+E3XbOag60QA4Qwuarca0bRuEJyr+cFQwzcY98huxhAKdQelt4kAQpj4qJ3gvFXAYn+aJumXk1yPlpQUgtIHhbYoFMUstNRRWgjnpl4A7IKlayNymqFHFaWCpV9CFry3LGxR1CgA5kB5M8OX2goApwpaz6mdOMGxtAgXWJySxb4WuQD4qTDgU+N5AAnzpr7ChSWpCyisiQJqY0Y7FtmSKpbV23b45kC0KHBxcQ9QeI8w4KgnHRPVtIU7rOtbioLVg5Hl/qDwSVFAMqLSMSObroCdZYlzIJtMRFVHCaRo/wFWPgaAXzdbBpkc2A4aKzCNd97+URQuESYGDDhIVfWOQIKZJu4D2+oXlgDTV1865gUQZDts756BArMNMoR1oa46BYqbyPixZz1ZUFV3sgwoGBajuBKATl3btIn8QYYMuezRgrsiRUWyr2BxA40EkPMpA/Hm6gbUu7fjEXA3azP6AsbKD9bxdUuhjM9W7fII52BF+daRpE4+WA3P501+jbfmHvQKyFqMuXf7Ot4mkN2fr50y+bRH61X7AXdUpHSxaPQ4GVbR5AGw3g+434XgQGKfr72I+vQRhfsu92dOx7WicInzt3CBg1RVpMm0NveWo2SqFzgmdNZMbriILD+S+zoueWf2vSdAipzacWN5nMl6XxNlUHa/J8DoJodUDE0HR8Ll5V0lPxcrLEHZPV4AzS83OLis7FowVa3RSku7BSNxJqQAlN3hBTC2apmDSkpaw22wJemGQFUG7J4MlP3JC6A+f96V7vRyX9It3nzT/GrjIU8edM7rMSnIi10f476lzbE1K7yEiEuWro0OJBguLCwDuFOJc1Na6sRWL/cCeMIwUN9ggSVbe3v/5/EgzTKWLvEAiBrYRUkgwNI2ZaFQNT75UDxEUEx97zYnzpmiLEmbaYCbNxYtFAb0/Z4AztgUrhyxuNgxPnhfHFDHz/vTgFWUQZxTRkkJhQ6YNdVUEPAfO6ZV5BRss6LcCVb7VaAma9giy0XJZBt9IQh42NY0NSdgbLIPlLUF6rEdrdt0CUCK1wsCbkcI3ZSLc7ZSwGLbmJXbPsNxnE5xilYKAobZ77LpGZ8TAIun+/iCKQoF71IxQDI3K2CCd+ARNvXg9sykBcnHAoCZG4u66hlDoQLe6QV4CRtFSxZQ+D0BwNO2jgdkzoGoah1nj3FVlSR19taTSYxI8QLut23U8dsgzqHulJNCQpcqBnpTALCuQ6NSYLHpmR5i42gZzuIdcrMMvMJbQlxe3jXxyZnLACl7ARm/FjPIDOY8ODtpM71sxwfcZpvBeUzKWmfNINM5AS+wO0Khh7dMqKccu4+qatarZjYAwDlgetzStHtEt+XedsBOQtU9XMrRgjg4KTnc5nr+dmqadit/4C4uLm8DuA9koJTj1TL7fI5nDL+qqoo/FLGAzL7dYT17PzvAcQONYSUQRxW/QMrHZVIyik0ZuQA2mzp+Ji8BW4YM3Mbzm9inaHkJCGfrUZZjujiYailfFwA8DHIy3acwUj4v9vUVa+SmgNsl5fuyDTKovW9/IAmfLV0Pi2UncA515kjYdrwC9i9rpuHiq3JwtAAAAABJRU5ErkJggg=="></a>
|
11 |
+
<a style="display:inline-block; margin-left: .5em" href='https://github.com/One-2-3-45/One-2-3-45'><img src='https://img.shields.io/github/stars/One-2-3-45/One-2-3-45?style=social' /></a>
|
12 |
+
</div>
|
13 |
+
|
14 |
+
This inference model supports the demo for [One-2-3-45](http://One-2-3-45.com).
|
15 |
+
|
16 |
+
Try out our 🤗 Hugging Face Demo:
|
17 |
+
<a target="_blank" href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45">
|
18 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HuggingFace"/>
|
19 |
+
</a>
|
20 |
+
|
21 |
+
Please refer to our [GitHub repo](https://github.com/One-2-3-45/One-2-3-45) for full code release and local deployment.
|
22 |
+
|
23 |
+
## Citation
|
24 |
+
|
25 |
+
```bibtex
|
26 |
+
@misc{liu2023one2345,
|
27 |
+
title={One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization},
|
28 |
+
author={Minghua Liu and Chao Xu and Haian Jin and Linghao Chen and Mukund Varma T and Zexiang Xu and Hao Su},
|
29 |
+
year={2023},
|
30 |
+
eprint={2306.16928},
|
31 |
+
archivePrefix={arXiv},
|
32 |
+
primaryClass={cs.CV}
|
33 |
+
}
|
34 |
+
```
|
SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network
|
2 |
+
#- for lod1 rendering network, using depth-adaptive render
|
3 |
+
|
4 |
+
general {
|
5 |
+
|
6 |
+
base_exp_dir = exp/lod0 # !!! where you store the results and checkpoints to be used
|
7 |
+
recording = [
|
8 |
+
./,
|
9 |
+
./data
|
10 |
+
./ops
|
11 |
+
./models
|
12 |
+
./loss
|
13 |
+
]
|
14 |
+
}
|
15 |
+
|
16 |
+
dataset {
|
17 |
+
trainpath = ../
|
18 |
+
valpath = ../ # !!! where you store the validation data
|
19 |
+
testpath = ../
|
20 |
+
|
21 |
+
imgScale_train = 1.0
|
22 |
+
imgScale_test = 1.0
|
23 |
+
nviews = 5
|
24 |
+
clean_image = True
|
25 |
+
importance_sample = True
|
26 |
+
test_ref_views = [23]
|
27 |
+
|
28 |
+
# test dataset
|
29 |
+
test_n_views = 2
|
30 |
+
test_img_wh = [256, 256]
|
31 |
+
test_clip_wh = [0, 0]
|
32 |
+
test_scan_id = scan110
|
33 |
+
train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] #
|
34 |
+
test_img_idx = [51, 55, 57] #[32, 33, 34] #
|
35 |
+
|
36 |
+
test_dir_comment = train
|
37 |
+
}
|
38 |
+
|
39 |
+
train {
|
40 |
+
learning_rate = 2e-4
|
41 |
+
learning_rate_milestone = [100000, 150000, 200000]
|
42 |
+
learning_rate_factor = 0.5
|
43 |
+
end_iter = 200000
|
44 |
+
save_freq = 5000
|
45 |
+
val_freq = 1
|
46 |
+
val_mesh_freq = 1
|
47 |
+
report_freq = 100
|
48 |
+
|
49 |
+
N_rays = 512
|
50 |
+
|
51 |
+
validate_resolution_level = 4
|
52 |
+
anneal_start = 0
|
53 |
+
anneal_end = 25000
|
54 |
+
anneal_start_lod1 = 0
|
55 |
+
anneal_end_lod1 = 15000
|
56 |
+
|
57 |
+
use_white_bkgd = True
|
58 |
+
|
59 |
+
# Loss
|
60 |
+
# ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization
|
61 |
+
sdf_igr_weight = 0.1
|
62 |
+
sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network
|
63 |
+
sdf_decay_param = 100 # cannot be too large, which decide the tsdf range
|
64 |
+
fg_bg_weight = 0.01 # first 0.01
|
65 |
+
bg_ratio = 0.3
|
66 |
+
|
67 |
+
if_fix_lod0_networks = False
|
68 |
+
}
|
69 |
+
|
70 |
+
model {
|
71 |
+
num_lods = 1
|
72 |
+
|
73 |
+
sdf_network_lod0 {
|
74 |
+
lod = 0,
|
75 |
+
ch_in = 56, # the channel num of fused pyramid features
|
76 |
+
voxel_size = 0.02105263, # 0.02083333, should be 2/95
|
77 |
+
vol_dims = [96, 96, 96],
|
78 |
+
hidden_dim = 128,
|
79 |
+
cost_type = variance_mean
|
80 |
+
d_pyramid_feature_compress = 16,
|
81 |
+
regnet_d_out = 16,
|
82 |
+
num_sdf_layers = 4,
|
83 |
+
# position embedding
|
84 |
+
multires = 6
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
sdf_network_lod1 {
|
89 |
+
lod = 1,
|
90 |
+
ch_in = 56, # the channel num of fused pyramid features
|
91 |
+
voxel_size = 0.0104712, #0.01041667, should be 2/191
|
92 |
+
vol_dims = [192, 192, 192],
|
93 |
+
hidden_dim = 128,
|
94 |
+
cost_type = variance_mean
|
95 |
+
d_pyramid_feature_compress = 8,
|
96 |
+
regnet_d_out = 16,
|
97 |
+
num_sdf_layers = 4,
|
98 |
+
|
99 |
+
# position embedding
|
100 |
+
multires = 6
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
variance_network {
|
105 |
+
init_val = 0.2
|
106 |
+
}
|
107 |
+
|
108 |
+
variance_network_lod1 {
|
109 |
+
init_val = 0.2
|
110 |
+
}
|
111 |
+
|
112 |
+
rendering_network {
|
113 |
+
in_geometry_feat_ch = 16
|
114 |
+
in_rendering_feat_ch = 56
|
115 |
+
anti_alias_pooling = True
|
116 |
+
}
|
117 |
+
|
118 |
+
rendering_network_lod1 {
|
119 |
+
in_geometry_feat_ch = 16 # default 8
|
120 |
+
in_rendering_feat_ch = 56
|
121 |
+
anti_alias_pooling = True
|
122 |
+
|
123 |
+
}
|
124 |
+
|
125 |
+
|
126 |
+
trainer {
|
127 |
+
n_samples_lod0 = 64
|
128 |
+
n_importance_lod0 = 64
|
129 |
+
n_samples_lod1 = 64
|
130 |
+
n_importance_lod1 = 64
|
131 |
+
n_outside = 0 # 128 if render_outside_uniform_sampling
|
132 |
+
perturb = 1.0
|
133 |
+
alpha_type = div
|
134 |
+
}
|
135 |
+
}
|
SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
from torchvision import transforms as T
|
9 |
+
from data.scene import get_boundingbox
|
10 |
+
|
11 |
+
from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image
|
12 |
+
from kornia import create_meshgrid
|
13 |
+
|
14 |
+
def get_ray_directions(H, W, focal, center=None):
|
15 |
+
"""
|
16 |
+
Get ray directions for all pixels in camera coordinate.
|
17 |
+
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
18 |
+
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
19 |
+
Inputs:
|
20 |
+
H, W, focal: image height, width and focal length
|
21 |
+
Outputs:
|
22 |
+
directions: (H, W, 3), the direction of the rays in camera coordinate
|
23 |
+
"""
|
24 |
+
grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2
|
25 |
+
|
26 |
+
i, j = grid.unbind(-1)
|
27 |
+
# the direction here is without +0.5 pixel centering as calibration is not so accurate
|
28 |
+
# see https://github.com/bmild/nerf/issues/24
|
29 |
+
cent = center if center is not None else [W / 2, H / 2]
|
30 |
+
directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
|
31 |
+
|
32 |
+
return directions
|
33 |
+
|
34 |
+
def load_K_Rt_from_P(filename, P=None):
|
35 |
+
if P is None:
|
36 |
+
lines = open(filename).read().splitlines()
|
37 |
+
if len(lines) == 4:
|
38 |
+
lines = lines[1:]
|
39 |
+
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
|
40 |
+
P = np.asarray(lines).astype(np.float32).squeeze()
|
41 |
+
|
42 |
+
out = cv2.decomposeProjectionMatrix(P)
|
43 |
+
K = out[0]
|
44 |
+
R = out[1]
|
45 |
+
t = out[2]
|
46 |
+
|
47 |
+
K = K / K[2, 2]
|
48 |
+
intrinsics = np.eye(4)
|
49 |
+
intrinsics[:3, :3] = K
|
50 |
+
|
51 |
+
pose = np.eye(4, dtype=np.float32)
|
52 |
+
pose[:3, :3] = R.transpose() # ? why need transpose here
|
53 |
+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
54 |
+
|
55 |
+
return intrinsics, pose # ! return cam2world matrix here
|
56 |
+
|
57 |
+
|
58 |
+
# ! load one ref-image with multiple src-images in camera coordinate system
|
59 |
+
class BlenderPerView(Dataset):
|
60 |
+
def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0,
|
61 |
+
split_filepath=None, pair_filepath=None,
|
62 |
+
N_rays=512,
|
63 |
+
vol_dims=[128, 128, 128], batch_size=1,
|
64 |
+
clean_image=False, importance_sample=False, test_ref_views=[],
|
65 |
+
specific_dataset_name = 'GSO'
|
66 |
+
):
|
67 |
+
|
68 |
+
# print("root_dir: ", root_dir)
|
69 |
+
self.root_dir = root_dir
|
70 |
+
self.split = split
|
71 |
+
|
72 |
+
self.specific_dataset_name = specific_dataset_name
|
73 |
+
self.n_views = n_views
|
74 |
+
self.N_rays = N_rays
|
75 |
+
self.batch_size = batch_size # - used for construct new metas for gru fusion training
|
76 |
+
|
77 |
+
self.clean_image = clean_image
|
78 |
+
self.importance_sample = importance_sample
|
79 |
+
self.test_ref_views = test_ref_views # used for testing
|
80 |
+
self.scale_factor = 1.0
|
81 |
+
self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0]))
|
82 |
+
assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh'
|
83 |
+
# find all subfolders
|
84 |
+
main_folder = os.path.join(root_dir, self.specific_dataset_name)
|
85 |
+
self.shape_list = [""] # os.listdir(main_folder) # MODIFIED
|
86 |
+
self.shape_list.sort()
|
87 |
+
|
88 |
+
# self.shape_list = ['barrel_render']
|
89 |
+
# self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED
|
90 |
+
|
91 |
+
|
92 |
+
self.lvis_paths = []
|
93 |
+
for shape_name in self.shape_list:
|
94 |
+
self.lvis_paths.append(os.path.join(main_folder, shape_name))
|
95 |
+
|
96 |
+
if img_wh is not None:
|
97 |
+
assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \
|
98 |
+
'img_wh must both be multiples of 32!'
|
99 |
+
|
100 |
+
|
101 |
+
# * bounding box for rendering
|
102 |
+
self.bbox_min = np.array([-1.0, -1.0, -1.0])
|
103 |
+
self.bbox_max = np.array([1.0, 1.0, 1.0])
|
104 |
+
|
105 |
+
# - used for cost volume regularization
|
106 |
+
self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32)
|
107 |
+
self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32)
|
108 |
+
|
109 |
+
|
110 |
+
def define_transforms(self):
|
111 |
+
self.transform = T.Compose([T.ToTensor()])
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
def load_cam_info(self):
|
116 |
+
for vid, img_id in enumerate(self.img_ids):
|
117 |
+
intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far
|
118 |
+
self.all_intrinsics.append(intrinsic)
|
119 |
+
self.all_extrinsics.append(extrinsic)
|
120 |
+
self.all_near_fars.append(near_far)
|
121 |
+
|
122 |
+
def read_mask(self, filename):
|
123 |
+
mask_h = cv2.imread(filename, 0)
|
124 |
+
mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample,
|
125 |
+
interpolation=cv2.INTER_NEAREST)
|
126 |
+
mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25,
|
127 |
+
interpolation=cv2.INTER_NEAREST)
|
128 |
+
|
129 |
+
mask[mask > 0] = 1 # the masks stored in png are not binary
|
130 |
+
mask_h[mask_h > 0] = 1
|
131 |
+
|
132 |
+
return mask, mask_h
|
133 |
+
|
134 |
+
def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.):
|
135 |
+
|
136 |
+
center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars)
|
137 |
+
|
138 |
+
radius = radius * factor
|
139 |
+
scale_mat = np.diag([radius, radius, radius, 1.0])
|
140 |
+
scale_mat[:3, 3] = center.cpu().numpy()
|
141 |
+
scale_mat = scale_mat.astype(np.float32)
|
142 |
+
|
143 |
+
return scale_mat, 1. / radius.cpu().numpy()
|
144 |
+
|
145 |
+
def __len__(self):
|
146 |
+
# return 8*len(self.lvis_paths)
|
147 |
+
return len(self.lvis_paths)
|
148 |
+
|
149 |
+
def __getitem__(self, idx):
|
150 |
+
sample = {}
|
151 |
+
idx = idx * 8 # to be deleted
|
152 |
+
origin_idx = idx
|
153 |
+
imgs, depths_h, masks_h = [], [], [] # full size (256, 256)
|
154 |
+
intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views
|
155 |
+
|
156 |
+
folder_path = self.lvis_paths[idx//8]
|
157 |
+
idx = idx % 8 # [0, 7]
|
158 |
+
|
159 |
+
# last subdir name
|
160 |
+
shape_name = os.path.split(folder_path)[-1]
|
161 |
+
|
162 |
+
pose_json_path = os.path.join(folder_path, "pose.json")
|
163 |
+
with open(pose_json_path, 'r') as f:
|
164 |
+
meta = json.load(f)
|
165 |
+
|
166 |
+
self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10"
|
167 |
+
self.img_wh = (256, 256)
|
168 |
+
self.input_poses = np.array(list(meta["c2ws"].values()))
|
169 |
+
intrinsic = np.eye(4)
|
170 |
+
intrinsic[:3, :3] = np.array(meta["intrinsics"])
|
171 |
+
self.intrinsic = intrinsic
|
172 |
+
self.near_far = np.array(meta["near_far"])
|
173 |
+
self.near_far[1] = 1.8
|
174 |
+
self.define_transforms()
|
175 |
+
self.blender2opencv = np.array(
|
176 |
+
[[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
|
177 |
+
)
|
178 |
+
|
179 |
+
self.c2ws = []
|
180 |
+
self.w2cs = []
|
181 |
+
self.near_fars = []
|
182 |
+
for image_idx, img_id in enumerate(self.img_ids):
|
183 |
+
pose = self.input_poses[image_idx]
|
184 |
+
c2w = pose @ self.blender2opencv
|
185 |
+
self.c2ws.append(c2w)
|
186 |
+
self.w2cs.append(np.linalg.inv(c2w))
|
187 |
+
self.near_fars.append(self.near_far)
|
188 |
+
self.c2ws = np.stack(self.c2ws, axis=0)
|
189 |
+
self.w2cs = np.stack(self.w2cs, axis=0)
|
190 |
+
|
191 |
+
|
192 |
+
self.all_intrinsics = [] # the cam info of the whole scene
|
193 |
+
self.all_extrinsics = []
|
194 |
+
self.all_near_fars = []
|
195 |
+
self.load_cam_info()
|
196 |
+
|
197 |
+
|
198 |
+
# target view
|
199 |
+
c2w = self.c2ws[idx]
|
200 |
+
w2c = np.linalg.inv(c2w)
|
201 |
+
w2c_ref = w2c
|
202 |
+
w2c_ref_inv = np.linalg.inv(w2c_ref)
|
203 |
+
|
204 |
+
w2cs.append(w2c @ w2c_ref_inv)
|
205 |
+
c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv))
|
206 |
+
|
207 |
+
img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}')
|
208 |
+
|
209 |
+
img = Image.open(img_filename)
|
210 |
+
img = self.transform(img) # (4, h, w)
|
211 |
+
|
212 |
+
|
213 |
+
if img.shape[0] == 4:
|
214 |
+
img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
|
215 |
+
imgs += [img]
|
216 |
+
|
217 |
+
|
218 |
+
depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32)
|
219 |
+
depth_h = depth_h.fill_(-1.0)
|
220 |
+
mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32)
|
221 |
+
|
222 |
+
|
223 |
+
depths_h.append(depth_h)
|
224 |
+
masks_h.append(mask_h)
|
225 |
+
|
226 |
+
intrinsic = self.intrinsic
|
227 |
+
intrinsics.append(intrinsic)
|
228 |
+
|
229 |
+
|
230 |
+
near_fars.append(self.near_fars[idx])
|
231 |
+
image_perm = 0 # only supervised on reference view
|
232 |
+
|
233 |
+
mask_dilated = None
|
234 |
+
|
235 |
+
|
236 |
+
src_views = range(8, 8 + 8 * 4)
|
237 |
+
|
238 |
+
for vid in src_views:
|
239 |
+
|
240 |
+
img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}')
|
241 |
+
img = Image.open(img_filename)
|
242 |
+
img_wh = self.img_wh
|
243 |
+
|
244 |
+
img = self.transform(img)
|
245 |
+
if img.shape[0] == 4:
|
246 |
+
img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
|
247 |
+
|
248 |
+
imgs += [img]
|
249 |
+
depth_h = np.ones(img.shape[1:], dtype=np.float32)
|
250 |
+
depths_h.append(depth_h)
|
251 |
+
masks_h.append(np.ones(img.shape[1:], dtype=np.int32))
|
252 |
+
|
253 |
+
near_fars.append(self.all_near_fars[vid])
|
254 |
+
intrinsics.append(self.all_intrinsics[vid])
|
255 |
+
|
256 |
+
w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv)
|
257 |
+
|
258 |
+
|
259 |
+
# ! estimate scale_mat
|
260 |
+
scale_mat, scale_factor = self.cal_scale_mat(
|
261 |
+
img_hw=[img_wh[1], img_wh[0]],
|
262 |
+
intrinsics=intrinsics, extrinsics=w2cs,
|
263 |
+
near_fars=near_fars, factor=1.1
|
264 |
+
)
|
265 |
+
|
266 |
+
|
267 |
+
new_near_fars = []
|
268 |
+
new_w2cs = []
|
269 |
+
new_c2ws = []
|
270 |
+
new_affine_mats = []
|
271 |
+
new_depths_h = []
|
272 |
+
for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h):
|
273 |
+
|
274 |
+
P = intrinsic @ extrinsic @ scale_mat
|
275 |
+
P = P[:3, :4]
|
276 |
+
# - should use load_K_Rt_from_P() to obtain c2w
|
277 |
+
c2w = load_K_Rt_from_P(None, P)[1]
|
278 |
+
w2c = np.linalg.inv(c2w)
|
279 |
+
new_w2cs.append(w2c)
|
280 |
+
new_c2ws.append(c2w)
|
281 |
+
affine_mat = np.eye(4)
|
282 |
+
affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4]
|
283 |
+
new_affine_mats.append(affine_mat)
|
284 |
+
|
285 |
+
camera_o = c2w[:3, 3]
|
286 |
+
dist = np.sqrt(np.sum(camera_o ** 2))
|
287 |
+
near = dist - 1
|
288 |
+
far = dist + 1
|
289 |
+
|
290 |
+
new_near_fars.append([0.95 * near, 1.05 * far])
|
291 |
+
new_depths_h.append(depth * scale_factor)
|
292 |
+
|
293 |
+
imgs = torch.stack(imgs).float()
|
294 |
+
depths_h = np.stack(new_depths_h)
|
295 |
+
masks_h = np.stack(masks_h)
|
296 |
+
|
297 |
+
affine_mats = np.stack(new_affine_mats)
|
298 |
+
intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack(
|
299 |
+
new_near_fars)
|
300 |
+
|
301 |
+
if self.split == 'train':
|
302 |
+
start_idx = 0
|
303 |
+
else:
|
304 |
+
start_idx = 1
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
target_w2cs = []
|
309 |
+
target_intrinsics = []
|
310 |
+
new_target_w2cs = []
|
311 |
+
for i_idx in range(8):
|
312 |
+
target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv)
|
313 |
+
target_intrinsics.append(self.all_intrinsics[i_idx])
|
314 |
+
|
315 |
+
for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs):
|
316 |
+
|
317 |
+
P = intrinsic @ extrinsic @ scale_mat
|
318 |
+
P = P[:3, :4]
|
319 |
+
# - should use load_K_Rt_from_P() to obtain c2w
|
320 |
+
c2w = load_K_Rt_from_P(None, P)[1]
|
321 |
+
w2c = np.linalg.inv(c2w)
|
322 |
+
new_target_w2cs.append(w2c)
|
323 |
+
target_w2cs = np.stack(new_target_w2cs)
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
view_ids = [idx] + list(src_views)
|
328 |
+
sample['origin_idx'] = origin_idx
|
329 |
+
sample['images'] = imgs # (V, 3, H, W)
|
330 |
+
sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W)
|
331 |
+
sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W)
|
332 |
+
sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4)
|
333 |
+
sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4)
|
334 |
+
sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4)
|
335 |
+
sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2)
|
336 |
+
sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3)
|
337 |
+
sample['view_ids'] = torch.from_numpy(np.array(view_ids))
|
338 |
+
sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space
|
339 |
+
|
340 |
+
sample['scan'] = shape_name
|
341 |
+
|
342 |
+
sample['scale_factor'] = torch.tensor(scale_factor)
|
343 |
+
sample['img_wh'] = torch.from_numpy(np.array(img_wh))
|
344 |
+
sample['render_img_idx'] = torch.tensor(image_perm)
|
345 |
+
sample['partial_vol_origin'] = self.partial_vol_origin
|
346 |
+
sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0])
|
347 |
+
# print("meta: ", sample['meta'])
|
348 |
+
|
349 |
+
# - image to render
|
350 |
+
sample['query_image'] = sample['images'][0]
|
351 |
+
sample['query_c2w'] = sample['c2ws'][0]
|
352 |
+
sample['query_w2c'] = sample['w2cs'][0]
|
353 |
+
sample['query_intrinsic'] = sample['intrinsics'][0]
|
354 |
+
sample['query_depth'] = sample['depths_h'][0]
|
355 |
+
sample['query_mask'] = sample['masks_h'][0]
|
356 |
+
sample['query_near_far'] = sample['near_fars'][0]
|
357 |
+
|
358 |
+
sample['images'] = sample['images'][start_idx:] # (V, 3, H, W)
|
359 |
+
sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W)
|
360 |
+
sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W)
|
361 |
+
sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4)
|
362 |
+
sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4)
|
363 |
+
sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3)
|
364 |
+
sample['view_ids'] = sample['view_ids'][start_idx:]
|
365 |
+
sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space
|
366 |
+
|
367 |
+
sample['scale_mat'] = torch.from_numpy(scale_mat)
|
368 |
+
sample['trans_mat'] = torch.from_numpy(w2c_ref_inv)
|
369 |
+
|
370 |
+
# - generate rays
|
371 |
+
if ('val' in self.split) or ('test' in self.split):
|
372 |
+
sample_rays = gen_rays_from_single_image(
|
373 |
+
img_wh[1], img_wh[0],
|
374 |
+
sample['query_image'],
|
375 |
+
sample['query_intrinsic'],
|
376 |
+
sample['query_c2w'],
|
377 |
+
depth=sample['query_depth'],
|
378 |
+
mask=sample['query_mask'] if self.clean_image else None)
|
379 |
+
else:
|
380 |
+
sample_rays = gen_random_rays_from_single_image(
|
381 |
+
img_wh[1], img_wh[0],
|
382 |
+
self.N_rays,
|
383 |
+
sample['query_image'],
|
384 |
+
sample['query_intrinsic'],
|
385 |
+
sample['query_c2w'],
|
386 |
+
depth=sample['query_depth'],
|
387 |
+
mask=sample['query_mask'] if self.clean_image else None,
|
388 |
+
dilated_mask=mask_dilated,
|
389 |
+
importance_sample=self.importance_sample)
|
390 |
+
|
391 |
+
|
392 |
+
sample['rays'] = sample_rays
|
393 |
+
|
394 |
+
return sample
|
SparseNeuS_demo_v1/data/scene.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def rigid_transform(xyz, transform):
|
6 |
+
"""Applies a rigid transform (c2w) to an (N, 3) pointcloud.
|
7 |
+
"""
|
8 |
+
device = xyz.device
|
9 |
+
xyz_h = torch.cat([xyz, torch.ones((len(xyz), 1)).to(device)], dim=1) # (N, 4)
|
10 |
+
xyz_t_h = (transform @ xyz_h.T).T # * checked: the same with the below
|
11 |
+
|
12 |
+
return xyz_t_h[:, :3]
|
13 |
+
|
14 |
+
|
15 |
+
def get_view_frustum(min_depth, max_depth, size, cam_intr, c2w):
|
16 |
+
"""Get corners of 3D camera view frustum of depth image
|
17 |
+
"""
|
18 |
+
device = cam_intr.device
|
19 |
+
im_h, im_w = size
|
20 |
+
im_h = int(im_h)
|
21 |
+
im_w = int(im_w)
|
22 |
+
view_frust_pts = torch.stack([
|
23 |
+
(torch.tensor([0, 0, im_w, im_w, 0, 0, im_w, im_w]).to(device) - cam_intr[0, 2]) * torch.tensor(
|
24 |
+
[min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
|
25 |
+
cam_intr[0, 0],
|
26 |
+
(torch.tensor([0, im_h, 0, im_h, 0, im_h, 0, im_h]).to(device) - cam_intr[1, 2]) * torch.tensor(
|
27 |
+
[min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) /
|
28 |
+
cam_intr[1, 1],
|
29 |
+
torch.tensor([min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(
|
30 |
+
device)
|
31 |
+
])
|
32 |
+
view_frust_pts = view_frust_pts.type(torch.float32)
|
33 |
+
c2w = c2w.type(torch.float32)
|
34 |
+
view_frust_pts = rigid_transform(view_frust_pts.T, c2w).T
|
35 |
+
return view_frust_pts
|
36 |
+
|
37 |
+
|
38 |
+
def set_pixel_coords(h, w):
|
39 |
+
i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type(torch.float32) # [1, H, W]
|
40 |
+
j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type(torch.float32) # [1, H, W]
|
41 |
+
ones = torch.ones(1, h, w).type(torch.float32)
|
42 |
+
|
43 |
+
pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
|
44 |
+
|
45 |
+
return pixel_coords
|
46 |
+
|
47 |
+
|
48 |
+
def get_boundingbox(img_hw, intrinsics, extrinsics, near_fars):
|
49 |
+
"""
|
50 |
+
# get the minimum bounding box of all visual hulls
|
51 |
+
:param img_hw:
|
52 |
+
:param intrinsics:
|
53 |
+
:param extrinsics:
|
54 |
+
:param near_fars:
|
55 |
+
:return:
|
56 |
+
"""
|
57 |
+
|
58 |
+
bnds = torch.zeros((3, 2))
|
59 |
+
bnds[:, 0] = np.inf
|
60 |
+
bnds[:, 1] = -np.inf
|
61 |
+
|
62 |
+
if isinstance(intrinsics, list):
|
63 |
+
num = len(intrinsics)
|
64 |
+
else:
|
65 |
+
num = intrinsics.shape[0]
|
66 |
+
# print("num: ", num)
|
67 |
+
view_frust_pts_list = []
|
68 |
+
for i in range(num):
|
69 |
+
if not isinstance(intrinsics[i], torch.Tensor):
|
70 |
+
cam_intr = torch.tensor(intrinsics[i])
|
71 |
+
w2c = torch.tensor(extrinsics[i])
|
72 |
+
c2w = torch.inverse(w2c)
|
73 |
+
else:
|
74 |
+
cam_intr = intrinsics[i]
|
75 |
+
w2c = extrinsics[i]
|
76 |
+
c2w = torch.inverse(w2c)
|
77 |
+
min_depth, max_depth = near_fars[i][0], near_fars[i][1]
|
78 |
+
# todo: check the coresponding points are matched
|
79 |
+
|
80 |
+
view_frust_pts = get_view_frustum(min_depth, max_depth, img_hw, cam_intr, c2w)
|
81 |
+
bnds[:, 0] = torch.min(bnds[:, 0], torch.min(view_frust_pts, dim=1)[0])
|
82 |
+
bnds[:, 1] = torch.max(bnds[:, 1], torch.max(view_frust_pts, dim=1)[0])
|
83 |
+
view_frust_pts_list.append(view_frust_pts)
|
84 |
+
all_view_frust_pts = torch.cat(view_frust_pts_list, dim=1)
|
85 |
+
|
86 |
+
# print("all_view_frust_pts: ", all_view_frust_pts.shape)
|
87 |
+
# distance = torch.norm(all_view_frust_pts, dim=0)
|
88 |
+
# print("distance: ", distance)
|
89 |
+
|
90 |
+
# print("all_view_frust_pts_z: ", all_view_frust_pts[2, :])
|
91 |
+
|
92 |
+
center = torch.tensor(((bnds[0, 1] + bnds[0, 0]) / 2, (bnds[1, 1] + bnds[1, 0]) / 2,
|
93 |
+
(bnds[2, 1] + bnds[2, 0]) / 2))
|
94 |
+
|
95 |
+
lengths = bnds[:, 1] - bnds[:, 0]
|
96 |
+
|
97 |
+
max_length, _ = torch.max(lengths, dim=0)
|
98 |
+
radius = max_length / 2
|
99 |
+
|
100 |
+
# print("radius: ", radius)
|
101 |
+
return center, radius, bnds
|
SparseNeuS_demo_v1/exp/lod0/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
checkpoints_*/
|
SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:888aaa8abde948358c26e4ef63df99f666438345c1dee301059967c5ce77b6ea
|
3 |
+
size 5312111
|
SparseNeuS_demo_v1/exp_runner_generic_blender_val.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from shutil import copyfile
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
from rich import print
|
10 |
+
from tqdm import tqdm
|
11 |
+
from pyhocon import ConfigFactory
|
12 |
+
|
13 |
+
import sys
|
14 |
+
sys.path.append(os.path.dirname(__file__))
|
15 |
+
|
16 |
+
from models.fields import SingleVarianceNetwork
|
17 |
+
from models.featurenet import FeatureNet
|
18 |
+
from models.trainer_generic import GenericTrainer
|
19 |
+
from models.sparse_sdf_network import SparseSdfNetwork
|
20 |
+
from models.rendering_network import GeneralRenderingNetwork
|
21 |
+
from data.blender_general_narrow_all_eval_new_data import BlenderPerView
|
22 |
+
|
23 |
+
|
24 |
+
from datetime import datetime
|
25 |
+
|
26 |
+
class Runner:
|
27 |
+
def __init__(self, conf_path, mode='train', is_continue=False,
|
28 |
+
is_restore=False, restore_lod0=False, local_rank=0):
|
29 |
+
|
30 |
+
# Initial setting
|
31 |
+
self.device = torch.device('cuda:%d' % local_rank)
|
32 |
+
# self.device = torch.device('cuda')
|
33 |
+
self.num_devices = torch.cuda.device_count()
|
34 |
+
self.is_continue = is_continue or (mode == "export_mesh")
|
35 |
+
self.is_restore = is_restore
|
36 |
+
self.restore_lod0 = restore_lod0
|
37 |
+
self.mode = mode
|
38 |
+
self.model_list = []
|
39 |
+
self.logger = logging.getLogger('exp_logger')
|
40 |
+
|
41 |
+
print("detected %d GPUs" % self.num_devices)
|
42 |
+
|
43 |
+
self.conf_path = conf_path
|
44 |
+
self.conf = ConfigFactory.parse_file(conf_path)
|
45 |
+
self.timestamp = None
|
46 |
+
if not self.is_continue:
|
47 |
+
self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
|
48 |
+
self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp # jha comment this when testing and use this when training
|
49 |
+
else:
|
50 |
+
self.base_exp_dir = self.conf['general.base_exp_dir']
|
51 |
+
self.conf['general.base_exp_dir'] = self.base_exp_dir # jha use this when testing
|
52 |
+
print("base_exp_dir: " + self.base_exp_dir)
|
53 |
+
os.makedirs(self.base_exp_dir, exist_ok=True)
|
54 |
+
self.iter_step = 0
|
55 |
+
self.val_step = 0
|
56 |
+
|
57 |
+
# trainning parameters
|
58 |
+
self.end_iter = self.conf.get_int('train.end_iter')
|
59 |
+
self.save_freq = self.conf.get_int('train.save_freq')
|
60 |
+
self.report_freq = self.conf.get_int('train.report_freq')
|
61 |
+
self.val_freq = self.conf.get_int('train.val_freq')
|
62 |
+
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
|
63 |
+
self.batch_size = self.num_devices # use DataParallel to warp
|
64 |
+
self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
|
65 |
+
self.learning_rate = self.conf.get_float('train.learning_rate')
|
66 |
+
self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone')
|
67 |
+
self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor')
|
68 |
+
self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd')
|
69 |
+
self.N_rays = self.conf.get_int('train.N_rays')
|
70 |
+
|
71 |
+
# warmup params for sdf gradient
|
72 |
+
self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0)
|
73 |
+
self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0)
|
74 |
+
self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0)
|
75 |
+
self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0)
|
76 |
+
|
77 |
+
self.writer = None
|
78 |
+
|
79 |
+
# Networks
|
80 |
+
self.num_lods = self.conf.get_int('model.num_lods')
|
81 |
+
|
82 |
+
self.rendering_network_outside = None
|
83 |
+
self.sdf_network_lod0 = None
|
84 |
+
self.sdf_network_lod1 = None
|
85 |
+
self.variance_network_lod0 = None
|
86 |
+
self.variance_network_lod1 = None
|
87 |
+
self.rendering_network_lod0 = None
|
88 |
+
self.rendering_network_lod1 = None
|
89 |
+
self.pyramid_feature_network = None # extract 2d pyramid feature maps from images, used for geometry
|
90 |
+
self.pyramid_feature_network_lod1 = None # may use different feature network for different lod
|
91 |
+
|
92 |
+
# * pyramid_feature_network
|
93 |
+
self.pyramid_feature_network = FeatureNet().to(self.device)
|
94 |
+
self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device)
|
95 |
+
self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
|
96 |
+
|
97 |
+
if self.num_lods > 1:
|
98 |
+
self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device)
|
99 |
+
self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device)
|
100 |
+
|
101 |
+
self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to(
|
102 |
+
self.device)
|
103 |
+
|
104 |
+
if self.num_lods > 1:
|
105 |
+
self.pyramid_feature_network_lod1 = FeatureNet().to(self.device)
|
106 |
+
self.rendering_network_lod1 = GeneralRenderingNetwork(
|
107 |
+
**self.conf['model.rendering_network_lod1']).to(self.device)
|
108 |
+
if self.mode == 'export_mesh' or self.mode == 'val':
|
109 |
+
# base_exp_dir_to_store = os.path.join(self.base_exp_dir, '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()))
|
110 |
+
base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) #"../gradio_tmp" # MODIFIED
|
111 |
+
else:
|
112 |
+
base_exp_dir_to_store = self.base_exp_dir
|
113 |
+
|
114 |
+
print(f"Store in: {base_exp_dir_to_store}")
|
115 |
+
# Renderer model
|
116 |
+
self.trainer = GenericTrainer(
|
117 |
+
self.rendering_network_outside,
|
118 |
+
self.pyramid_feature_network,
|
119 |
+
self.pyramid_feature_network_lod1,
|
120 |
+
self.sdf_network_lod0,
|
121 |
+
self.sdf_network_lod1,
|
122 |
+
self.variance_network_lod0,
|
123 |
+
self.variance_network_lod1,
|
124 |
+
self.rendering_network_lod0,
|
125 |
+
self.rendering_network_lod1,
|
126 |
+
**self.conf['model.trainer'],
|
127 |
+
timestamp=self.timestamp,
|
128 |
+
base_exp_dir=base_exp_dir_to_store,
|
129 |
+
conf=self.conf)
|
130 |
+
|
131 |
+
self.data_setup() # * data setup
|
132 |
+
|
133 |
+
self.optimizer_setup()
|
134 |
+
|
135 |
+
# Load checkpoint
|
136 |
+
latest_model_name = None
|
137 |
+
if self.is_continue:
|
138 |
+
model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints'))
|
139 |
+
model_list = []
|
140 |
+
for model_name in model_list_raw:
|
141 |
+
if model_name.startswith('ckpt'):
|
142 |
+
if model_name[-3:] == 'pth': # and int(model_name[5:-4]) <= self.end_iter:
|
143 |
+
model_list.append(model_name)
|
144 |
+
model_list.sort()
|
145 |
+
latest_model_name = model_list[-1]
|
146 |
+
|
147 |
+
if latest_model_name is not None:
|
148 |
+
self.logger.info('Find checkpoint: {}'.format(latest_model_name))
|
149 |
+
self.load_checkpoint(latest_model_name)
|
150 |
+
|
151 |
+
self.trainer = torch.nn.DataParallel(self.trainer).to(self.device)
|
152 |
+
|
153 |
+
if self.mode[:5] == 'train':
|
154 |
+
self.file_backup()
|
155 |
+
|
156 |
+
def optimizer_setup(self):
|
157 |
+
self.params_to_train = self.trainer.get_trainable_params()
|
158 |
+
self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate)
|
159 |
+
|
160 |
+
def data_setup(self):
|
161 |
+
"""
|
162 |
+
if use ddp, use setup() not prepare_data(),
|
163 |
+
prepare_data() only called on 1 GPU/TPU in distributed
|
164 |
+
:return:
|
165 |
+
"""
|
166 |
+
|
167 |
+
self.train_dataset = BlenderPerView(
|
168 |
+
root_dir=self.conf['dataset.trainpath'],
|
169 |
+
split=self.conf.get_string('dataset.train_split', default='train'),
|
170 |
+
split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None),
|
171 |
+
n_views=self.conf['dataset.nviews'],
|
172 |
+
downSample=self.conf['dataset.imgScale_train'],
|
173 |
+
N_rays=self.N_rays,
|
174 |
+
batch_size=self.batch_size,
|
175 |
+
clean_image=True, # True for training
|
176 |
+
importance_sample=self.conf.get_bool('dataset.importance_sample', default=False),
|
177 |
+
specific_dataset_name = args.specific_dataset_name
|
178 |
+
)
|
179 |
+
|
180 |
+
self.val_dataset = BlenderPerView(
|
181 |
+
root_dir=self.conf['dataset.valpath'],
|
182 |
+
split=self.conf.get_string('dataset.test_split', default='test'),
|
183 |
+
split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None),
|
184 |
+
n_views=3,
|
185 |
+
downSample=self.conf['dataset.imgScale_test'],
|
186 |
+
N_rays=self.N_rays,
|
187 |
+
batch_size=self.batch_size,
|
188 |
+
clean_image=self.conf.get_bool('dataset.mask_out_image',
|
189 |
+
default=False) if self.mode != 'train' else False,
|
190 |
+
importance_sample=self.conf.get_bool('dataset.importance_sample', default=False),
|
191 |
+
test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]),
|
192 |
+
specific_dataset_name = args.specific_dataset_name
|
193 |
+
)
|
194 |
+
|
195 |
+
# item = self.train_dataset.__getitem__(0)
|
196 |
+
self.train_dataloader = DataLoader(self.train_dataset,
|
197 |
+
shuffle=True,
|
198 |
+
num_workers=4 * self.batch_size,
|
199 |
+
# num_workers=1,
|
200 |
+
batch_size=self.batch_size,
|
201 |
+
pin_memory=True,
|
202 |
+
drop_last=True
|
203 |
+
)
|
204 |
+
|
205 |
+
self.val_dataloader = DataLoader(self.val_dataset,
|
206 |
+
# shuffle=False if self.mode == 'train' else True,
|
207 |
+
shuffle=False,
|
208 |
+
num_workers=4 * self.batch_size,
|
209 |
+
# num_workers=1,
|
210 |
+
batch_size=self.batch_size,
|
211 |
+
pin_memory=True,
|
212 |
+
drop_last=False
|
213 |
+
)
|
214 |
+
|
215 |
+
self.val_dataloader_iterator = iter(self.val_dataloader) # - should be after "reconstruct_metas_for_gru_fusion"
|
216 |
+
|
217 |
+
def train(self):
|
218 |
+
self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs'))
|
219 |
+
res_step = self.end_iter - self.iter_step
|
220 |
+
|
221 |
+
dataloader = self.train_dataloader
|
222 |
+
|
223 |
+
epochs = int(1 + res_step // len(dataloader))
|
224 |
+
|
225 |
+
self.adjust_learning_rate()
|
226 |
+
print("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr']))
|
227 |
+
|
228 |
+
background_rgb = None
|
229 |
+
if self.use_white_bkgd:
|
230 |
+
# background_rgb = torch.ones([1, 3]).to(self.device)
|
231 |
+
background_rgb = 1.0
|
232 |
+
|
233 |
+
for epoch_i in range(epochs):
|
234 |
+
|
235 |
+
print("current epoch %d" % epoch_i)
|
236 |
+
dataloader = tqdm(dataloader)
|
237 |
+
|
238 |
+
for batch in dataloader:
|
239 |
+
# print("Checker1:, fetch data")
|
240 |
+
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # used to get meta
|
241 |
+
|
242 |
+
# - warmup params
|
243 |
+
if self.num_lods == 1:
|
244 |
+
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
|
245 |
+
else:
|
246 |
+
alpha_inter_ratio_lod0 = 1.
|
247 |
+
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
|
248 |
+
|
249 |
+
losses = self.trainer(
|
250 |
+
batch,
|
251 |
+
background_rgb=background_rgb,
|
252 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
253 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
254 |
+
iter_step=self.iter_step,
|
255 |
+
mode='train',
|
256 |
+
)
|
257 |
+
|
258 |
+
loss_types = ['loss_lod0', 'loss_lod1']
|
259 |
+
# print("[TEST]: weights_sum in trainer return", losses['losses_lod0']['weights_sum'].mean())
|
260 |
+
|
261 |
+
losses_lod0 = losses['losses_lod0']
|
262 |
+
losses_lod1 = losses['losses_lod1']
|
263 |
+
# import ipdb; ipdb.set_trace()
|
264 |
+
loss = 0
|
265 |
+
for loss_type in loss_types:
|
266 |
+
if losses[loss_type] is not None:
|
267 |
+
loss = loss + losses[loss_type].mean()
|
268 |
+
# print("Checker4:, begin BP")
|
269 |
+
self.optimizer.zero_grad()
|
270 |
+
loss.backward()
|
271 |
+
torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0)
|
272 |
+
self.optimizer.step()
|
273 |
+
# print("Checker5:, end BP")
|
274 |
+
self.iter_step += 1
|
275 |
+
|
276 |
+
if self.iter_step % self.report_freq == 0:
|
277 |
+
self.writer.add_scalar('Loss/loss', loss, self.iter_step)
|
278 |
+
|
279 |
+
if losses_lod0 is not None:
|
280 |
+
self.writer.add_scalar('Loss/d_loss_lod0',
|
281 |
+
losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0,
|
282 |
+
self.iter_step)
|
283 |
+
self.writer.add_scalar('Loss/sparse_loss_lod0',
|
284 |
+
losses_lod0[
|
285 |
+
'sparse_loss'].mean() if losses_lod0 is not None else 0,
|
286 |
+
self.iter_step)
|
287 |
+
self.writer.add_scalar('Loss/color_loss_lod0',
|
288 |
+
losses_lod0['color_fine_loss'].mean()
|
289 |
+
if losses_lod0['color_fine_loss'] is not None else 0,
|
290 |
+
self.iter_step)
|
291 |
+
|
292 |
+
self.writer.add_scalar('statis/psnr_lod0',
|
293 |
+
losses_lod0['psnr'].mean()
|
294 |
+
if losses_lod0['psnr'] is not None else 0,
|
295 |
+
self.iter_step)
|
296 |
+
|
297 |
+
self.writer.add_scalar('param/variance_lod0',
|
298 |
+
1. / torch.exp(self.variance_network_lod0.variance * 10),
|
299 |
+
self.iter_step)
|
300 |
+
self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0,
|
301 |
+
self.iter_step)
|
302 |
+
|
303 |
+
######## - lod 1
|
304 |
+
if self.num_lods > 1:
|
305 |
+
self.writer.add_scalar('Loss/d_loss_lod1',
|
306 |
+
losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0,
|
307 |
+
self.iter_step)
|
308 |
+
self.writer.add_scalar('Loss/sparse_loss_lod1',
|
309 |
+
losses_lod1[
|
310 |
+
'sparse_loss'].mean() if losses_lod1 is not None else 0,
|
311 |
+
self.iter_step)
|
312 |
+
self.writer.add_scalar('Loss/color_loss_lod1',
|
313 |
+
losses_lod1['color_fine_loss'].mean()
|
314 |
+
if losses_lod1['color_fine_loss'] is not None else 0,
|
315 |
+
self.iter_step)
|
316 |
+
self.writer.add_scalar('statis/sdf_mean_lod1',
|
317 |
+
losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0,
|
318 |
+
self.iter_step)
|
319 |
+
self.writer.add_scalar('statis/psnr_lod1',
|
320 |
+
losses_lod1['psnr'].mean()
|
321 |
+
if losses_lod1['psnr'] is not None else 0,
|
322 |
+
self.iter_step)
|
323 |
+
self.writer.add_scalar('statis/sparseness_0.01_lod1',
|
324 |
+
losses_lod1['sparseness_1'].mean()
|
325 |
+
if losses_lod1['sparseness_1'] is not None else 0,
|
326 |
+
self.iter_step)
|
327 |
+
self.writer.add_scalar('statis/sparseness_0.02_lod1',
|
328 |
+
losses_lod1['sparseness_2'].mean()
|
329 |
+
if losses_lod1['sparseness_2'] is not None else 0,
|
330 |
+
self.iter_step)
|
331 |
+
self.writer.add_scalar('param/variance_lod1',
|
332 |
+
1. / torch.exp(self.variance_network_lod1.variance * 10),
|
333 |
+
self.iter_step)
|
334 |
+
|
335 |
+
print(self.base_exp_dir)
|
336 |
+
print(
|
337 |
+
'iter:{:8>d} '
|
338 |
+
'loss = {:.4f} '
|
339 |
+
'd_loss_lod0 = {:.4f} '
|
340 |
+
'color_loss_lod0 = {:.4f} '
|
341 |
+
'sparse_loss_lod0= {:.4f} '
|
342 |
+
'd_loss_lod1 = {:.4f} '
|
343 |
+
'color_loss_lod1 = {:.4f} '
|
344 |
+
' lr = {:.5f}'.format(
|
345 |
+
self.iter_step, loss,
|
346 |
+
losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0,
|
347 |
+
losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0,
|
348 |
+
losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0,
|
349 |
+
losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0,
|
350 |
+
losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0,
|
351 |
+
self.optimizer.param_groups[0]['lr']))
|
352 |
+
|
353 |
+
print('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format(
|
354 |
+
alpha_inter_ratio_lod0, alpha_inter_ratio_lod1))
|
355 |
+
|
356 |
+
if losses_lod0 is not None:
|
357 |
+
# print("[TEST]: weights_sum in print", losses_lod0['weights_sum'].mean())
|
358 |
+
# import ipdb; ipdb.set_trace()
|
359 |
+
print(
|
360 |
+
'iter:{:8>d} '
|
361 |
+
'variance = {:.5f} '
|
362 |
+
'weights_sum = {:.4f} '
|
363 |
+
'weights_sum_fg = {:.4f} '
|
364 |
+
'alpha_sum = {:.4f} '
|
365 |
+
'sparse_weight= {:.4f} '
|
366 |
+
'background_loss = {:.4f} '
|
367 |
+
'background_weight = {:.4f} '
|
368 |
+
.format(
|
369 |
+
self.iter_step,
|
370 |
+
losses_lod0['variance'].mean(),
|
371 |
+
losses_lod0['weights_sum'].mean(),
|
372 |
+
losses_lod0['weights_sum_fg'].mean(),
|
373 |
+
losses_lod0['alpha_sum'].mean(),
|
374 |
+
losses_lod0['sparse_weight'].mean(),
|
375 |
+
losses_lod0['fg_bg_loss'].mean(),
|
376 |
+
losses_lod0['fg_bg_weight'].mean(),
|
377 |
+
))
|
378 |
+
|
379 |
+
if losses_lod1 is not None:
|
380 |
+
print(
|
381 |
+
'iter:{:8>d} '
|
382 |
+
'variance = {:.5f} '
|
383 |
+
' weights_sum = {:.4f} '
|
384 |
+
'alpha_sum = {:.4f} '
|
385 |
+
'fg_bg_loss = {:.4f} '
|
386 |
+
'fg_bg_weight = {:.4f} '
|
387 |
+
'sparse_weight= {:.4f} '
|
388 |
+
'fg_bg_loss = {:.4f} '
|
389 |
+
'fg_bg_weight = {:.4f} '
|
390 |
+
.format(
|
391 |
+
self.iter_step,
|
392 |
+
losses_lod1['variance'].mean(),
|
393 |
+
losses_lod1['weights_sum'].mean(),
|
394 |
+
losses_lod1['alpha_sum'].mean(),
|
395 |
+
losses_lod1['fg_bg_loss'].mean(),
|
396 |
+
losses_lod1['fg_bg_weight'].mean(),
|
397 |
+
losses_lod1['sparse_weight'].mean(),
|
398 |
+
losses_lod1['fg_bg_loss'].mean(),
|
399 |
+
losses_lod1['fg_bg_weight'].mean(),
|
400 |
+
))
|
401 |
+
|
402 |
+
if self.iter_step % self.save_freq == 0:
|
403 |
+
self.save_checkpoint()
|
404 |
+
|
405 |
+
if self.iter_step % self.val_freq == 0:
|
406 |
+
self.validate()
|
407 |
+
|
408 |
+
# - ajust learning rate
|
409 |
+
self.adjust_learning_rate()
|
410 |
+
|
411 |
+
def adjust_learning_rate(self):
|
412 |
+
# - ajust learning rate, cosine learning schedule
|
413 |
+
learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1
|
414 |
+
learning_rate = self.learning_rate * learning_rate
|
415 |
+
for g in self.optimizer.param_groups:
|
416 |
+
g['lr'] = learning_rate
|
417 |
+
|
418 |
+
def get_alpha_inter_ratio(self, start, end):
|
419 |
+
if end == 0.0:
|
420 |
+
return 1.0
|
421 |
+
elif self.iter_step < start:
|
422 |
+
return 0.0
|
423 |
+
else:
|
424 |
+
return np.min([1.0, (self.iter_step - start) / (end - start)])
|
425 |
+
|
426 |
+
def file_backup(self):
|
427 |
+
# copy python file
|
428 |
+
dir_lis = self.conf['general.recording']
|
429 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
|
430 |
+
for dir_name in dir_lis:
|
431 |
+
cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
|
432 |
+
os.makedirs(cur_dir, exist_ok=True)
|
433 |
+
files = os.listdir(dir_name)
|
434 |
+
for f_name in files:
|
435 |
+
if f_name[-3:] == '.py':
|
436 |
+
copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
|
437 |
+
|
438 |
+
# copy configs
|
439 |
+
copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
|
440 |
+
|
441 |
+
def load_checkpoint(self, checkpoint_name):
|
442 |
+
|
443 |
+
def load_state_dict(network, checkpoint, comment):
|
444 |
+
if network is not None:
|
445 |
+
try:
|
446 |
+
pretrained_dict = checkpoint[comment]
|
447 |
+
|
448 |
+
model_dict = network.state_dict()
|
449 |
+
|
450 |
+
# 1. filter out unnecessary keys
|
451 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
452 |
+
# 2. overwrite entries in the existing state dict
|
453 |
+
model_dict.update(pretrained_dict)
|
454 |
+
# 3. load the new state dict
|
455 |
+
network.load_state_dict(pretrained_dict)
|
456 |
+
except:
|
457 |
+
print(comment + " load fails")
|
458 |
+
|
459 |
+
checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name),
|
460 |
+
map_location=self.device)
|
461 |
+
|
462 |
+
load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside')
|
463 |
+
|
464 |
+
load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0')
|
465 |
+
load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1')
|
466 |
+
|
467 |
+
load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network')
|
468 |
+
load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1')
|
469 |
+
|
470 |
+
load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0')
|
471 |
+
load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1')
|
472 |
+
|
473 |
+
load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0')
|
474 |
+
load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1')
|
475 |
+
|
476 |
+
if self.restore_lod0: # use the trained lod0 networks to initialize lod1 networks
|
477 |
+
load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0')
|
478 |
+
load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network')
|
479 |
+
load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0')
|
480 |
+
|
481 |
+
if self.is_continue and (not self.restore_lod0):
|
482 |
+
try:
|
483 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
484 |
+
except:
|
485 |
+
print("load optimizer fails")
|
486 |
+
self.iter_step = checkpoint['iter_step']
|
487 |
+
self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0
|
488 |
+
|
489 |
+
self.logger.info('End')
|
490 |
+
|
491 |
+
def save_checkpoint(self):
|
492 |
+
|
493 |
+
def save_state_dict(network, checkpoint, comment):
|
494 |
+
if network is not None:
|
495 |
+
checkpoint[comment] = network.state_dict()
|
496 |
+
|
497 |
+
checkpoint = {
|
498 |
+
'optimizer': self.optimizer.state_dict(),
|
499 |
+
'iter_step': self.iter_step,
|
500 |
+
'val_step': self.val_step,
|
501 |
+
}
|
502 |
+
|
503 |
+
save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0")
|
504 |
+
save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1")
|
505 |
+
|
506 |
+
save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside')
|
507 |
+
save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0")
|
508 |
+
save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1")
|
509 |
+
|
510 |
+
save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0')
|
511 |
+
save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1')
|
512 |
+
|
513 |
+
save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network')
|
514 |
+
save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1')
|
515 |
+
|
516 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
|
517 |
+
torch.save(checkpoint,
|
518 |
+
os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
|
519 |
+
|
520 |
+
def validate(self, resolution_level=-1):
|
521 |
+
# validate image
|
522 |
+
print("iter_step: ", self.iter_step)
|
523 |
+
self.logger.info('Validate begin')
|
524 |
+
self.val_step += 1
|
525 |
+
|
526 |
+
try:
|
527 |
+
batch = next(self.val_dataloader_iterator)
|
528 |
+
except:
|
529 |
+
self.val_dataloader_iterator = iter(self.val_dataloader) # reset
|
530 |
+
|
531 |
+
batch = next(self.val_dataloader_iterator)
|
532 |
+
|
533 |
+
|
534 |
+
background_rgb = None
|
535 |
+
if self.use_white_bkgd:
|
536 |
+
# background_rgb = torch.ones([1, 3]).to(self.device)
|
537 |
+
background_rgb = 1.0
|
538 |
+
|
539 |
+
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)])
|
540 |
+
|
541 |
+
# - warmup params
|
542 |
+
if self.num_lods == 1:
|
543 |
+
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
|
544 |
+
else:
|
545 |
+
alpha_inter_ratio_lod0 = 1.
|
546 |
+
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
|
547 |
+
|
548 |
+
self.trainer(
|
549 |
+
batch,
|
550 |
+
background_rgb=background_rgb,
|
551 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
552 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
553 |
+
iter_step=self.iter_step,
|
554 |
+
save_vis=True,
|
555 |
+
mode='val',
|
556 |
+
)
|
557 |
+
|
558 |
+
|
559 |
+
def export_mesh(self, resolution_level=-1):
|
560 |
+
print("iter_step: ", self.iter_step)
|
561 |
+
self.logger.info('Validate begin')
|
562 |
+
self.val_step += 1
|
563 |
+
|
564 |
+
try:
|
565 |
+
batch = next(self.val_dataloader_iterator)
|
566 |
+
except:
|
567 |
+
self.val_dataloader_iterator = iter(self.val_dataloader) # reset
|
568 |
+
|
569 |
+
batch = next(self.val_dataloader_iterator)
|
570 |
+
|
571 |
+
|
572 |
+
background_rgb = None
|
573 |
+
if self.use_white_bkgd:
|
574 |
+
background_rgb = 1.0
|
575 |
+
|
576 |
+
batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)])
|
577 |
+
|
578 |
+
# - warmup params
|
579 |
+
if self.num_lods == 1:
|
580 |
+
alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0)
|
581 |
+
else:
|
582 |
+
alpha_inter_ratio_lod0 = 1.
|
583 |
+
alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1)
|
584 |
+
self.trainer(
|
585 |
+
batch,
|
586 |
+
background_rgb=background_rgb,
|
587 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
588 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
589 |
+
iter_step=self.iter_step,
|
590 |
+
save_vis=True,
|
591 |
+
mode='export_mesh',
|
592 |
+
)
|
593 |
+
|
594 |
+
|
595 |
+
if __name__ == '__main__':
|
596 |
+
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
597 |
+
torch.set_default_dtype(torch.float32)
|
598 |
+
FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
|
599 |
+
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
600 |
+
|
601 |
+
parser = argparse.ArgumentParser()
|
602 |
+
parser.add_argument('--conf', type=str, default='./confs/base.conf')
|
603 |
+
parser.add_argument('--mode', type=str, default='train')
|
604 |
+
parser.add_argument('--threshold', type=float, default=0.0)
|
605 |
+
parser.add_argument('--is_continue', default=False, action="store_true")
|
606 |
+
parser.add_argument('--is_restore', default=False, action="store_true")
|
607 |
+
parser.add_argument('--is_finetune', default=False, action="store_true")
|
608 |
+
parser.add_argument('--train_from_scratch', default=False, action="store_true")
|
609 |
+
parser.add_argument('--restore_lod0', default=False, action="store_true")
|
610 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
611 |
+
parser.add_argument('--specific_dataset_name', type=str, default='GSO')
|
612 |
+
|
613 |
+
|
614 |
+
args = parser.parse_args()
|
615 |
+
|
616 |
+
torch.cuda.set_device(args.local_rank)
|
617 |
+
torch.backends.cudnn.benchmark = True # ! make training 2x faster
|
618 |
+
|
619 |
+
runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0,
|
620 |
+
args.local_rank)
|
621 |
+
|
622 |
+
if args.mode == 'train':
|
623 |
+
runner.train()
|
624 |
+
elif args.mode == 'val':
|
625 |
+
for i in range(len(runner.val_dataset)):
|
626 |
+
runner.validate()
|
627 |
+
elif args.mode == 'export_mesh':
|
628 |
+
for i in range(len(runner.val_dataset)):
|
629 |
+
runner.export_mesh()
|
SparseNeuS_demo_v1/loss/__init__.py
ADDED
File without changes
|
SparseNeuS_demo_v1/loss/color_loss.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from loss.ncc import NCC
|
4 |
+
|
5 |
+
|
6 |
+
class Normalize(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(Normalize, self).__init__()
|
9 |
+
|
10 |
+
def forward(self, bottom):
|
11 |
+
qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12
|
12 |
+
top = bottom.div(qn)
|
13 |
+
|
14 |
+
return top
|
15 |
+
|
16 |
+
|
17 |
+
class OcclusionColorLoss(nn.Module):
|
18 |
+
def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]):
|
19 |
+
super(OcclusionColorLoss, self).__init__()
|
20 |
+
self.alpha = alpha
|
21 |
+
self.beta = beta
|
22 |
+
self.gama = gama
|
23 |
+
self.occlusion_aware = occlusion_aware
|
24 |
+
self.eps = 1e-4
|
25 |
+
|
26 |
+
self.weight_thred = weight_thred
|
27 |
+
self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
|
28 |
+
|
29 |
+
def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True):
|
30 |
+
"""
|
31 |
+
|
32 |
+
:param pred: [N_pts, 3]
|
33 |
+
:param gt: [N_pts, 3]
|
34 |
+
:param weight: [N_pts]
|
35 |
+
:param mask: [N_pts]
|
36 |
+
:return:
|
37 |
+
"""
|
38 |
+
if detach:
|
39 |
+
weight = weight.detach()
|
40 |
+
|
41 |
+
error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) # [N_pts]
|
42 |
+
error = error[mask]
|
43 |
+
|
44 |
+
if not (self.occlusion_aware and occlusion_aware):
|
45 |
+
return torch.mean(error), torch.mean(error)
|
46 |
+
|
47 |
+
beta = self.adjuster(weight.mean())
|
48 |
+
|
49 |
+
# weight = weight[mask]
|
50 |
+
weight = weight.clamp(0.0, 1.0)
|
51 |
+
term1 = self.alpha * torch.mean(weight[mask] * error)
|
52 |
+
term2 = beta * torch.log(1 - weight + self.eps).mean()
|
53 |
+
term3 = self.gama * torch.log(weight + self.eps).mean()
|
54 |
+
|
55 |
+
return term1 + term2 + term3, term1
|
56 |
+
|
57 |
+
|
58 |
+
class OcclusionColorPatchLoss(nn.Module):
|
59 |
+
def __init__(self, alpha=1, beta=0.025, gama=0.015,
|
60 |
+
occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]):
|
61 |
+
super(OcclusionColorPatchLoss, self).__init__()
|
62 |
+
self.alpha = alpha
|
63 |
+
self.beta = beta
|
64 |
+
self.gama = gama
|
65 |
+
self.occlusion_aware = occlusion_aware
|
66 |
+
self.type = type # 'l1' or 'ncc' loss
|
67 |
+
self.ncc = NCC(h_patch_size=h_patch_size)
|
68 |
+
self.eps = 1e-4
|
69 |
+
self.weight_thred = weight_thred
|
70 |
+
|
71 |
+
self.adjuster = ParamAdjuster(self.weight_thred, self.beta)
|
72 |
+
|
73 |
+
print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama,
|
74 |
+
weight_thred))
|
75 |
+
|
76 |
+
def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True):
|
77 |
+
"""
|
78 |
+
|
79 |
+
:param pred: [N_pts, Npx, 3]
|
80 |
+
:param gt: [N_pts, Npx, 3]
|
81 |
+
:param weight: [N_pts]
|
82 |
+
:param mask: [N_pts]
|
83 |
+
:return:
|
84 |
+
"""
|
85 |
+
|
86 |
+
if detach:
|
87 |
+
weight = weight.detach()
|
88 |
+
|
89 |
+
if self.type == 'l1':
|
90 |
+
error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) # [N_pts]
|
91 |
+
elif self.type == 'ncc':
|
92 |
+
error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] # ncc 1 positive, -1 negative
|
93 |
+
error, indices = torch.sort(error)
|
94 |
+
mask = torch.index_select(mask, 0, index=indices)
|
95 |
+
mask[int(penalize_ratio * mask.shape[0]):] = False # can help boundaries
|
96 |
+
elif self.type == 'ssd':
|
97 |
+
error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False)
|
98 |
+
|
99 |
+
error = error[mask]
|
100 |
+
if not (self.occlusion_aware and occlusion_aware):
|
101 |
+
return torch.mean(error), torch.mean(error), 0.
|
102 |
+
|
103 |
+
# * weight adjuster
|
104 |
+
beta = self.adjuster(weight.mean())
|
105 |
+
|
106 |
+
# weight = weight[mask]
|
107 |
+
weight = weight.clamp(0.0, 1.0)
|
108 |
+
|
109 |
+
term1 = self.alpha * torch.mean(weight[mask] * error)
|
110 |
+
term2 = beta * torch.log(1 - weight + self.eps).mean()
|
111 |
+
term3 = self.gama * torch.log(weight + self.eps).mean()
|
112 |
+
|
113 |
+
return term1 + term2 + term3, term1, beta
|
114 |
+
|
115 |
+
|
116 |
+
class ParamAdjuster(nn.Module):
|
117 |
+
def __init__(self, weight_thred, param):
|
118 |
+
super(ParamAdjuster, self).__init__()
|
119 |
+
self.weight_thred = weight_thred
|
120 |
+
self.thred_num = len(weight_thred)
|
121 |
+
self.param = param
|
122 |
+
self.global_step = 0
|
123 |
+
self.statis_window = 100
|
124 |
+
self.counter = 0
|
125 |
+
self.adjusted = False
|
126 |
+
self.adjusted_step = 0
|
127 |
+
self.thred_idx = 0
|
128 |
+
|
129 |
+
def reset(self):
|
130 |
+
self.counter = 0
|
131 |
+
self.adjusted = False
|
132 |
+
|
133 |
+
def adjust(self):
|
134 |
+
if (self.counter / self.statis_window) > 0.3:
|
135 |
+
self.param = self.param + 0.005
|
136 |
+
self.adjusted = True
|
137 |
+
self.adjusted_step = self.global_step
|
138 |
+
self.thred_idx += 1
|
139 |
+
print("adjusted param, now {}".format(self.param))
|
140 |
+
|
141 |
+
def forward(self, weight_mean):
|
142 |
+
self.global_step += 1
|
143 |
+
|
144 |
+
if (self.global_step % self.statis_window == 0) and self.adjusted is False:
|
145 |
+
self.adjust()
|
146 |
+
self.reset()
|
147 |
+
|
148 |
+
if self.thred_idx < self.thred_num:
|
149 |
+
if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted):
|
150 |
+
self.counter += 1
|
151 |
+
|
152 |
+
return self.param
|
SparseNeuS_demo_v1/loss/depth_loss.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DepthLoss(nn.Module):
|
7 |
+
def __init__(self, type='l1'):
|
8 |
+
super(DepthLoss, self).__init__()
|
9 |
+
self.type = type
|
10 |
+
|
11 |
+
|
12 |
+
def forward(self, depth_pred, depth_gt, mask=None):
|
13 |
+
if (depth_gt < 0).sum() > 0:
|
14 |
+
# print("no depth loss")
|
15 |
+
return torch.tensor(0.0).to(depth_pred.device)
|
16 |
+
if mask is not None:
|
17 |
+
mask_d = (depth_gt > 0).float()
|
18 |
+
|
19 |
+
mask = mask * mask_d
|
20 |
+
|
21 |
+
mask_sum = mask.sum() + 1e-5
|
22 |
+
depth_error = (depth_pred - depth_gt) * mask
|
23 |
+
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
|
24 |
+
reduction='sum') / mask_sum
|
25 |
+
else:
|
26 |
+
depth_error = depth_pred - depth_gt
|
27 |
+
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
|
28 |
+
reduction='mean')
|
29 |
+
return depth_loss
|
30 |
+
|
31 |
+
def forward(self, depth_pred, depth_gt, mask=None):
|
32 |
+
if mask is not None:
|
33 |
+
mask_d = (depth_gt > 0).float()
|
34 |
+
|
35 |
+
mask = mask * mask_d
|
36 |
+
|
37 |
+
mask_sum = mask.sum() + 1e-5
|
38 |
+
depth_error = (depth_pred - depth_gt) * mask
|
39 |
+
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
|
40 |
+
reduction='sum') / mask_sum
|
41 |
+
else:
|
42 |
+
depth_error = depth_pred - depth_gt
|
43 |
+
depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device),
|
44 |
+
reduction='mean')
|
45 |
+
return depth_loss
|
46 |
+
|
47 |
+
class DepthSmoothLoss(nn.Module):
|
48 |
+
def __init__(self):
|
49 |
+
super(DepthSmoothLoss, self).__init__()
|
50 |
+
|
51 |
+
def forward(self, disp, img, mask):
|
52 |
+
"""
|
53 |
+
Computes the smoothness loss for a disparity image
|
54 |
+
The color image is used for edge-aware smoothness
|
55 |
+
:param disp: [B, 1, H, W]
|
56 |
+
:param img: [B, 1, H, W]
|
57 |
+
:param mask: [B, 1, H, W]
|
58 |
+
:return:
|
59 |
+
"""
|
60 |
+
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
|
61 |
+
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
|
62 |
+
|
63 |
+
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
|
64 |
+
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
|
65 |
+
|
66 |
+
grad_disp_x *= torch.exp(-grad_img_x)
|
67 |
+
grad_disp_y *= torch.exp(-grad_img_y)
|
68 |
+
|
69 |
+
grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean()
|
70 |
+
|
71 |
+
return grad_disp
|
SparseNeuS_demo_v1/loss/depth_metric.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def l1(depth1, depth2):
|
5 |
+
"""
|
6 |
+
Computes the l1 errors between the two depth maps.
|
7 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
8 |
+
|
9 |
+
depth1: one depth map
|
10 |
+
depth2: another depth map
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
L1(log)
|
14 |
+
|
15 |
+
"""
|
16 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
17 |
+
diff = depth1 - depth2
|
18 |
+
num_pixels = float(diff.size)
|
19 |
+
|
20 |
+
if num_pixels == 0:
|
21 |
+
return np.nan
|
22 |
+
else:
|
23 |
+
return np.sum(np.absolute(diff)) / num_pixels
|
24 |
+
|
25 |
+
|
26 |
+
def l1_inverse(depth1, depth2):
|
27 |
+
"""
|
28 |
+
Computes the l1 errors between inverses of two depth maps.
|
29 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
30 |
+
|
31 |
+
depth1: one depth map
|
32 |
+
depth2: another depth map
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
L1(log)
|
36 |
+
|
37 |
+
"""
|
38 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
39 |
+
diff = np.reciprocal(depth1) - np.reciprocal(depth2)
|
40 |
+
num_pixels = float(diff.size)
|
41 |
+
|
42 |
+
if num_pixels == 0:
|
43 |
+
return np.nan
|
44 |
+
else:
|
45 |
+
return np.sum(np.absolute(diff)) / num_pixels
|
46 |
+
|
47 |
+
|
48 |
+
def rmse_log(depth1, depth2):
|
49 |
+
"""
|
50 |
+
Computes the root min square errors between the logs of two depth maps.
|
51 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
52 |
+
|
53 |
+
depth1: one depth map
|
54 |
+
depth2: another depth map
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
RMSE(log)
|
58 |
+
|
59 |
+
"""
|
60 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
61 |
+
log_diff = np.log(depth1) - np.log(depth2)
|
62 |
+
num_pixels = float(log_diff.size)
|
63 |
+
|
64 |
+
if num_pixels == 0:
|
65 |
+
return np.nan
|
66 |
+
else:
|
67 |
+
return np.sqrt(np.sum(np.square(log_diff)) / num_pixels)
|
68 |
+
|
69 |
+
|
70 |
+
def rmse(depth1, depth2):
|
71 |
+
"""
|
72 |
+
Computes the root min square errors between the two depth maps.
|
73 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
74 |
+
|
75 |
+
depth1: one depth map
|
76 |
+
depth2: another depth map
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
RMSE(log)
|
80 |
+
|
81 |
+
"""
|
82 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
83 |
+
diff = depth1 - depth2
|
84 |
+
num_pixels = float(diff.size)
|
85 |
+
|
86 |
+
if num_pixels == 0:
|
87 |
+
return np.nan
|
88 |
+
else:
|
89 |
+
return np.sqrt(np.sum(np.square(diff)) / num_pixels)
|
90 |
+
|
91 |
+
|
92 |
+
def scale_invariant(depth1, depth2):
|
93 |
+
"""
|
94 |
+
Computes the scale invariant loss based on differences of logs of depth maps.
|
95 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
96 |
+
|
97 |
+
depth1: one depth map
|
98 |
+
depth2: another depth map
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
scale_invariant_distance
|
102 |
+
|
103 |
+
"""
|
104 |
+
# sqrt(Eq. 3)
|
105 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
106 |
+
log_diff = np.log(depth1) - np.log(depth2)
|
107 |
+
num_pixels = float(log_diff.size)
|
108 |
+
|
109 |
+
if num_pixels == 0:
|
110 |
+
return np.nan
|
111 |
+
else:
|
112 |
+
return np.sqrt(np.sum(np.square(log_diff)) / num_pixels - np.square(np.sum(log_diff)) / np.square(num_pixels))
|
113 |
+
|
114 |
+
|
115 |
+
def abs_relative(depth_pred, depth_gt):
|
116 |
+
"""
|
117 |
+
Computes relative absolute distance.
|
118 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
119 |
+
|
120 |
+
depth_pred: depth map prediction
|
121 |
+
depth_gt: depth map ground truth
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
abs_relative_distance
|
125 |
+
|
126 |
+
"""
|
127 |
+
assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0)))
|
128 |
+
diff = depth_pred - depth_gt
|
129 |
+
num_pixels = float(diff.size)
|
130 |
+
|
131 |
+
if num_pixels == 0:
|
132 |
+
return np.nan
|
133 |
+
else:
|
134 |
+
return np.sum(np.absolute(diff) / depth_gt) / num_pixels
|
135 |
+
|
136 |
+
|
137 |
+
def avg_log10(depth1, depth2):
|
138 |
+
"""
|
139 |
+
Computes average log_10 error (Liu, Neural Fields, 2015).
|
140 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
141 |
+
|
142 |
+
depth1: one depth map
|
143 |
+
depth2: another depth map
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
abs_relative_distance
|
147 |
+
|
148 |
+
"""
|
149 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
150 |
+
log_diff = np.log10(depth1) - np.log10(depth2)
|
151 |
+
num_pixels = float(log_diff.size)
|
152 |
+
|
153 |
+
if num_pixels == 0:
|
154 |
+
return np.nan
|
155 |
+
else:
|
156 |
+
return np.sum(np.absolute(log_diff)) / num_pixels
|
157 |
+
|
158 |
+
|
159 |
+
def sq_relative(depth_pred, depth_gt):
|
160 |
+
"""
|
161 |
+
Computes relative squared distance.
|
162 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
163 |
+
|
164 |
+
depth_pred: depth map prediction
|
165 |
+
depth_gt: depth map ground truth
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
squared_relative_distance
|
169 |
+
|
170 |
+
"""
|
171 |
+
assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0)))
|
172 |
+
diff = depth_pred - depth_gt
|
173 |
+
num_pixels = float(diff.size)
|
174 |
+
|
175 |
+
if num_pixels == 0:
|
176 |
+
return np.nan
|
177 |
+
else:
|
178 |
+
return np.sum(np.square(diff) / depth_gt) / num_pixels
|
179 |
+
|
180 |
+
|
181 |
+
def ratio_threshold(depth1, depth2, threshold):
|
182 |
+
"""
|
183 |
+
Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold.
|
184 |
+
Takes preprocessed depths (no nans, infs and non-positive values)
|
185 |
+
|
186 |
+
depth1: one depth map
|
187 |
+
depth2: another depth map
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
percentage of pixels with ratio less than the threshold
|
191 |
+
|
192 |
+
"""
|
193 |
+
assert (threshold > 0.)
|
194 |
+
assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0)))
|
195 |
+
log_diff = np.log(depth1) - np.log(depth2)
|
196 |
+
num_pixels = float(log_diff.size)
|
197 |
+
|
198 |
+
if num_pixels == 0:
|
199 |
+
return np.nan
|
200 |
+
else:
|
201 |
+
return float(np.sum(np.absolute(log_diff) < np.log(threshold))) / num_pixels
|
202 |
+
|
203 |
+
|
204 |
+
def compute_depth_errors(depth_pred, depth_gt, valid_mask):
|
205 |
+
"""
|
206 |
+
Computes different distance measures between two depth maps.
|
207 |
+
|
208 |
+
depth_pred: depth map prediction
|
209 |
+
depth_gt: depth map ground truth
|
210 |
+
distances_to_compute: which distances to compute
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
a dictionary with computed distances, and the number of valid pixels
|
214 |
+
|
215 |
+
"""
|
216 |
+
depth_pred = depth_pred[valid_mask]
|
217 |
+
depth_gt = depth_gt[valid_mask]
|
218 |
+
num_valid = np.sum(valid_mask)
|
219 |
+
|
220 |
+
distances_to_compute = ['l1',
|
221 |
+
'l1_inverse',
|
222 |
+
'scale_invariant',
|
223 |
+
'abs_relative',
|
224 |
+
'sq_relative',
|
225 |
+
'avg_log10',
|
226 |
+
'rmse_log',
|
227 |
+
'rmse',
|
228 |
+
'ratio_threshold_1.25',
|
229 |
+
'ratio_threshold_1.5625',
|
230 |
+
'ratio_threshold_1.953125']
|
231 |
+
|
232 |
+
results = {'num_valid': num_valid}
|
233 |
+
for dist in distances_to_compute:
|
234 |
+
if dist.startswith('ratio_threshold'):
|
235 |
+
threshold = float(dist.split('_')[-1])
|
236 |
+
results[dist] = ratio_threshold(depth_pred, depth_gt, threshold)
|
237 |
+
else:
|
238 |
+
results[dist] = globals()[dist](depth_pred, depth_gt)
|
239 |
+
|
240 |
+
return results
|
SparseNeuS_demo_v1/loss/ncc.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from math import exp, sqrt
|
5 |
+
|
6 |
+
|
7 |
+
class NCC(torch.nn.Module):
|
8 |
+
def __init__(self, h_patch_size, mode='rgb'):
|
9 |
+
super(NCC, self).__init__()
|
10 |
+
self.window_size = 2 * h_patch_size + 1
|
11 |
+
self.mode = mode # 'rgb' or 'gray'
|
12 |
+
self.channel = 3
|
13 |
+
self.register_buffer("window", create_window(self.window_size, self.channel))
|
14 |
+
|
15 |
+
def forward(self, img_pred, img_gt):
|
16 |
+
"""
|
17 |
+
:param img_pred: [Npx, nviews, npatch, c]
|
18 |
+
:param img_gt: [Npx, npatch, c]
|
19 |
+
:return:
|
20 |
+
"""
|
21 |
+
ntotpx, nviews, npatch, channels = img_pred.shape
|
22 |
+
|
23 |
+
patch_size = int(sqrt(npatch))
|
24 |
+
patch_img_pred = img_pred.reshape(ntotpx, nviews, patch_size, patch_size, channels).permute(0, 1, 4, 2,
|
25 |
+
3).contiguous()
|
26 |
+
patch_img_gt = img_gt.reshape(ntotpx, patch_size, patch_size, channels).permute(0, 3, 1, 2)
|
27 |
+
|
28 |
+
return _ncc(patch_img_pred, patch_img_gt, self.window, self.channel)
|
29 |
+
|
30 |
+
|
31 |
+
def gaussian(window_size, sigma):
|
32 |
+
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
33 |
+
return gauss / gauss.sum()
|
34 |
+
|
35 |
+
|
36 |
+
def create_window(window_size, channel, std=1.5):
|
37 |
+
_1D_window = gaussian(window_size, std).unsqueeze(1)
|
38 |
+
_2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
|
39 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
40 |
+
return window
|
41 |
+
|
42 |
+
|
43 |
+
def _ncc(pred, gt, window, channel):
|
44 |
+
ntotpx, nviews, nc, h, w = pred.shape
|
45 |
+
flat_pred = pred.view(-1, nc, h, w)
|
46 |
+
mu1 = F.conv2d(flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc)
|
47 |
+
mu2 = F.conv2d(gt, window, padding=0, groups=channel).view(ntotpx, nc)
|
48 |
+
|
49 |
+
mu1_sq = mu1.pow(2)
|
50 |
+
mu2_sq = mu2.pow(2).unsqueeze(1) # (ntotpx, 1, nc)
|
51 |
+
|
52 |
+
sigma1_sq = F.conv2d(flat_pred * flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) - mu1_sq
|
53 |
+
sigma2_sq = F.conv2d(gt * gt, window, padding=0, groups=channel).view(ntotpx, 1, 3) - mu2_sq
|
54 |
+
|
55 |
+
sigma1 = torch.sqrt(sigma1_sq + 1e-4)
|
56 |
+
sigma2 = torch.sqrt(sigma2_sq + 1e-4)
|
57 |
+
|
58 |
+
pred_norm = (pred - mu1[:, :, :, None, None]) / (sigma1[:, :, :, None, None] + 1e-8) # [ntotpx, nviews, nc, h, w]
|
59 |
+
gt_norm = (gt[:, None, :, :, :] - mu2[:, None, :, None, None]) / (
|
60 |
+
sigma2[:, :, :, None, None] + 1e-8) # ntotpx, nc, h, w
|
61 |
+
|
62 |
+
ncc = F.conv2d((pred_norm * gt_norm).view(-1, nc, h, w), window, padding=0, groups=channel).view(
|
63 |
+
ntotpx, nviews, nc)
|
64 |
+
|
65 |
+
return torch.mean(ncc, dim=2)
|
SparseNeuS_demo_v1/models/__init__.py
ADDED
File without changes
|
SparseNeuS_demo_v1/models/embedder.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
|
5 |
+
|
6 |
+
|
7 |
+
class Embedder:
|
8 |
+
def __init__(self, **kwargs):
|
9 |
+
self.kwargs = kwargs
|
10 |
+
self.create_embedding_fn()
|
11 |
+
|
12 |
+
def create_embedding_fn(self):
|
13 |
+
embed_fns = []
|
14 |
+
d = self.kwargs['input_dims']
|
15 |
+
out_dim = 0
|
16 |
+
if self.kwargs['include_input']:
|
17 |
+
embed_fns.append(lambda x: x)
|
18 |
+
out_dim += d
|
19 |
+
|
20 |
+
max_freq = self.kwargs['max_freq_log2']
|
21 |
+
N_freqs = self.kwargs['num_freqs']
|
22 |
+
|
23 |
+
if self.kwargs['log_sampling']:
|
24 |
+
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
25 |
+
else:
|
26 |
+
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
|
27 |
+
|
28 |
+
for freq in freq_bands:
|
29 |
+
for p_fn in self.kwargs['periodic_fns']:
|
30 |
+
if self.kwargs['normalize']:
|
31 |
+
embed_fns.append(lambda x, p_fn=p_fn,
|
32 |
+
freq=freq: p_fn(x * freq) / freq)
|
33 |
+
else:
|
34 |
+
embed_fns.append(lambda x, p_fn=p_fn,
|
35 |
+
freq=freq: p_fn(x * freq))
|
36 |
+
out_dim += d
|
37 |
+
|
38 |
+
self.embed_fns = embed_fns
|
39 |
+
self.out_dim = out_dim
|
40 |
+
|
41 |
+
def embed(self, inputs):
|
42 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
43 |
+
|
44 |
+
|
45 |
+
def get_embedder(multires, normalize=False, input_dims=3):
|
46 |
+
embed_kwargs = {
|
47 |
+
'include_input': True,
|
48 |
+
'input_dims': input_dims,
|
49 |
+
'max_freq_log2': multires - 1,
|
50 |
+
'num_freqs': multires,
|
51 |
+
'normalize': normalize,
|
52 |
+
'log_sampling': True,
|
53 |
+
'periodic_fns': [torch.sin, torch.cos],
|
54 |
+
}
|
55 |
+
|
56 |
+
embedder_obj = Embedder(**embed_kwargs)
|
57 |
+
|
58 |
+
def embed(x, eo=embedder_obj): return eo.embed(x)
|
59 |
+
|
60 |
+
return embed, embedder_obj.out_dim
|
61 |
+
|
62 |
+
|
63 |
+
class Embedding(nn.Module):
|
64 |
+
def __init__(self, in_channels, N_freqs, logscale=True, normalize=False):
|
65 |
+
"""
|
66 |
+
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
|
67 |
+
in_channels: number of input channels (3 for both xyz and direction)
|
68 |
+
"""
|
69 |
+
super(Embedding, self).__init__()
|
70 |
+
self.N_freqs = N_freqs
|
71 |
+
self.in_channels = in_channels
|
72 |
+
self.funcs = [torch.sin, torch.cos]
|
73 |
+
self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
|
74 |
+
self.normalize = normalize
|
75 |
+
|
76 |
+
if logscale:
|
77 |
+
self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
|
78 |
+
else:
|
79 |
+
self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
"""
|
83 |
+
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
|
84 |
+
Different from the paper, "x" is also in the output
|
85 |
+
See https://github.com/bmild/nerf/issues/12
|
86 |
+
|
87 |
+
Inputs:
|
88 |
+
x: (B, self.in_channels)
|
89 |
+
|
90 |
+
Outputs:
|
91 |
+
out: (B, self.out_channels)
|
92 |
+
"""
|
93 |
+
out = [x]
|
94 |
+
for freq in self.freq_bands:
|
95 |
+
for func in self.funcs:
|
96 |
+
if self.normalize:
|
97 |
+
out += [func(freq * x) / freq]
|
98 |
+
else:
|
99 |
+
out += [func(freq * x)]
|
100 |
+
|
101 |
+
return torch.cat(out, -1)
|
SparseNeuS_demo_v1/models/fast_renderer.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from icecream import ic
|
5 |
+
|
6 |
+
|
7 |
+
# - neus: use sphere-tracing to speed up depth maps extraction
|
8 |
+
# This code snippet is heavily borrowed from IDR.
|
9 |
+
class FastRenderer(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(FastRenderer, self).__init__()
|
12 |
+
|
13 |
+
self.sdf_threshold = 5e-5
|
14 |
+
self.line_search_step = 0.5
|
15 |
+
self.line_step_iters = 1
|
16 |
+
self.sphere_tracing_iters = 10
|
17 |
+
self.n_steps = 100
|
18 |
+
self.n_secant_steps = 8
|
19 |
+
|
20 |
+
# - use sdf_network to inference sdf value or directly interpolate sdf value from precomputed sdf_volume
|
21 |
+
self.network_inference = False
|
22 |
+
|
23 |
+
def extract_depth_maps(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
|
24 |
+
with torch.no_grad():
|
25 |
+
curr_start_points, network_object_mask, acc_start_dis = self.get_intersection(
|
26 |
+
rays_o, rays_d, near, far,
|
27 |
+
sdf_network, conditional_volume)
|
28 |
+
|
29 |
+
network_object_mask = network_object_mask.reshape(-1)
|
30 |
+
|
31 |
+
return network_object_mask, acc_start_dis
|
32 |
+
|
33 |
+
def get_intersection(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
|
34 |
+
device = rays_o.device
|
35 |
+
num_pixels, _ = rays_d.shape
|
36 |
+
|
37 |
+
curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \
|
38 |
+
self.sphere_tracing(rays_o, rays_d, near, far, sdf_network, conditional_volume)
|
39 |
+
|
40 |
+
network_object_mask = (acc_start_dis < acc_end_dis)
|
41 |
+
|
42 |
+
# The non convergent rays should be handled by the sampler
|
43 |
+
sampler_mask = unfinished_mask_start
|
44 |
+
sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().to(device)
|
45 |
+
if sampler_mask.sum() > 0:
|
46 |
+
# sampler_min_max = torch.zeros((num_pixels, 2)).to(device)
|
47 |
+
# sampler_min_max[sampler_mask, 0] = acc_start_dis[sampler_mask]
|
48 |
+
# sampler_min_max[sampler_mask, 1] = acc_end_dis[sampler_mask]
|
49 |
+
|
50 |
+
# ray_sampler(self, rays_o, rays_d, near, far, sampler_mask):
|
51 |
+
sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(rays_o,
|
52 |
+
rays_d,
|
53 |
+
acc_start_dis,
|
54 |
+
acc_end_dis,
|
55 |
+
sampler_mask,
|
56 |
+
sdf_network,
|
57 |
+
conditional_volume
|
58 |
+
)
|
59 |
+
|
60 |
+
curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
|
61 |
+
acc_start_dis[sampler_mask] = sampler_dists[sampler_mask][:, None]
|
62 |
+
network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask][:, None]
|
63 |
+
|
64 |
+
# print('----------------------------------------------------------------')
|
65 |
+
# print('RayTracing: object = {0}/{1}, secant on {2}/{3}.'
|
66 |
+
# .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(),
|
67 |
+
# sampler_mask.sum()))
|
68 |
+
# print('----------------------------------------------------------------')
|
69 |
+
|
70 |
+
return curr_start_points, network_object_mask, acc_start_dis
|
71 |
+
|
72 |
+
def sphere_tracing(self, rays_o, rays_d, near, far, sdf_network, conditional_volume):
|
73 |
+
''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection '''
|
74 |
+
|
75 |
+
device = rays_o.device
|
76 |
+
|
77 |
+
unfinished_mask_start = (near < far).reshape(-1).clone()
|
78 |
+
unfinished_mask_end = (near < far).reshape(-1).clone()
|
79 |
+
|
80 |
+
# Initialize start current points
|
81 |
+
curr_start_points = rays_o + rays_d * near
|
82 |
+
acc_start_dis = near.clone()
|
83 |
+
|
84 |
+
# Initialize end current points
|
85 |
+
curr_end_points = rays_o + rays_d * far
|
86 |
+
acc_end_dis = far.clone()
|
87 |
+
|
88 |
+
# Initizlize min and max depth
|
89 |
+
min_dis = acc_start_dis.clone()
|
90 |
+
max_dis = acc_end_dis.clone()
|
91 |
+
|
92 |
+
# Iterate on the rays (from both sides) till finding a surface
|
93 |
+
iters = 0
|
94 |
+
|
95 |
+
next_sdf_start = torch.zeros_like(acc_start_dis).to(device)
|
96 |
+
|
97 |
+
if self.network_inference:
|
98 |
+
sdf_func = sdf_network.sdf
|
99 |
+
else:
|
100 |
+
sdf_func = sdf_network.sdf_from_sdfvolume
|
101 |
+
|
102 |
+
next_sdf_start[unfinished_mask_start] = sdf_func(
|
103 |
+
curr_start_points[unfinished_mask_start],
|
104 |
+
conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
|
105 |
+
|
106 |
+
next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
|
107 |
+
next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
|
108 |
+
conditional_volume, lod=0, gru_fusion=False)[
|
109 |
+
'sdf_pts_scale%d' % 0]
|
110 |
+
|
111 |
+
while True:
|
112 |
+
# Update sdf
|
113 |
+
curr_sdf_start = torch.zeros_like(acc_start_dis).to(device)
|
114 |
+
curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start]
|
115 |
+
curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0
|
116 |
+
|
117 |
+
curr_sdf_end = torch.zeros_like(acc_end_dis).to(device)
|
118 |
+
curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end]
|
119 |
+
curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0
|
120 |
+
|
121 |
+
# Update masks
|
122 |
+
unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold).reshape(-1)
|
123 |
+
unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold).reshape(-1)
|
124 |
+
|
125 |
+
if (
|
126 |
+
unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters:
|
127 |
+
break
|
128 |
+
iters += 1
|
129 |
+
|
130 |
+
# Make step
|
131 |
+
# Update distance
|
132 |
+
acc_start_dis = acc_start_dis + curr_sdf_start
|
133 |
+
acc_end_dis = acc_end_dis - curr_sdf_end
|
134 |
+
|
135 |
+
# Update points
|
136 |
+
curr_start_points = rays_o + acc_start_dis * rays_d
|
137 |
+
curr_end_points = rays_o + acc_end_dis * rays_d
|
138 |
+
|
139 |
+
# Fix points which wrongly crossed the surface
|
140 |
+
next_sdf_start = torch.zeros_like(acc_start_dis).to(device)
|
141 |
+
if unfinished_mask_start.sum() > 0:
|
142 |
+
next_sdf_start[unfinished_mask_start] = sdf_func(curr_start_points[unfinished_mask_start],
|
143 |
+
conditional_volume, lod=0, gru_fusion=False)[
|
144 |
+
'sdf_pts_scale%d' % 0]
|
145 |
+
|
146 |
+
next_sdf_end = torch.zeros_like(acc_end_dis).to(device)
|
147 |
+
if unfinished_mask_end.sum() > 0:
|
148 |
+
next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end],
|
149 |
+
conditional_volume, lod=0, gru_fusion=False)[
|
150 |
+
'sdf_pts_scale%d' % 0]
|
151 |
+
|
152 |
+
not_projected_start = (next_sdf_start < 0).reshape(-1)
|
153 |
+
not_projected_end = (next_sdf_end < 0).reshape(-1)
|
154 |
+
not_proj_iters = 0
|
155 |
+
|
156 |
+
while (
|
157 |
+
not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters:
|
158 |
+
# Step backwards
|
159 |
+
if not_projected_start.sum() > 0:
|
160 |
+
acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
|
161 |
+
curr_sdf_start[not_projected_start]
|
162 |
+
curr_start_points[not_projected_start] = (rays_o + acc_start_dis * rays_d)[not_projected_start]
|
163 |
+
|
164 |
+
next_sdf_start[not_projected_start] = sdf_func(
|
165 |
+
curr_start_points[not_projected_start],
|
166 |
+
conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
|
167 |
+
|
168 |
+
if not_projected_end.sum() > 0:
|
169 |
+
acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \
|
170 |
+
curr_sdf_end[
|
171 |
+
not_projected_end]
|
172 |
+
curr_end_points[not_projected_end] = (rays_o + acc_end_dis * rays_d)[not_projected_end]
|
173 |
+
|
174 |
+
# Calc sdf
|
175 |
+
|
176 |
+
next_sdf_end[not_projected_end] = sdf_func(
|
177 |
+
curr_end_points[not_projected_end],
|
178 |
+
conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]
|
179 |
+
|
180 |
+
# Update mask
|
181 |
+
not_projected_start = (next_sdf_start < 0).reshape(-1)
|
182 |
+
not_projected_end = (next_sdf_end < 0).reshape(-1)
|
183 |
+
not_proj_iters += 1
|
184 |
+
|
185 |
+
unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis).reshape(-1)
|
186 |
+
unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis).reshape(-1)
|
187 |
+
|
188 |
+
return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis
|
189 |
+
|
190 |
+
def ray_sampler(self, rays_o, rays_d, near, far, sampler_mask, sdf_network, conditional_volume):
|
191 |
+
''' Sample the ray in a given range and run secant on rays which have sign transition '''
|
192 |
+
device = rays_o.device
|
193 |
+
num_pixels, _ = rays_d.shape
|
194 |
+
sampler_pts = torch.zeros(num_pixels, 3).to(device).float()
|
195 |
+
sampler_dists = torch.zeros(num_pixels).to(device).float()
|
196 |
+
|
197 |
+
intervals_dist = torch.linspace(0, 1, steps=self.n_steps).to(device).view(1, -1)
|
198 |
+
|
199 |
+
pts_intervals = near + intervals_dist * (far - near)
|
200 |
+
points = rays_o[:, None, :] + pts_intervals[:, :, None] * rays_d[:, None, :]
|
201 |
+
|
202 |
+
# Get the non convergent rays
|
203 |
+
mask_intersect_idx = torch.nonzero(sampler_mask).flatten()
|
204 |
+
points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :]
|
205 |
+
pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask]
|
206 |
+
|
207 |
+
if self.network_inference:
|
208 |
+
sdf_func = sdf_network.sdf
|
209 |
+
else:
|
210 |
+
sdf_func = sdf_network.sdf_from_sdfvolume
|
211 |
+
|
212 |
+
sdf_val_all = []
|
213 |
+
for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0):
|
214 |
+
sdf_val_all.append(sdf_func(pnts,
|
215 |
+
conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0])
|
216 |
+
sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps)
|
217 |
+
|
218 |
+
tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).to(device).float().reshape(
|
219 |
+
(1, self.n_steps)) # Force argmin to return the first min value
|
220 |
+
sampler_pts_ind = torch.argmin(tmp, -1)
|
221 |
+
sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :]
|
222 |
+
sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind]
|
223 |
+
|
224 |
+
net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0)
|
225 |
+
|
226 |
+
# take points with minimal SDF value for P_out pixels
|
227 |
+
p_out_mask = ~net_surface_pts
|
228 |
+
n_p_out = p_out_mask.sum()
|
229 |
+
if n_p_out > 0:
|
230 |
+
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
|
231 |
+
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx,
|
232 |
+
:]
|
233 |
+
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][
|
234 |
+
torch.arange(n_p_out), out_pts_idx]
|
235 |
+
|
236 |
+
# Get Network object mask
|
237 |
+
sampler_net_obj_mask = sampler_mask.clone()
|
238 |
+
sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False
|
239 |
+
|
240 |
+
# Run Secant method
|
241 |
+
secant_pts = net_surface_pts
|
242 |
+
n_secant_pts = secant_pts.sum()
|
243 |
+
if n_secant_pts > 0:
|
244 |
+
# Get secant z predictions
|
245 |
+
z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts]
|
246 |
+
sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts]
|
247 |
+
z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]
|
248 |
+
sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1]
|
249 |
+
|
250 |
+
cam_loc_secant = rays_o[mask_intersect_idx[secant_pts]]
|
251 |
+
ray_directions_secant = rays_d[mask_intersect_idx[secant_pts]]
|
252 |
+
z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant,
|
253 |
+
sdf_network, conditional_volume)
|
254 |
+
|
255 |
+
# Get points
|
256 |
+
sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant[:,
|
257 |
+
None] * ray_directions_secant
|
258 |
+
sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant
|
259 |
+
|
260 |
+
return sampler_pts, sampler_net_obj_mask, sampler_dists
|
261 |
+
|
262 |
+
def secant(self, sdf_low, sdf_high, z_low, z_high, rays_o, rays_d, sdf_network, conditional_volume):
|
263 |
+
''' Runs the secant method for interval [z_low, z_high] for n_secant_steps '''
|
264 |
+
|
265 |
+
if self.network_inference:
|
266 |
+
sdf_func = sdf_network.sdf
|
267 |
+
else:
|
268 |
+
sdf_func = sdf_network.sdf_from_sdfvolume
|
269 |
+
|
270 |
+
z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
|
271 |
+
for i in range(self.n_secant_steps):
|
272 |
+
p_mid = rays_o + z_pred[:, None] * rays_d
|
273 |
+
sdf_mid = sdf_func(p_mid,
|
274 |
+
conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0].reshape(-1)
|
275 |
+
ind_low = (sdf_mid > 0).reshape(-1)
|
276 |
+
if ind_low.sum() > 0:
|
277 |
+
z_low[ind_low] = z_pred[ind_low]
|
278 |
+
sdf_low[ind_low] = sdf_mid[ind_low]
|
279 |
+
ind_high = sdf_mid < 0
|
280 |
+
if ind_high.sum() > 0:
|
281 |
+
z_high[ind_high] = z_pred[ind_high]
|
282 |
+
sdf_high[ind_high] = sdf_mid[ind_high]
|
283 |
+
|
284 |
+
z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
|
285 |
+
|
286 |
+
return z_pred # 1D tensor
|
287 |
+
|
288 |
+
def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis):
|
289 |
+
''' Find points with minimal SDF value on rays for P_out pixels '''
|
290 |
+
device = sdf.device
|
291 |
+
n_mask_points = mask.sum()
|
292 |
+
|
293 |
+
n = self.n_steps
|
294 |
+
# steps = torch.linspace(0.0, 1.0,n).to(device)
|
295 |
+
steps = torch.empty(n).uniform_(0.0, 1.0).to(device)
|
296 |
+
mask_max_dis = max_dis[mask].unsqueeze(-1)
|
297 |
+
mask_min_dis = min_dis[mask].unsqueeze(-1)
|
298 |
+
steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis
|
299 |
+
|
300 |
+
mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask]
|
301 |
+
mask_rays = ray_directions[mask, :]
|
302 |
+
|
303 |
+
mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze(
|
304 |
+
1).repeat(1, n, 1)
|
305 |
+
points = mask_points_all.reshape(-1, 3)
|
306 |
+
|
307 |
+
mask_sdf_all = []
|
308 |
+
for pnts in torch.split(points, 100000, dim=0):
|
309 |
+
mask_sdf_all.append(sdf(pnts))
|
310 |
+
|
311 |
+
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
|
312 |
+
min_vals, min_idx = mask_sdf_all.min(-1)
|
313 |
+
min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx]
|
314 |
+
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
|
315 |
+
|
316 |
+
return min_mask_points, min_mask_dist
|
SparseNeuS_demo_v1/models/featurenet.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# ! amazing!!!! autograd.grad with set_detect_anomaly(True) will cause memory leak
|
4 |
+
# ! https://github.com/pytorch/pytorch/issues/51349
|
5 |
+
# torch.autograd.set_detect_anomaly(True)
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from inplace_abn import InPlaceABN
|
9 |
+
|
10 |
+
|
11 |
+
############################################# MVS Net models ################################################
|
12 |
+
class ConvBnReLU(nn.Module):
|
13 |
+
def __init__(self, in_channels, out_channels,
|
14 |
+
kernel_size=3, stride=1, pad=1,
|
15 |
+
norm_act=InPlaceABN):
|
16 |
+
super(ConvBnReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_channels, out_channels,
|
18 |
+
kernel_size, stride=stride, padding=pad, bias=False)
|
19 |
+
self.bn = norm_act(out_channels)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return self.bn(self.conv(x))
|
23 |
+
|
24 |
+
|
25 |
+
class ConvBnReLU3D(nn.Module):
|
26 |
+
def __init__(self, in_channels, out_channels,
|
27 |
+
kernel_size=3, stride=1, pad=1,
|
28 |
+
norm_act=InPlaceABN):
|
29 |
+
super(ConvBnReLU3D, self).__init__()
|
30 |
+
self.conv = nn.Conv3d(in_channels, out_channels,
|
31 |
+
kernel_size, stride=stride, padding=pad, bias=False)
|
32 |
+
self.bn = norm_act(out_channels)
|
33 |
+
# self.bn = nn.ReLU()
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.bn(self.conv(x))
|
37 |
+
|
38 |
+
|
39 |
+
################################### feature net ######################################
|
40 |
+
class FeatureNet(nn.Module):
|
41 |
+
"""
|
42 |
+
output 3 levels of features using a FPN structure
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, norm_act=InPlaceABN):
|
46 |
+
super(FeatureNet, self).__init__()
|
47 |
+
|
48 |
+
self.conv0 = nn.Sequential(
|
49 |
+
ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
|
50 |
+
ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
|
51 |
+
|
52 |
+
self.conv1 = nn.Sequential(
|
53 |
+
ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
|
54 |
+
ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
|
55 |
+
ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
|
56 |
+
|
57 |
+
self.conv2 = nn.Sequential(
|
58 |
+
ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
|
59 |
+
ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
|
60 |
+
ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
|
61 |
+
|
62 |
+
self.toplayer = nn.Conv2d(32, 32, 1)
|
63 |
+
self.lat1 = nn.Conv2d(16, 32, 1)
|
64 |
+
self.lat0 = nn.Conv2d(8, 32, 1)
|
65 |
+
|
66 |
+
# to reduce channel size of the outputs from FPN
|
67 |
+
self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
|
68 |
+
self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
|
69 |
+
|
70 |
+
def _upsample_add(self, x, y):
|
71 |
+
return F.interpolate(x, scale_factor=2,
|
72 |
+
mode="bilinear", align_corners=True) + y
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
# x: (B, 3, H, W)
|
76 |
+
conv0 = self.conv0(x) # (B, 8, H, W)
|
77 |
+
conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
|
78 |
+
conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
|
79 |
+
feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
|
80 |
+
feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
|
81 |
+
feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
|
82 |
+
|
83 |
+
# reduce output channels
|
84 |
+
feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
|
85 |
+
feat0 = self.smooth0(feat0) # (B, 8, H, W)
|
86 |
+
|
87 |
+
# feats = {"level_0": feat0,
|
88 |
+
# "level_1": feat1,
|
89 |
+
# "level_2": feat2}
|
90 |
+
|
91 |
+
return [feat2, feat1, feat0] # coarser to finer features
|
SparseNeuS_demo_v1/models/fields.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The codes are from NeuS
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from models.embedder import get_embedder
|
8 |
+
|
9 |
+
|
10 |
+
class SDFNetwork(nn.Module):
|
11 |
+
def __init__(self,
|
12 |
+
d_in,
|
13 |
+
d_out,
|
14 |
+
d_hidden,
|
15 |
+
n_layers,
|
16 |
+
skip_in=(4,),
|
17 |
+
multires=0,
|
18 |
+
bias=0.5,
|
19 |
+
scale=1,
|
20 |
+
geometric_init=True,
|
21 |
+
weight_norm=True,
|
22 |
+
activation='softplus',
|
23 |
+
conditional_type='multiply'):
|
24 |
+
super(SDFNetwork, self).__init__()
|
25 |
+
|
26 |
+
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
27 |
+
|
28 |
+
self.embed_fn_fine = None
|
29 |
+
|
30 |
+
if multires > 0:
|
31 |
+
embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False)
|
32 |
+
self.embed_fn_fine = embed_fn
|
33 |
+
dims[0] = input_ch
|
34 |
+
|
35 |
+
self.num_layers = len(dims)
|
36 |
+
self.skip_in = skip_in
|
37 |
+
self.scale = scale
|
38 |
+
|
39 |
+
for l in range(0, self.num_layers - 1):
|
40 |
+
if l + 1 in self.skip_in:
|
41 |
+
out_dim = dims[l + 1] - dims[0]
|
42 |
+
else:
|
43 |
+
out_dim = dims[l + 1]
|
44 |
+
|
45 |
+
lin = nn.Linear(dims[l], out_dim)
|
46 |
+
|
47 |
+
if geometric_init:
|
48 |
+
if l == self.num_layers - 2:
|
49 |
+
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
|
50 |
+
torch.nn.init.constant_(lin.bias, -bias)
|
51 |
+
elif multires > 0 and l == 0:
|
52 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
53 |
+
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
54 |
+
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
55 |
+
elif multires > 0 and l in self.skip_in:
|
56 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
57 |
+
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
58 |
+
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) # ? why dims[0] - 3
|
59 |
+
else:
|
60 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
61 |
+
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
62 |
+
|
63 |
+
if weight_norm:
|
64 |
+
lin = nn.utils.weight_norm(lin)
|
65 |
+
|
66 |
+
setattr(self, "lin" + str(l), lin)
|
67 |
+
|
68 |
+
if activation == 'softplus':
|
69 |
+
self.activation = nn.Softplus(beta=100)
|
70 |
+
else:
|
71 |
+
assert activation == 'relu'
|
72 |
+
self.activation = nn.ReLU()
|
73 |
+
|
74 |
+
def forward(self, inputs):
|
75 |
+
inputs = inputs * self.scale
|
76 |
+
if self.embed_fn_fine is not None:
|
77 |
+
inputs = self.embed_fn_fine(inputs)
|
78 |
+
|
79 |
+
x = inputs
|
80 |
+
for l in range(0, self.num_layers - 1):
|
81 |
+
lin = getattr(self, "lin" + str(l))
|
82 |
+
|
83 |
+
if l in self.skip_in:
|
84 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
85 |
+
|
86 |
+
x = lin(x)
|
87 |
+
|
88 |
+
if l < self.num_layers - 2:
|
89 |
+
x = self.activation(x)
|
90 |
+
return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
|
91 |
+
|
92 |
+
def sdf(self, x):
|
93 |
+
return self.forward(x)[:, :1]
|
94 |
+
|
95 |
+
def sdf_hidden_appearance(self, x):
|
96 |
+
return self.forward(x)
|
97 |
+
|
98 |
+
def gradient(self, x):
|
99 |
+
x.requires_grad_(True)
|
100 |
+
y = self.sdf(x)
|
101 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
102 |
+
gradients = torch.autograd.grad(
|
103 |
+
outputs=y,
|
104 |
+
inputs=x,
|
105 |
+
grad_outputs=d_output,
|
106 |
+
create_graph=True,
|
107 |
+
retain_graph=True,
|
108 |
+
only_inputs=True)[0]
|
109 |
+
return gradients.unsqueeze(1)
|
110 |
+
|
111 |
+
|
112 |
+
class VarianceNetwork(nn.Module):
|
113 |
+
def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0):
|
114 |
+
super(VarianceNetwork, self).__init__()
|
115 |
+
|
116 |
+
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
|
117 |
+
|
118 |
+
self.embed_fn_fine = None
|
119 |
+
|
120 |
+
if multires > 0:
|
121 |
+
embed_fn, input_ch = get_embedder(multires, normalize=False)
|
122 |
+
self.embed_fn_fine = embed_fn
|
123 |
+
dims[0] = input_ch
|
124 |
+
|
125 |
+
self.num_layers = len(dims)
|
126 |
+
self.skip_in = skip_in
|
127 |
+
|
128 |
+
for l in range(0, self.num_layers - 1):
|
129 |
+
if l + 1 in self.skip_in:
|
130 |
+
out_dim = dims[l + 1] - dims[0]
|
131 |
+
else:
|
132 |
+
out_dim = dims[l + 1]
|
133 |
+
|
134 |
+
lin = nn.Linear(dims[l], out_dim)
|
135 |
+
setattr(self, "lin" + str(l), lin)
|
136 |
+
|
137 |
+
self.relu = nn.ReLU()
|
138 |
+
self.softplus = nn.Softplus(beta=100)
|
139 |
+
|
140 |
+
def forward(self, inputs):
|
141 |
+
if self.embed_fn_fine is not None:
|
142 |
+
inputs = self.embed_fn_fine(inputs)
|
143 |
+
|
144 |
+
x = inputs
|
145 |
+
for l in range(0, self.num_layers - 1):
|
146 |
+
lin = getattr(self, "lin" + str(l))
|
147 |
+
|
148 |
+
if l in self.skip_in:
|
149 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
150 |
+
|
151 |
+
x = lin(x)
|
152 |
+
|
153 |
+
if l < self.num_layers - 2:
|
154 |
+
x = self.relu(x)
|
155 |
+
|
156 |
+
# return torch.exp(x)
|
157 |
+
return 1.0 / (self.softplus(x + 0.5) + 1e-3)
|
158 |
+
|
159 |
+
def coarse(self, inputs):
|
160 |
+
return self.forward(inputs)[:, :1]
|
161 |
+
|
162 |
+
def fine(self, inputs):
|
163 |
+
return self.forward(inputs)[:, 1:]
|
164 |
+
|
165 |
+
|
166 |
+
class FixVarianceNetwork(nn.Module):
|
167 |
+
def __init__(self, base):
|
168 |
+
super(FixVarianceNetwork, self).__init__()
|
169 |
+
self.base = base
|
170 |
+
self.iter_step = 0
|
171 |
+
|
172 |
+
def set_iter_step(self, iter_step):
|
173 |
+
self.iter_step = iter_step
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base)
|
177 |
+
|
178 |
+
|
179 |
+
class SingleVarianceNetwork(nn.Module):
|
180 |
+
def __init__(self, init_val=1.0):
|
181 |
+
super(SingleVarianceNetwork, self).__init__()
|
182 |
+
self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0)
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
class RenderingNetwork(nn.Module):
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
d_feature,
|
193 |
+
mode,
|
194 |
+
d_in,
|
195 |
+
d_out,
|
196 |
+
d_hidden,
|
197 |
+
n_layers,
|
198 |
+
weight_norm=True,
|
199 |
+
multires_view=0,
|
200 |
+
squeeze_out=True,
|
201 |
+
d_conditional_colors=0
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
|
205 |
+
self.mode = mode
|
206 |
+
self.squeeze_out = squeeze_out
|
207 |
+
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
|
208 |
+
|
209 |
+
self.embedview_fn = None
|
210 |
+
if multires_view > 0:
|
211 |
+
embedview_fn, input_ch = get_embedder(multires_view)
|
212 |
+
self.embedview_fn = embedview_fn
|
213 |
+
dims[0] += (input_ch - 3)
|
214 |
+
|
215 |
+
self.num_layers = len(dims)
|
216 |
+
|
217 |
+
for l in range(0, self.num_layers - 1):
|
218 |
+
out_dim = dims[l + 1]
|
219 |
+
lin = nn.Linear(dims[l], out_dim)
|
220 |
+
|
221 |
+
if weight_norm:
|
222 |
+
lin = nn.utils.weight_norm(lin)
|
223 |
+
|
224 |
+
setattr(self, "lin" + str(l), lin)
|
225 |
+
|
226 |
+
self.relu = nn.ReLU()
|
227 |
+
|
228 |
+
def forward(self, points, normals, view_dirs, feature_vectors):
|
229 |
+
if self.embedview_fn is not None:
|
230 |
+
view_dirs = self.embedview_fn(view_dirs)
|
231 |
+
|
232 |
+
rendering_input = None
|
233 |
+
|
234 |
+
if self.mode == 'idr':
|
235 |
+
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
|
236 |
+
elif self.mode == 'no_view_dir':
|
237 |
+
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
|
238 |
+
elif self.mode == 'no_normal':
|
239 |
+
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
|
240 |
+
elif self.mode == 'no_points':
|
241 |
+
rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1)
|
242 |
+
elif self.mode == 'no_points_no_view_dir':
|
243 |
+
rendering_input = torch.cat([normals, feature_vectors], dim=-1)
|
244 |
+
|
245 |
+
x = rendering_input
|
246 |
+
|
247 |
+
for l in range(0, self.num_layers - 1):
|
248 |
+
lin = getattr(self, "lin" + str(l))
|
249 |
+
|
250 |
+
x = lin(x)
|
251 |
+
|
252 |
+
if l < self.num_layers - 2:
|
253 |
+
x = self.relu(x)
|
254 |
+
|
255 |
+
if self.squeeze_out:
|
256 |
+
x = torch.sigmoid(x)
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
# Code from nerf-pytorch
|
261 |
+
class NeRF(nn.Module):
|
262 |
+
def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4],
|
263 |
+
use_viewdirs=False):
|
264 |
+
"""
|
265 |
+
"""
|
266 |
+
super(NeRF, self).__init__()
|
267 |
+
self.D = D
|
268 |
+
self.W = W
|
269 |
+
self.d_in = d_in
|
270 |
+
self.d_in_view = d_in_view
|
271 |
+
self.input_ch = 3
|
272 |
+
self.input_ch_view = 3
|
273 |
+
self.embed_fn = None
|
274 |
+
self.embed_fn_view = None
|
275 |
+
|
276 |
+
if multires > 0:
|
277 |
+
embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False)
|
278 |
+
self.embed_fn = embed_fn
|
279 |
+
self.input_ch = input_ch
|
280 |
+
|
281 |
+
if multires_view > 0:
|
282 |
+
embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False)
|
283 |
+
self.embed_fn_view = embed_fn_view
|
284 |
+
self.input_ch_view = input_ch_view
|
285 |
+
|
286 |
+
self.skips = skips
|
287 |
+
self.use_viewdirs = use_viewdirs
|
288 |
+
|
289 |
+
self.pts_linears = nn.ModuleList(
|
290 |
+
[nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W)
|
291 |
+
for i in
|
292 |
+
range(D - 1)])
|
293 |
+
|
294 |
+
### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
|
295 |
+
self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
|
296 |
+
|
297 |
+
### Implementation according to the paper
|
298 |
+
# self.views_linears = nn.ModuleList(
|
299 |
+
# [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
|
300 |
+
|
301 |
+
if use_viewdirs:
|
302 |
+
self.feature_linear = nn.Linear(W, W)
|
303 |
+
self.alpha_linear = nn.Linear(W, 1)
|
304 |
+
self.rgb_linear = nn.Linear(W // 2, 3)
|
305 |
+
else:
|
306 |
+
self.output_linear = nn.Linear(W, output_ch)
|
307 |
+
|
308 |
+
def forward(self, input_pts, input_views):
|
309 |
+
if self.embed_fn is not None:
|
310 |
+
input_pts = self.embed_fn(input_pts)
|
311 |
+
if self.embed_fn_view is not None:
|
312 |
+
input_views = self.embed_fn_view(input_views)
|
313 |
+
|
314 |
+
h = input_pts
|
315 |
+
for i, l in enumerate(self.pts_linears):
|
316 |
+
h = self.pts_linears[i](h)
|
317 |
+
h = F.relu(h)
|
318 |
+
if i in self.skips:
|
319 |
+
h = torch.cat([input_pts, h], -1)
|
320 |
+
|
321 |
+
if self.use_viewdirs:
|
322 |
+
alpha = self.alpha_linear(h)
|
323 |
+
feature = self.feature_linear(h)
|
324 |
+
h = torch.cat([feature, input_views], -1)
|
325 |
+
|
326 |
+
for i, l in enumerate(self.views_linears):
|
327 |
+
h = self.views_linears[i](h)
|
328 |
+
h = F.relu(h)
|
329 |
+
|
330 |
+
rgb = self.rgb_linear(h)
|
331 |
+
return alpha + 1.0, rgb
|
332 |
+
else:
|
333 |
+
assert False
|
SparseNeuS_demo_v1/models/patch_projector.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Patch Projector
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
from models.render_utils import sample_ptsFeatures_from_featureMaps
|
9 |
+
|
10 |
+
|
11 |
+
class PatchProjector():
|
12 |
+
def __init__(self, patch_size):
|
13 |
+
self.h_patch_size = patch_size
|
14 |
+
self.offsets = build_patch_offset(patch_size) # the warping patch offsets index
|
15 |
+
|
16 |
+
self.z_axis = torch.tensor([0, 0, 1]).float()
|
17 |
+
|
18 |
+
self.plane_dist_thresh = 0.001
|
19 |
+
|
20 |
+
# * correctness checked
|
21 |
+
def pixel_warp(self, pts, imgs, intrinsics,
|
22 |
+
w2cs, img_wh=None):
|
23 |
+
"""
|
24 |
+
|
25 |
+
:param pts: [N_rays, n_samples, 3]
|
26 |
+
:param imgs: [N_views, 3, H, W]
|
27 |
+
:param intrinsics: [N_views, 4, 4]
|
28 |
+
:param c2ws: [N_views, 4, 4]
|
29 |
+
:param img_wh:
|
30 |
+
:return:
|
31 |
+
"""
|
32 |
+
if img_wh is None:
|
33 |
+
N_views, _, sizeH, sizeW = imgs.shape
|
34 |
+
img_wh = [sizeW, sizeH]
|
35 |
+
|
36 |
+
pts_color, valid_mask = sample_ptsFeatures_from_featureMaps(
|
37 |
+
pts, imgs, w2cs, intrinsics, img_wh,
|
38 |
+
proj_matrix=None, return_mask=True) # [N_views, c, N_rays, n_samples], [N_views, N_rays, n_samples]
|
39 |
+
|
40 |
+
pts_color = pts_color.permute(2, 3, 0, 1)
|
41 |
+
valid_mask = valid_mask.permute(1, 2, 0)
|
42 |
+
|
43 |
+
return pts_color, valid_mask # [N_rays, n_samples, N_views, 3] , [N_rays, n_samples, N_views]
|
44 |
+
|
45 |
+
def patch_warp(self, pts, uv, normals, src_imgs,
|
46 |
+
ref_intrinsic, src_intrinsics,
|
47 |
+
ref_c2w, src_c2ws, img_wh=None
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
|
51 |
+
:param pts: [N_rays, n_samples, 3]
|
52 |
+
:param uv : [N_rays, 2] normalized in (-1, 1)
|
53 |
+
:param normals: [N_rays, n_samples, 3] The normal of pt in world space
|
54 |
+
:param src_imgs: [N_src, 3, h, w]
|
55 |
+
:param ref_intrinsic: [4,4]
|
56 |
+
:param src_intrinsics: [N_src, 4, 4]
|
57 |
+
:param ref_c2w: [4,4]
|
58 |
+
:param src_c2ws: [N_src, 4, 4]
|
59 |
+
:return:
|
60 |
+
"""
|
61 |
+
device = pts.device
|
62 |
+
|
63 |
+
N_rays, n_samples, _ = pts.shape
|
64 |
+
N_pts = N_rays * n_samples
|
65 |
+
|
66 |
+
N_src, _, sizeH, sizeW = src_imgs.shape
|
67 |
+
|
68 |
+
if img_wh is not None:
|
69 |
+
sizeW, sizeH = img_wh[0], img_wh[1]
|
70 |
+
|
71 |
+
# scale uv from (-1, 1) to (0, W/H)
|
72 |
+
uv[:, 0] = (uv[:, 0] + 1) / 2. * (sizeW - 1)
|
73 |
+
uv[:, 1] = (uv[:, 1] + 1) / 2. * (sizeH - 1)
|
74 |
+
|
75 |
+
ref_intr = ref_intrinsic[:3, :3]
|
76 |
+
inv_ref_intr = torch.inverse(ref_intr)
|
77 |
+
src_intrs = src_intrinsics[:, :3, :3]
|
78 |
+
inv_src_intrs = torch.inverse(src_intrs)
|
79 |
+
|
80 |
+
ref_pose = ref_c2w
|
81 |
+
inv_ref_pose = torch.inverse(ref_pose)
|
82 |
+
src_poses = src_c2ws
|
83 |
+
inv_src_poses = torch.inverse(src_poses)
|
84 |
+
|
85 |
+
ref_cam_loc = ref_pose[:3, 3].unsqueeze(0) # [1, 3]
|
86 |
+
sampled_dists = torch.norm(pts - ref_cam_loc, dim=-1) # [N_pts, 1]
|
87 |
+
|
88 |
+
relative_proj = inv_src_poses @ ref_pose
|
89 |
+
R_rel = relative_proj[:, :3, :3]
|
90 |
+
t_rel = relative_proj[:, :3, 3:]
|
91 |
+
R_ref = inv_ref_pose[:3, :3]
|
92 |
+
t_ref = inv_ref_pose[:3, 3:]
|
93 |
+
|
94 |
+
pts = pts.view(-1, 3)
|
95 |
+
normals = normals.view(-1, 3)
|
96 |
+
|
97 |
+
with torch.no_grad():
|
98 |
+
rot_normals = R_ref @ normals.unsqueeze(-1) # [N_pts, 3, 1]
|
99 |
+
points_in_ref = R_ref @ pts.unsqueeze(
|
100 |
+
-1) + t_ref # [N_pts, 3, 1] points in the reference frame coordiantes system
|
101 |
+
d1 = torch.sum(rot_normals * points_in_ref, dim=1).unsqueeze(
|
102 |
+
1) # distance from the plane to ref camera center
|
103 |
+
|
104 |
+
d2 = torch.sum(rot_normals.unsqueeze(1) * (-R_rel.transpose(1, 2) @ t_rel).unsqueeze(0),
|
105 |
+
dim=2) # distance from the plane to src camera center
|
106 |
+
valid_hom = (torch.abs(d1) > self.plane_dist_thresh) & (
|
107 |
+
torch.abs(d1 - d2) > self.plane_dist_thresh) & ((d2 / d1) < 1)
|
108 |
+
|
109 |
+
d1 = d1.squeeze()
|
110 |
+
sign = torch.sign(d1)
|
111 |
+
sign[sign == 0] = 1
|
112 |
+
d = torch.clamp(torch.abs(d1), 1e-8) * sign
|
113 |
+
|
114 |
+
H = src_intrs.unsqueeze(1) @ (
|
115 |
+
R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ rot_normals.view(1, N_pts, 1, 3) / d.view(1,
|
116 |
+
N_pts,
|
117 |
+
1, 1)
|
118 |
+
) @ inv_ref_intr.view(1, 1, 3, 3)
|
119 |
+
|
120 |
+
# replace invalid homs with fronto-parallel homographies
|
121 |
+
H_invalid = src_intrs.unsqueeze(1) @ (
|
122 |
+
R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ self.z_axis.to(device).view(1, 1, 1, 3).expand(-1, N_pts,
|
123 |
+
-1,
|
124 |
+
-1) / sampled_dists.view(
|
125 |
+
1, N_pts, 1, 1)
|
126 |
+
) @ inv_ref_intr.view(1, 1, 3, 3)
|
127 |
+
tmp_m = ~valid_hom.view(-1, N_src).t()
|
128 |
+
H[tmp_m] = H_invalid[tmp_m]
|
129 |
+
|
130 |
+
pixels = uv.view(N_rays, 1, 2) + self.offsets.float().to(device)
|
131 |
+
Npx = pixels.shape[1]
|
132 |
+
grid, warp_mask_full = self.patch_homography(H, pixels)
|
133 |
+
|
134 |
+
warp_mask_full = warp_mask_full & (grid[..., 0] < (sizeW - self.h_patch_size)) & (
|
135 |
+
grid[..., 1] < (sizeH - self.h_patch_size)) & (grid >= self.h_patch_size).all(dim=-1)
|
136 |
+
warp_mask_full = warp_mask_full.view(N_src, N_rays, n_samples, Npx)
|
137 |
+
|
138 |
+
grid = torch.clamp(normalize(grid, sizeH, sizeW), -10, 10)
|
139 |
+
|
140 |
+
sampled_rgb_val = F.grid_sample(src_imgs, grid.view(N_src, -1, 1, 2), align_corners=True).squeeze(
|
141 |
+
-1).transpose(1, 2)
|
142 |
+
sampled_rgb_val = sampled_rgb_val.view(N_src, N_rays, n_samples, Npx, 3)
|
143 |
+
|
144 |
+
warp_mask_full = warp_mask_full.permute(1, 2, 0, 3).contiguous() # (N_rays, n_samples, N_src, Npx)
|
145 |
+
sampled_rgb_val = sampled_rgb_val.permute(1, 2, 0, 3, 4).contiguous() # (N_rays, n_samples, N_src, Npx, 3)
|
146 |
+
|
147 |
+
return sampled_rgb_val, warp_mask_full
|
148 |
+
|
149 |
+
def patch_homography(self, H, uv):
|
150 |
+
N, Npx = uv.shape[:2]
|
151 |
+
Nsrc = H.shape[0]
|
152 |
+
H = H.view(Nsrc, N, -1, 3, 3)
|
153 |
+
hom_uv = add_hom(uv)
|
154 |
+
|
155 |
+
# einsum is 30 times faster
|
156 |
+
# tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3)
|
157 |
+
tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3)
|
158 |
+
|
159 |
+
grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8)
|
160 |
+
mask = tmp[..., 2] > 0
|
161 |
+
return grid, mask
|
162 |
+
|
163 |
+
|
164 |
+
def add_hom(pts):
|
165 |
+
try:
|
166 |
+
dev = pts.device
|
167 |
+
ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1)
|
168 |
+
return torch.cat((pts, ones), dim=-1)
|
169 |
+
|
170 |
+
except AttributeError:
|
171 |
+
ones = np.ones((pts.shape[0], 1))
|
172 |
+
return np.concatenate((pts, ones), axis=1)
|
173 |
+
|
174 |
+
|
175 |
+
def normalize(flow, h, w, clamp=None):
|
176 |
+
# either h and w are simple float or N torch.tensor where N batch size
|
177 |
+
try:
|
178 |
+
h.device
|
179 |
+
|
180 |
+
except AttributeError:
|
181 |
+
h = torch.tensor(h, device=flow.device).float().unsqueeze(0)
|
182 |
+
w = torch.tensor(w, device=flow.device).float().unsqueeze(0)
|
183 |
+
|
184 |
+
if len(flow.shape) == 4:
|
185 |
+
w = w.unsqueeze(1).unsqueeze(2)
|
186 |
+
h = h.unsqueeze(1).unsqueeze(2)
|
187 |
+
elif len(flow.shape) == 3:
|
188 |
+
w = w.unsqueeze(1)
|
189 |
+
h = h.unsqueeze(1)
|
190 |
+
elif len(flow.shape) == 5:
|
191 |
+
w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
192 |
+
h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
193 |
+
|
194 |
+
res = torch.empty_like(flow)
|
195 |
+
if res.shape[-1] == 3:
|
196 |
+
res[..., 2] = 1
|
197 |
+
|
198 |
+
# for grid_sample with align_corners=True
|
199 |
+
# https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33
|
200 |
+
res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1
|
201 |
+
res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1
|
202 |
+
|
203 |
+
if clamp:
|
204 |
+
return torch.clamp(res, -clamp, clamp)
|
205 |
+
else:
|
206 |
+
return res
|
207 |
+
|
208 |
+
|
209 |
+
def build_patch_offset(h_patch_size):
|
210 |
+
offsets = torch.arange(-h_patch_size, h_patch_size + 1)
|
211 |
+
return torch.stack(torch.meshgrid(offsets, offsets, indexing="ij")[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2
|
SparseNeuS_demo_v1/models/projector.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The codes are partly from IBRNet
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume
|
6 |
+
|
7 |
+
def safe_l2_normalize(x, dim=None, eps=1e-6):
|
8 |
+
return F.normalize(x, p=2, dim=dim, eps=eps)
|
9 |
+
|
10 |
+
class Projector():
|
11 |
+
"""
|
12 |
+
Obtain features from geometryVolume and rendering_feature_maps for generalized rendering
|
13 |
+
"""
|
14 |
+
|
15 |
+
def compute_angle(self, xyz, query_c2w, supporting_c2ws):
|
16 |
+
"""
|
17 |
+
|
18 |
+
:param xyz: [N_rays, n_samples,3 ]
|
19 |
+
:param query_c2w: [1,4,4]
|
20 |
+
:param supporting_c2ws: [n,4,4]
|
21 |
+
:return:
|
22 |
+
"""
|
23 |
+
N_rays, n_samples, _ = xyz.shape
|
24 |
+
num_views = supporting_c2ws.shape[0]
|
25 |
+
xyz = xyz.reshape(-1, 3)
|
26 |
+
|
27 |
+
ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
|
28 |
+
ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6)
|
29 |
+
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
|
30 |
+
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
|
31 |
+
ray_diff = ray2tar_pose - ray2support_pose
|
32 |
+
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
|
33 |
+
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
|
34 |
+
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
|
35 |
+
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
|
36 |
+
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product
|
37 |
+
return ray_diff.detach()
|
38 |
+
|
39 |
+
|
40 |
+
def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws):
|
41 |
+
"""
|
42 |
+
|
43 |
+
:param xyz: [N_rays, n_samples,3 ]
|
44 |
+
:param surface_normals: [N_rays, n_samples,3 ]
|
45 |
+
:param supporting_c2ws: [n,4,4]
|
46 |
+
:return:
|
47 |
+
"""
|
48 |
+
N_rays, n_samples, _ = xyz.shape
|
49 |
+
num_views = supporting_c2ws.shape[0]
|
50 |
+
xyz = xyz.reshape(-1, 3)
|
51 |
+
|
52 |
+
ray2tar_pose = surface_normals
|
53 |
+
ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0))
|
54 |
+
ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6)
|
55 |
+
ray_diff = ray2tar_pose - ray2support_pose
|
56 |
+
ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True)
|
57 |
+
ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True)
|
58 |
+
ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6)
|
59 |
+
ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1)
|
60 |
+
ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product,
|
61 |
+
# and the first three dimensions is the normalized ray diff vector
|
62 |
+
return ray_diff.detach()
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values):
|
66 |
+
"""
|
67 |
+
compute the depth difference of query pts projected on the image and the predicted depth values of the image
|
68 |
+
:param xyz: [N_rays, n_samples,3 ]
|
69 |
+
:param w2cs: [N_views, 4, 4]
|
70 |
+
:param intrinsics: [N_views, 3, 3]
|
71 |
+
:param pred_depth_values: [N_views, N_rays, n_samples,1 ]
|
72 |
+
:param pred_depth_masks: [N_views, N_rays, n_samples]
|
73 |
+
:return:
|
74 |
+
"""
|
75 |
+
device = xyz.device
|
76 |
+
N_views = w2cs.shape[0]
|
77 |
+
N_rays, n_samples, _ = xyz.shape
|
78 |
+
proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :])
|
79 |
+
|
80 |
+
proj_rot = proj_matrix[:, :3, :3]
|
81 |
+
proj_trans = proj_matrix[:, :3, 3:]
|
82 |
+
|
83 |
+
batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1)
|
84 |
+
|
85 |
+
proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans
|
86 |
+
|
87 |
+
# X = proj_xyz[:, 0]
|
88 |
+
# Y = proj_xyz[:, 1]
|
89 |
+
Z = proj_xyz[:, 2].clamp(min=1e-3) # [N_views, N_rays*n_samples]
|
90 |
+
proj_z = Z.view(N_views, N_rays, n_samples, 1)
|
91 |
+
|
92 |
+
z_diff = proj_z - pred_depth_values # [N_views, N_rays, n_samples,1 ]
|
93 |
+
|
94 |
+
return z_diff
|
95 |
+
|
96 |
+
def compute(self,
|
97 |
+
pts,
|
98 |
+
# * 3d geometry feature volumes
|
99 |
+
geometryVolume=None,
|
100 |
+
geometryVolumeMask=None,
|
101 |
+
vol_dims=None,
|
102 |
+
partial_vol_origin=None,
|
103 |
+
vol_size=None,
|
104 |
+
# * 2d rendering feature maps
|
105 |
+
rendering_feature_maps=None,
|
106 |
+
color_maps=None,
|
107 |
+
w2cs=None,
|
108 |
+
intrinsics=None,
|
109 |
+
img_wh=None,
|
110 |
+
query_img_idx=0, # the index of the N_views dim for rendering
|
111 |
+
query_c2w=None,
|
112 |
+
pred_depth_maps=None, # no use here
|
113 |
+
pred_depth_masks=None # no use here
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
extract features of pts for rendering
|
117 |
+
:param pts:
|
118 |
+
:param geometryVolume:
|
119 |
+
:param vol_dims:
|
120 |
+
:param partial_vol_origin:
|
121 |
+
:param vol_size:
|
122 |
+
:param rendering_feature_maps:
|
123 |
+
:param color_maps:
|
124 |
+
:param w2cs:
|
125 |
+
:param intrinsics:
|
126 |
+
:param img_wh:
|
127 |
+
:param rendering_img_idx: by default, we render the first view of w2cs
|
128 |
+
:return:
|
129 |
+
"""
|
130 |
+
device = pts.device
|
131 |
+
c2ws = torch.inverse(w2cs)
|
132 |
+
|
133 |
+
if len(pts.shape) == 2:
|
134 |
+
pts = pts[None, :, :]
|
135 |
+
|
136 |
+
N_rays, n_samples, _ = pts.shape
|
137 |
+
N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
|
138 |
+
|
139 |
+
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
|
140 |
+
query_img_idx = torch.LongTensor([query_img_idx]).to(device)
|
141 |
+
|
142 |
+
if query_c2w is None and query_img_idx > -1:
|
143 |
+
query_c2w = torch.index_select(c2ws, 0, query_img_idx)
|
144 |
+
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
|
145 |
+
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
|
146 |
+
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
|
147 |
+
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
|
148 |
+
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
|
149 |
+
|
150 |
+
if pred_depth_maps is not None:
|
151 |
+
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
|
152 |
+
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
|
153 |
+
# print("N_supporting_views: ", N_views - 1)
|
154 |
+
N_supporting_views = N_views - 1
|
155 |
+
else:
|
156 |
+
supporting_c2ws = c2ws
|
157 |
+
supporting_w2cs = w2cs
|
158 |
+
supporting_rendering_feature_maps = rendering_feature_maps
|
159 |
+
supporting_color_maps = color_maps
|
160 |
+
supporting_intrinsics = intrinsics
|
161 |
+
supporting_depth_maps = pred_depth_masks
|
162 |
+
supporting_depth_masks = pred_depth_masks
|
163 |
+
# print("N_supporting_views: ", N_views)
|
164 |
+
N_supporting_views = N_views
|
165 |
+
# import ipdb; ipdb.set_trace()
|
166 |
+
if geometryVolume is not None:
|
167 |
+
# * sample feature of pts from 3D feature volume
|
168 |
+
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
|
169 |
+
pts, geometryVolume, vol_dims,
|
170 |
+
partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
|
171 |
+
|
172 |
+
if len(geometryVolumeMask.shape) == 3:
|
173 |
+
geometryVolumeMask = geometryVolumeMask[None, :, :, :]
|
174 |
+
|
175 |
+
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
|
176 |
+
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
|
177 |
+
partial_vol_origin, vol_size) # [N_rays, n_samples, C]
|
178 |
+
|
179 |
+
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
|
180 |
+
else:
|
181 |
+
pts_geometry_feature = None
|
182 |
+
pts_geometry_masks = None
|
183 |
+
|
184 |
+
# * sample feature of pts from 2D feature maps
|
185 |
+
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
|
186 |
+
pts, supporting_rendering_feature_maps, supporting_w2cs,
|
187 |
+
supporting_intrinsics, img_wh,
|
188 |
+
return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
|
189 |
+
# import ipdb; ipdb.set_trace()
|
190 |
+
# * size (N_views, N_rays*n_samples, c)
|
191 |
+
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
|
192 |
+
|
193 |
+
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
|
194 |
+
supporting_intrinsics, img_wh)
|
195 |
+
# * size (N_views, N_rays*n_samples, c)
|
196 |
+
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
|
197 |
+
|
198 |
+
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
|
199 |
+
|
200 |
+
|
201 |
+
ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
|
202 |
+
# import ipdb; ipdb.set_trace()
|
203 |
+
if pts_geometry_masks is not None:
|
204 |
+
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
|
205 |
+
pts_rendering_mask # [N_views, N_rays, n_samples]
|
206 |
+
else:
|
207 |
+
final_mask = pts_rendering_mask
|
208 |
+
# import ipdb; ipdb.set_trace()
|
209 |
+
z_diff, pts_pred_depth_masks = None, None
|
210 |
+
|
211 |
+
if pred_depth_maps is not None:
|
212 |
+
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
|
213 |
+
supporting_intrinsics, img_wh)
|
214 |
+
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
|
215 |
+
1).contiguous() # (N_views, N_rays*n_samples, 1)
|
216 |
+
|
217 |
+
# - pts_pred_depth_masks are critical than final_mask,
|
218 |
+
# - the ray containing few invalid pts will be treated invalid
|
219 |
+
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
|
220 |
+
supporting_w2cs,
|
221 |
+
supporting_intrinsics, img_wh)
|
222 |
+
|
223 |
+
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
|
224 |
+
0] # (N_views, N_rays*n_samples)
|
225 |
+
|
226 |
+
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
|
227 |
+
# import ipdb; ipdb.set_trace()
|
228 |
+
return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks
|
229 |
+
|
230 |
+
|
231 |
+
def compute_view_independent(
|
232 |
+
self,
|
233 |
+
pts,
|
234 |
+
# * 3d geometry feature volumes
|
235 |
+
geometryVolume=None,
|
236 |
+
geometryVolumeMask=None,
|
237 |
+
sdf_network=None,
|
238 |
+
lod=0,
|
239 |
+
vol_dims=None,
|
240 |
+
partial_vol_origin=None,
|
241 |
+
vol_size=None,
|
242 |
+
# * 2d rendering feature maps
|
243 |
+
rendering_feature_maps=None,
|
244 |
+
color_maps=None,
|
245 |
+
w2cs=None,
|
246 |
+
target_candidate_w2cs=None,
|
247 |
+
intrinsics=None,
|
248 |
+
img_wh=None,
|
249 |
+
query_img_idx=0, # the index of the N_views dim for rendering
|
250 |
+
query_c2w=None,
|
251 |
+
pred_depth_maps=None, # no use here
|
252 |
+
pred_depth_masks=None # no use here
|
253 |
+
):
|
254 |
+
"""
|
255 |
+
extract features of pts for rendering
|
256 |
+
:param pts:
|
257 |
+
:param geometryVolume:
|
258 |
+
:param vol_dims:
|
259 |
+
:param partial_vol_origin:
|
260 |
+
:param vol_size:
|
261 |
+
:param rendering_feature_maps:
|
262 |
+
:param color_maps:
|
263 |
+
:param w2cs:
|
264 |
+
:param intrinsics:
|
265 |
+
:param img_wh:
|
266 |
+
:param rendering_img_idx: by default, we render the first view of w2cs
|
267 |
+
:return:
|
268 |
+
"""
|
269 |
+
device = pts.device
|
270 |
+
c2ws = torch.inverse(w2cs)
|
271 |
+
|
272 |
+
if len(pts.shape) == 2:
|
273 |
+
pts = pts[None, :, :]
|
274 |
+
|
275 |
+
N_rays, n_samples, _ = pts.shape
|
276 |
+
N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W)
|
277 |
+
|
278 |
+
supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device)
|
279 |
+
query_img_idx = torch.LongTensor([query_img_idx]).to(device)
|
280 |
+
|
281 |
+
if query_c2w is None and query_img_idx > -1:
|
282 |
+
query_c2w = torch.index_select(c2ws, 0, query_img_idx)
|
283 |
+
supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs)
|
284 |
+
supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs)
|
285 |
+
supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs)
|
286 |
+
supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs)
|
287 |
+
supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs)
|
288 |
+
|
289 |
+
if pred_depth_maps is not None:
|
290 |
+
supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs)
|
291 |
+
supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs)
|
292 |
+
# print("N_supporting_views: ", N_views - 1)
|
293 |
+
N_supporting_views = N_views - 1
|
294 |
+
else:
|
295 |
+
supporting_c2ws = c2ws
|
296 |
+
supporting_w2cs = w2cs
|
297 |
+
supporting_rendering_feature_maps = rendering_feature_maps
|
298 |
+
supporting_color_maps = color_maps
|
299 |
+
supporting_intrinsics = intrinsics
|
300 |
+
supporting_depth_maps = pred_depth_masks
|
301 |
+
supporting_depth_masks = pred_depth_masks
|
302 |
+
# print("N_supporting_views: ", N_views)
|
303 |
+
N_supporting_views = N_views
|
304 |
+
# import ipdb; ipdb.set_trace()
|
305 |
+
if geometryVolume is not None:
|
306 |
+
# * sample feature of pts from 3D feature volume
|
307 |
+
pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume(
|
308 |
+
pts, geometryVolume, vol_dims,
|
309 |
+
partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples]
|
310 |
+
|
311 |
+
if len(geometryVolumeMask.shape) == 3:
|
312 |
+
geometryVolumeMask = geometryVolumeMask[None, :, :, :]
|
313 |
+
|
314 |
+
pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume(
|
315 |
+
pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims,
|
316 |
+
partial_vol_origin, vol_size) # [N_rays, n_samples, C]
|
317 |
+
|
318 |
+
pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0)
|
319 |
+
else:
|
320 |
+
pts_geometry_feature = None
|
321 |
+
pts_geometry_masks = None
|
322 |
+
|
323 |
+
# * sample feature of pts from 2D feature maps
|
324 |
+
pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps(
|
325 |
+
pts, supporting_rendering_feature_maps, supporting_w2cs,
|
326 |
+
supporting_intrinsics, img_wh,
|
327 |
+
return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples]
|
328 |
+
|
329 |
+
# * size (N_views, N_rays*n_samples, c)
|
330 |
+
pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous()
|
331 |
+
|
332 |
+
pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs,
|
333 |
+
supporting_intrinsics, img_wh)
|
334 |
+
# * size (N_views, N_rays*n_samples, c)
|
335 |
+
pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous()
|
336 |
+
|
337 |
+
rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c]
|
338 |
+
|
339 |
+
# import ipdb; ipdb.set_trace()
|
340 |
+
|
341 |
+
gradients = sdf_network.gradient(
|
342 |
+
pts.reshape(-1, 3), # pts.squeeze(0),
|
343 |
+
geometryVolume.unsqueeze(0),
|
344 |
+
lod=lod
|
345 |
+
).squeeze()
|
346 |
+
|
347 |
+
surface_normals = safe_l2_normalize(gradients, dim=-1) # [npts, 3]
|
348 |
+
# input normals
|
349 |
+
ren_ray_diff = self.compute_angle_view_independent(
|
350 |
+
xyz=pts,
|
351 |
+
surface_normals=surface_normals,
|
352 |
+
supporting_c2ws=supporting_c2ws
|
353 |
+
)
|
354 |
+
|
355 |
+
# # choose closest target view direction from 32 candidate views
|
356 |
+
# # choose the closest source view as view direction instead of the normals vectors
|
357 |
+
# pts2src_centers = safe_l2_normalize((supporting_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
|
358 |
+
|
359 |
+
# cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
|
360 |
+
# # choose the largest cosine distance as the view direction
|
361 |
+
# max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
|
362 |
+
|
363 |
+
# chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
|
364 |
+
# ren_ray_diff = self.compute_angle_view_independent(
|
365 |
+
# xyz=pts,
|
366 |
+
# surface_normals=chosen_view_direction,
|
367 |
+
# supporting_c2ws=supporting_c2ws
|
368 |
+
# )
|
369 |
+
|
370 |
+
|
371 |
+
|
372 |
+
# # choose closest target view direction from 8 candidate views
|
373 |
+
# # choose the closest source view as view direction instead of the normals vectors
|
374 |
+
# target_candidate_c2ws = torch.inverse(target_candidate_w2cs)
|
375 |
+
# pts2src_centers = safe_l2_normalize((target_candidate_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3]
|
376 |
+
|
377 |
+
# cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1]
|
378 |
+
# # choose the largest cosine distance as the view direction
|
379 |
+
# max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1]
|
380 |
+
|
381 |
+
# chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3]
|
382 |
+
# ren_ray_diff = self.compute_angle_view_independent(
|
383 |
+
# xyz=pts,
|
384 |
+
# surface_normals=chosen_view_direction,
|
385 |
+
# supporting_c2ws=supporting_c2ws
|
386 |
+
# )
|
387 |
+
|
388 |
+
|
389 |
+
# ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4]
|
390 |
+
# import ipdb; ipdb.set_trace()
|
391 |
+
|
392 |
+
|
393 |
+
# input_directions = safe_l2_normalize(pts)
|
394 |
+
# ren_ray_diff = self.compute_angle_view_independent(
|
395 |
+
# xyz=pts,
|
396 |
+
# surface_normals=input_directions,
|
397 |
+
# supporting_c2ws=supporting_c2ws
|
398 |
+
# )
|
399 |
+
|
400 |
+
if pts_geometry_masks is not None:
|
401 |
+
final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \
|
402 |
+
pts_rendering_mask # [N_views, N_rays, n_samples]
|
403 |
+
else:
|
404 |
+
final_mask = pts_rendering_mask
|
405 |
+
# import ipdb; ipdb.set_trace()
|
406 |
+
z_diff, pts_pred_depth_masks = None, None
|
407 |
+
|
408 |
+
if pred_depth_maps is not None:
|
409 |
+
pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs,
|
410 |
+
supporting_intrinsics, img_wh)
|
411 |
+
pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3,
|
412 |
+
1).contiguous() # (N_views, N_rays*n_samples, 1)
|
413 |
+
|
414 |
+
# - pts_pred_depth_masks are critical than final_mask,
|
415 |
+
# - the ray containing few invalid pts will be treated invalid
|
416 |
+
pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(),
|
417 |
+
supporting_w2cs,
|
418 |
+
supporting_intrinsics, img_wh)
|
419 |
+
|
420 |
+
pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :,
|
421 |
+
0] # (N_views, N_rays*n_samples)
|
422 |
+
|
423 |
+
z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values)
|
424 |
+
# import ipdb; ipdb.set_trace()
|
425 |
+
return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks
|
SparseNeuS_demo_v1/models/rays.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def build_patch_offset(h_patch_size):
|
7 |
+
offsets = torch.arange(-h_patch_size, h_patch_size + 1)
|
8 |
+
return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2
|
9 |
+
|
10 |
+
|
11 |
+
def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None):
|
12 |
+
"""
|
13 |
+
generate rays in world space, for image image
|
14 |
+
:param H:
|
15 |
+
:param W:
|
16 |
+
:param intrinsics: [3,3]
|
17 |
+
:param c2ws: [4,4]
|
18 |
+
:return:
|
19 |
+
"""
|
20 |
+
device = image.device
|
21 |
+
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
|
22 |
+
torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij'
|
23 |
+
p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
|
24 |
+
|
25 |
+
# normalized ndc uv coordinates, (-1, 1)
|
26 |
+
ndc_u = 2 * xs / (W - 1) - 1
|
27 |
+
ndc_v = 2 * ys / (H - 1) - 1
|
28 |
+
rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device)
|
29 |
+
|
30 |
+
intrinsic_inv = torch.inverse(intrinsic)
|
31 |
+
|
32 |
+
p = p.view(-1, 3).float().to(device) # N_rays, 3
|
33 |
+
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3
|
34 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3
|
35 |
+
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3
|
36 |
+
rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3
|
37 |
+
|
38 |
+
image = image.permute(1, 2, 0)
|
39 |
+
color = image.view(-1, 3)
|
40 |
+
depth = depth.view(-1, 1) if depth is not None else None
|
41 |
+
mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device)
|
42 |
+
sample = {
|
43 |
+
'rays_o': rays_o,
|
44 |
+
'rays_v': rays_v,
|
45 |
+
'rays_ndc_uv': rays_ndc_uv,
|
46 |
+
'rays_color': color,
|
47 |
+
# 'rays_depth': depth,
|
48 |
+
'rays_mask': mask,
|
49 |
+
'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth
|
50 |
+
}
|
51 |
+
if depth is not None:
|
52 |
+
sample['rays_depth'] = depth
|
53 |
+
|
54 |
+
return sample
|
55 |
+
|
56 |
+
|
57 |
+
def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None,
|
58 |
+
importance_sample=False, h_patch_size=3):
|
59 |
+
"""
|
60 |
+
generate random rays in world space, for a single image
|
61 |
+
:param H:
|
62 |
+
:param W:
|
63 |
+
:param N_rays:
|
64 |
+
:param image: [3, H, W]
|
65 |
+
:param intrinsic: [3,3]
|
66 |
+
:param c2w: [4,4]
|
67 |
+
:param depth: [H, W]
|
68 |
+
:param mask: [H, W]
|
69 |
+
:return:
|
70 |
+
"""
|
71 |
+
device = image.device
|
72 |
+
|
73 |
+
if dilated_mask is None:
|
74 |
+
dilated_mask = mask
|
75 |
+
|
76 |
+
if not importance_sample:
|
77 |
+
pixels_x = torch.randint(low=0, high=W, size=[N_rays])
|
78 |
+
pixels_y = torch.randint(low=0, high=H, size=[N_rays])
|
79 |
+
elif importance_sample and dilated_mask is not None: # sample more pts in the valid mask regions
|
80 |
+
pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4])
|
81 |
+
pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4])
|
82 |
+
|
83 |
+
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
|
84 |
+
torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij'
|
85 |
+
p = torch.stack([xs, ys], dim=-1) # H, W, 2
|
86 |
+
|
87 |
+
try:
|
88 |
+
p_valid = p[dilated_mask > 0] # [num, 2]
|
89 |
+
random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3])
|
90 |
+
except:
|
91 |
+
print("dilated_mask.shape: ", dilated_mask.shape)
|
92 |
+
print("dilated_mask valid number", dilated_mask.sum())
|
93 |
+
|
94 |
+
raise ValueError("hhhh")
|
95 |
+
p_select = p_valid[random_idx] # [N_rays//2, 2]
|
96 |
+
pixels_x_2 = p_select[:, 0]
|
97 |
+
pixels_y_2 = p_select[:, 1]
|
98 |
+
|
99 |
+
pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64)
|
100 |
+
pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64)
|
101 |
+
|
102 |
+
# - crop patch from images
|
103 |
+
offsets = build_patch_offset(h_patch_size).to(device)
|
104 |
+
grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() # [N_pts, Npx, 2]
|
105 |
+
patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * (
|
106 |
+
pixels_y < H - h_patch_size) # [N_pts]
|
107 |
+
grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1
|
108 |
+
grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1
|
109 |
+
grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) # [N_pts, Npx, 2]
|
110 |
+
patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear',
|
111 |
+
padding_mode='zeros',align_corners=True)[0] # [3, N_pts, Npx]
|
112 |
+
patch_color = patch_color.permute(1, 2, 0).contiguous()
|
113 |
+
|
114 |
+
# normalized ndc uv coordinates, (-1, 1)
|
115 |
+
ndc_u = 2 * pixels_x / (W - 1) - 1
|
116 |
+
ndc_v = 2 * pixels_y / (H - 1) - 1
|
117 |
+
rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device)
|
118 |
+
|
119 |
+
image = image.permute(1, 2, 0) # H ,W, C
|
120 |
+
color = image[(pixels_y, pixels_x)] # N_rays, 3
|
121 |
+
|
122 |
+
if mask is not None:
|
123 |
+
mask = mask[(pixels_y, pixels_x)] # N_rays
|
124 |
+
patch_mask = patch_mask * mask # N_rays
|
125 |
+
mask = mask.view(-1, 1)
|
126 |
+
else:
|
127 |
+
mask = torch.ones([N_rays, 1])
|
128 |
+
|
129 |
+
if depth is not None:
|
130 |
+
depth = depth[(pixels_y, pixels_x)] # N_rays
|
131 |
+
depth = depth.view(-1, 1)
|
132 |
+
|
133 |
+
intrinsic_inv = torch.inverse(intrinsic)
|
134 |
+
|
135 |
+
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays, 3
|
136 |
+
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3
|
137 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3
|
138 |
+
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3
|
139 |
+
rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3
|
140 |
+
|
141 |
+
sample = {
|
142 |
+
'rays_o': rays_o,
|
143 |
+
'rays_v': rays_v,
|
144 |
+
'rays_ndc_uv': rays_ndc_uv,
|
145 |
+
'rays_color': color,
|
146 |
+
# 'rays_depth': depth,
|
147 |
+
'rays_mask': mask,
|
148 |
+
'rays_norm_XYZ_cam': p, # - XYZ_cam, before multiply depth,
|
149 |
+
'rays_patch_color': patch_color,
|
150 |
+
'rays_patch_mask': patch_mask.view(-1, 1)
|
151 |
+
}
|
152 |
+
|
153 |
+
if depth is not None:
|
154 |
+
sample['rays_depth'] = depth
|
155 |
+
|
156 |
+
return sample
|
157 |
+
|
158 |
+
|
159 |
+
def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size,
|
160 |
+
image, intrinsic, c2w, depth=None, mask=None):
|
161 |
+
"""
|
162 |
+
generate random rays in world space, for a single image
|
163 |
+
sample rays from local patches
|
164 |
+
:param H:
|
165 |
+
:param W:
|
166 |
+
:param N_rays: the number of center rays of patches
|
167 |
+
:param image: [3, H, W]
|
168 |
+
:param intrinsic: [3,3]
|
169 |
+
:param c2w: [4,4]
|
170 |
+
:param depth: [H, W]
|
171 |
+
:param mask: [H, W]
|
172 |
+
:return:
|
173 |
+
"""
|
174 |
+
device = image.device
|
175 |
+
patch_radius_max = patch_size // 2
|
176 |
+
|
177 |
+
unit_u = 2 / (W - 1)
|
178 |
+
unit_v = 2 / (H - 1)
|
179 |
+
|
180 |
+
pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays])
|
181 |
+
pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays])
|
182 |
+
|
183 |
+
# normalized ndc uv coordinates, (-1, 1)
|
184 |
+
ndc_u_center = 2 * pixels_x_center / (W - 1) - 1
|
185 |
+
ndc_v_center = 2 * pixels_y_center / (H - 1) - 1
|
186 |
+
ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None,
|
187 |
+
:] # [N_rays, 1, 2]
|
188 |
+
|
189 |
+
shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand(
|
190 |
+
[N_rays, num_neighboring_pts, 1]) # uniform distribution of [0,1)
|
191 |
+
shift_u = 2 * (shift_u - 0.5) # mapping to [-1, 1)
|
192 |
+
shift_v = 2 * (shift_v - 0.5)
|
193 |
+
|
194 |
+
# - avoid sample points which are too close to center point
|
195 |
+
shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v],
|
196 |
+
dim=-1) # [N_rays, num_npts, 2]
|
197 |
+
neighboring_pts_uv = ndc_uv_center + shift_uv # [N_rays, num_npts, 2]
|
198 |
+
|
199 |
+
sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) # concat the center point
|
200 |
+
|
201 |
+
# sample the gts
|
202 |
+
color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear',
|
203 |
+
align_corners=True)[0] # [3, N_rays, num_npts]
|
204 |
+
depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear',
|
205 |
+
align_corners=True)[0] # [1, N_rays, num_npts]
|
206 |
+
|
207 |
+
mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest',
|
208 |
+
align_corners=True).to(torch.int64)[0] # [1, N_rays, num_npts]
|
209 |
+
|
210 |
+
intrinsic_inv = torch.inverse(intrinsic)
|
211 |
+
|
212 |
+
sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2)
|
213 |
+
color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3)
|
214 |
+
depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1)
|
215 |
+
mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1)
|
216 |
+
|
217 |
+
pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2
|
218 |
+
pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2
|
219 |
+
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays*num_pts, 3
|
220 |
+
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays*num_pts, 3
|
221 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays*num_pts, 3
|
222 |
+
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays*num_pts, 3
|
223 |
+
rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays*num_pts, 3
|
224 |
+
|
225 |
+
sample = {
|
226 |
+
'rays_o': rays_o,
|
227 |
+
'rays_v': rays_v,
|
228 |
+
'rays_ndc_uv': sampled_pts_uv,
|
229 |
+
'rays_color': color,
|
230 |
+
'rays_depth': depth,
|
231 |
+
'rays_mask': mask,
|
232 |
+
# 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth
|
233 |
+
}
|
234 |
+
|
235 |
+
return sample
|
236 |
+
|
237 |
+
|
238 |
+
def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None):
|
239 |
+
"""
|
240 |
+
|
241 |
+
:param H:
|
242 |
+
:param W:
|
243 |
+
:param N_rays:
|
244 |
+
:param images: [B,3,H,W]
|
245 |
+
:param intrinsics: [B, 3, 3]
|
246 |
+
:param c2ws: [B, 4, 4]
|
247 |
+
:param depths: [B,H,W]
|
248 |
+
:param masks: [B,H,W]
|
249 |
+
:return:
|
250 |
+
"""
|
251 |
+
assert len(images.shape) == 4
|
252 |
+
|
253 |
+
rays_o = []
|
254 |
+
rays_v = []
|
255 |
+
rays_color = []
|
256 |
+
rays_depth = []
|
257 |
+
rays_mask = []
|
258 |
+
for i in range(images.shape[0]):
|
259 |
+
sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i],
|
260 |
+
depth=depths[i] if depths is not None else None,
|
261 |
+
mask=masks[i] if masks is not None else None)
|
262 |
+
rays_o.append(sample['rays_o'])
|
263 |
+
rays_v.append(sample['rays_v'])
|
264 |
+
rays_color.append(sample['rays_color'])
|
265 |
+
if depths is not None:
|
266 |
+
rays_depth.append(sample['rays_depth'])
|
267 |
+
if masks is not None:
|
268 |
+
rays_mask.append(sample['rays_mask'])
|
269 |
+
|
270 |
+
sample = {
|
271 |
+
'rays_o': torch.stack(rays_o, dim=0), # [batch, N_rays, 3]
|
272 |
+
'rays_v': torch.stack(rays_v, dim=0),
|
273 |
+
'rays_color': torch.stack(rays_color, dim=0),
|
274 |
+
'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None,
|
275 |
+
'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None
|
276 |
+
}
|
277 |
+
return sample
|
278 |
+
|
279 |
+
|
280 |
+
from scipy.spatial.transform import Rotation as Rot
|
281 |
+
from scipy.spatial.transform import Slerp
|
282 |
+
|
283 |
+
|
284 |
+
def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1):
|
285 |
+
device = c2w_0.device
|
286 |
+
|
287 |
+
l = resolution_level
|
288 |
+
tx = torch.linspace(0, W - 1, W // l)
|
289 |
+
ty = torch.linspace(0, H - 1, H // l)
|
290 |
+
pixels_x, pixels_y = torch.meshgrid(tx, ty, indexing="ij")
|
291 |
+
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) # W, H, 3
|
292 |
+
|
293 |
+
intrinsic_inv = torch.inverse(intrinsic[:3, :3])
|
294 |
+
p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
295 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
296 |
+
trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio
|
297 |
+
|
298 |
+
pose_0 = c2w_0.detach().cpu().numpy()
|
299 |
+
pose_1 = c2w_1.detach().cpu().numpy()
|
300 |
+
pose_0 = np.linalg.inv(pose_0)
|
301 |
+
pose_1 = np.linalg.inv(pose_1)
|
302 |
+
rot_0 = pose_0[:3, :3]
|
303 |
+
rot_1 = pose_1[:3, :3]
|
304 |
+
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
|
305 |
+
key_times = [0, 1]
|
306 |
+
key_rots = [rot_0, rot_1]
|
307 |
+
slerp = Slerp(key_times, rots)
|
308 |
+
rot = slerp(ratio)
|
309 |
+
pose = np.diag([1.0, 1.0, 1.0, 1.0])
|
310 |
+
pose = pose.astype(np.float32)
|
311 |
+
pose[:3, :3] = rot.as_matrix()
|
312 |
+
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
|
313 |
+
pose = np.linalg.inv(pose)
|
314 |
+
|
315 |
+
c2w = torch.from_numpy(pose).to(device)
|
316 |
+
rot = torch.from_numpy(pose[:3, :3]).cuda()
|
317 |
+
trans = torch.from_numpy(pose[:3, 3]).cuda()
|
318 |
+
rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
319 |
+
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
|
320 |
+
return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3)
|
SparseNeuS_demo_v1/models/render_utils.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from ops.back_project import cam2pixel
|
6 |
+
|
7 |
+
|
8 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
9 |
+
'''
|
10 |
+
:param bins: tensor of shape [N_rays, M+1], M is the number of bins
|
11 |
+
:param weights: tensor of shape [N_rays, M]
|
12 |
+
:param N_samples: number of samples along each ray
|
13 |
+
:param det: if True, will perform deterministic sampling
|
14 |
+
:return: [N_rays, N_samples]
|
15 |
+
'''
|
16 |
+
device = weights.device
|
17 |
+
|
18 |
+
weights = weights + 1e-5 # prevent nans
|
19 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
20 |
+
cdf = torch.cumsum(pdf, -1)
|
21 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1)
|
22 |
+
|
23 |
+
# if bins.shape[1] != weights.shape[1]: # - minor modification, add this constraint
|
24 |
+
# cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1)
|
25 |
+
# Take uniform samples
|
26 |
+
if det:
|
27 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(device)
|
28 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
29 |
+
else:
|
30 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device)
|
31 |
+
|
32 |
+
# Invert CDF
|
33 |
+
u = u.contiguous()
|
34 |
+
# inds = searchsorted(cdf, u, side='right')
|
35 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
36 |
+
|
37 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
38 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
39 |
+
inds_g = torch.stack([below, above], -1) # (batch, n_samples, 2)
|
40 |
+
|
41 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
42 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
43 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
44 |
+
|
45 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
46 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
47 |
+
t = (u - cdf_g[..., 0]) / denom
|
48 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
49 |
+
|
50 |
+
# pdb.set_trace()
|
51 |
+
return samples
|
52 |
+
|
53 |
+
|
54 |
+
def sample_ptsFeatures_from_featureVolume(pts, featureVolume, vol_dims=None, partial_vol_origin=None, vol_size=None):
|
55 |
+
"""
|
56 |
+
sample feature of pts_wrd from featureVolume, all in world space
|
57 |
+
:param pts: [N_rays, n_samples, 3]
|
58 |
+
:param featureVolume: [C,wX,wY,wZ]
|
59 |
+
:param vol_dims: [3] "3" for dimX, dimY, dimZ
|
60 |
+
:param partial_vol_origin: [3]
|
61 |
+
:return: pts_feature: [N_rays, n_samples, C]
|
62 |
+
:return: valid_mask: [N_rays]
|
63 |
+
"""
|
64 |
+
|
65 |
+
N_rays, n_samples, _ = pts.shape
|
66 |
+
|
67 |
+
if vol_dims is None:
|
68 |
+
pts_normalized = pts
|
69 |
+
else:
|
70 |
+
# normalized to (-1, 1)
|
71 |
+
pts_normalized = 2 * (pts - partial_vol_origin[None, None, :]) / (vol_size * (vol_dims[None, None, :] - 1)) - 1
|
72 |
+
|
73 |
+
valid_mask = (torch.abs(pts_normalized[:, :, 0]) < 1.0) & (
|
74 |
+
torch.abs(pts_normalized[:, :, 1]) < 1.0) & (
|
75 |
+
torch.abs(pts_normalized[:, :, 2]) < 1.0) # (N_rays, n_samples)
|
76 |
+
|
77 |
+
pts_normalized = torch.flip(pts_normalized, dims=[-1]) # ! reverse the xyz for grid_sample
|
78 |
+
|
79 |
+
# ! checked grid_sample, (x,y,z) is for (D,H,W), reverse for (W,H,D)
|
80 |
+
pts_feature = F.grid_sample(featureVolume[None, :, :, :, :], pts_normalized[None, None, :, :, :],
|
81 |
+
padding_mode='zeros',
|
82 |
+
align_corners=True).view(-1, N_rays, n_samples) # [C, N_rays, n_samples]
|
83 |
+
|
84 |
+
pts_feature = pts_feature.permute(1, 2, 0) # [N_rays, n_samples, C]
|
85 |
+
return pts_feature, valid_mask
|
86 |
+
|
87 |
+
|
88 |
+
def sample_ptsFeatures_from_featureMaps(pts, featureMaps, w2cs, intrinsics, WH, proj_matrix=None, return_mask=False):
|
89 |
+
"""
|
90 |
+
sample features of pts from 2d feature maps
|
91 |
+
:param pts: [N_rays, N_samples, 3]
|
92 |
+
:param featureMaps: [N_views, C, H, W]
|
93 |
+
:param w2cs: [N_views, 4, 4]
|
94 |
+
:param intrinsics: [N_views, 3, 3]
|
95 |
+
:param proj_matrix: [N_views, 4, 4]
|
96 |
+
:param HW:
|
97 |
+
:return:
|
98 |
+
"""
|
99 |
+
# normalized to (-1, 1)
|
100 |
+
N_rays, n_samples, _ = pts.shape
|
101 |
+
N_views = featureMaps.shape[0]
|
102 |
+
|
103 |
+
if proj_matrix is None:
|
104 |
+
proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :])
|
105 |
+
|
106 |
+
pts = pts.permute(2, 0, 1).contiguous().view(1, 3, N_rays, n_samples).repeat(N_views, 1, 1, 1)
|
107 |
+
pixel_grids = cam2pixel(pts, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
|
108 |
+
'zeros', sizeH=WH[1], sizeW=WH[0]) # (nviews, N_rays, n_samples, 2)
|
109 |
+
|
110 |
+
valid_mask = (torch.abs(pixel_grids[:, :, :, 0]) < 1.0) & (
|
111 |
+
torch.abs(pixel_grids[:, :, :, 1]) < 1.00) # (nviews, N_rays, n_samples)
|
112 |
+
|
113 |
+
pts_feature = F.grid_sample(featureMaps, pixel_grids,
|
114 |
+
padding_mode='zeros',
|
115 |
+
align_corners=True) # [N_views, C, N_rays, n_samples]
|
116 |
+
|
117 |
+
if return_mask:
|
118 |
+
return pts_feature, valid_mask
|
119 |
+
else:
|
120 |
+
return pts_feature
|
SparseNeuS_demo_v1/models/rendering_network.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# the codes are partly borrowed from IBRNet
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
torch._C._jit_set_profiling_executor(False)
|
8 |
+
torch._C._jit_set_profiling_mode(False)
|
9 |
+
|
10 |
+
|
11 |
+
# default tensorflow initialization of linear layers
|
12 |
+
def weights_init(m):
|
13 |
+
if isinstance(m, nn.Linear):
|
14 |
+
nn.init.kaiming_normal_(m.weight.data)
|
15 |
+
if m.bias is not None:
|
16 |
+
nn.init.zeros_(m.bias.data)
|
17 |
+
|
18 |
+
|
19 |
+
@torch.jit.script
|
20 |
+
def fused_mean_variance(x, weight):
|
21 |
+
mean = torch.sum(x * weight, dim=2, keepdim=True)
|
22 |
+
var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True)
|
23 |
+
return mean, var
|
24 |
+
|
25 |
+
|
26 |
+
class GeneralRenderingNetwork(nn.Module):
|
27 |
+
"""
|
28 |
+
This model is not sensitive to finetuning
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True):
|
32 |
+
super(GeneralRenderingNetwork, self).__init__()
|
33 |
+
|
34 |
+
self.in_geometry_feat_ch = in_geometry_feat_ch
|
35 |
+
self.in_rendering_feat_ch = in_rendering_feat_ch
|
36 |
+
self.anti_alias_pooling = anti_alias_pooling
|
37 |
+
|
38 |
+
if self.anti_alias_pooling:
|
39 |
+
self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True)
|
40 |
+
activation_func = nn.ELU(inplace=True)
|
41 |
+
|
42 |
+
self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16),
|
43 |
+
activation_func,
|
44 |
+
nn.Linear(16, in_rendering_feat_ch + 3),
|
45 |
+
activation_func)
|
46 |
+
|
47 |
+
self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64),
|
48 |
+
activation_func,
|
49 |
+
nn.Linear(64, 32),
|
50 |
+
activation_func)
|
51 |
+
|
52 |
+
self.vis_fc = nn.Sequential(nn.Linear(32, 32),
|
53 |
+
activation_func,
|
54 |
+
nn.Linear(32, 33),
|
55 |
+
activation_func,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.vis_fc2 = nn.Sequential(nn.Linear(32, 32),
|
59 |
+
activation_func,
|
60 |
+
nn.Linear(32, 1),
|
61 |
+
nn.Sigmoid()
|
62 |
+
)
|
63 |
+
|
64 |
+
self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16),
|
65 |
+
activation_func,
|
66 |
+
nn.Linear(16, 8),
|
67 |
+
activation_func,
|
68 |
+
nn.Linear(8, 1))
|
69 |
+
|
70 |
+
self.base_fc.apply(weights_init)
|
71 |
+
self.vis_fc2.apply(weights_init)
|
72 |
+
self.vis_fc.apply(weights_init)
|
73 |
+
self.rgb_fc.apply(weights_init)
|
74 |
+
|
75 |
+
def forward(self, geometry_feat, rgb_feat, ray_diff, mask):
|
76 |
+
'''
|
77 |
+
:param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat]
|
78 |
+
:param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat]
|
79 |
+
:param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions,
|
80 |
+
last channel is inner product
|
81 |
+
:param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples]
|
82 |
+
:return: rgb and density output, [n_rays, n_samples, 4]
|
83 |
+
'''
|
84 |
+
|
85 |
+
rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous()
|
86 |
+
ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous()
|
87 |
+
mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous()
|
88 |
+
num_views = rgb_feat.shape[2]
|
89 |
+
geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1)
|
90 |
+
|
91 |
+
direction_feat = self.ray_dir_fc(ray_diff)
|
92 |
+
rgb_in = rgb_feat[..., :3]
|
93 |
+
rgb_feat = rgb_feat + direction_feat
|
94 |
+
|
95 |
+
if self.anti_alias_pooling:
|
96 |
+
_, dot_prod = torch.split(ray_diff, [3, 1], dim=-1)
|
97 |
+
exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1))
|
98 |
+
weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask
|
99 |
+
weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8)
|
100 |
+
else:
|
101 |
+
weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8)
|
102 |
+
|
103 |
+
# compute mean and variance across different views for each point
|
104 |
+
mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat]
|
105 |
+
globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat]
|
106 |
+
|
107 |
+
x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat],
|
108 |
+
dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat]
|
109 |
+
x = self.base_fc(x)
|
110 |
+
|
111 |
+
x_vis = self.vis_fc(x * weight)
|
112 |
+
x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1)
|
113 |
+
vis = F.sigmoid(vis) * mask
|
114 |
+
x = x + x_res
|
115 |
+
vis = self.vis_fc2(x * vis) * mask
|
116 |
+
|
117 |
+
# rgb computation
|
118 |
+
x = torch.cat([x, vis, ray_diff], dim=-1)
|
119 |
+
x = self.rgb_fc(x)
|
120 |
+
x = x.masked_fill(mask == 0, -1e9)
|
121 |
+
blending_weights_valid = F.softmax(x, dim=2) # color blending
|
122 |
+
rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2)
|
123 |
+
|
124 |
+
mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1]
|
125 |
+
mask = torch.sum(mask, dim=2, keepdim=False)
|
126 |
+
mask = mask >= 2 # more than 2 views see the point
|
127 |
+
mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False)
|
128 |
+
valid_mask = mask > 8 # valid rays, more than 8 valid samples
|
129 |
+
return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1)
|
SparseNeuS_demo_v1/models/sparse_neus_renderer.py
ADDED
@@ -0,0 +1,985 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The codes are heavily borrowed from NeuS
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import cv2 as cv
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import numpy as np
|
11 |
+
import logging
|
12 |
+
import mcubes
|
13 |
+
from icecream import ic
|
14 |
+
from models.render_utils import sample_pdf
|
15 |
+
|
16 |
+
from models.projector import Projector
|
17 |
+
from tsparse.torchsparse_utils import sparse_to_dense_channel
|
18 |
+
|
19 |
+
from models.fast_renderer import FastRenderer
|
20 |
+
|
21 |
+
from models.patch_projector import PatchProjector
|
22 |
+
|
23 |
+
|
24 |
+
class SparseNeuSRenderer(nn.Module):
|
25 |
+
"""
|
26 |
+
conditional neus render;
|
27 |
+
optimize on normalized world space;
|
28 |
+
warped by nn.Module to support DataParallel traning
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
rendering_network_outside,
|
33 |
+
sdf_network,
|
34 |
+
variance_network,
|
35 |
+
rendering_network,
|
36 |
+
n_samples,
|
37 |
+
n_importance,
|
38 |
+
n_outside,
|
39 |
+
perturb,
|
40 |
+
alpha_type='div',
|
41 |
+
conf=None
|
42 |
+
):
|
43 |
+
super(SparseNeuSRenderer, self).__init__()
|
44 |
+
|
45 |
+
self.conf = conf
|
46 |
+
self.base_exp_dir = conf['general.base_exp_dir']
|
47 |
+
|
48 |
+
# network setups
|
49 |
+
self.rendering_network_outside = rendering_network_outside
|
50 |
+
self.sdf_network = sdf_network
|
51 |
+
self.variance_network = variance_network
|
52 |
+
self.rendering_network = rendering_network
|
53 |
+
|
54 |
+
self.n_samples = n_samples
|
55 |
+
self.n_importance = n_importance
|
56 |
+
self.n_outside = n_outside
|
57 |
+
self.perturb = perturb
|
58 |
+
self.alpha_type = alpha_type
|
59 |
+
|
60 |
+
self.rendering_projector = Projector() # used to obtain features for generalized rendering
|
61 |
+
|
62 |
+
self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3)
|
63 |
+
self.patch_projector = PatchProjector(self.h_patch_size)
|
64 |
+
|
65 |
+
self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume
|
66 |
+
|
67 |
+
# - fitted rendering or general rendering
|
68 |
+
try:
|
69 |
+
self.if_fitted_rendering = self.sdf_network.if_fitted_rendering
|
70 |
+
except:
|
71 |
+
self.if_fitted_rendering = False
|
72 |
+
|
73 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance,
|
74 |
+
conditional_valid_mask_volume=None):
|
75 |
+
device = rays_o.device
|
76 |
+
batch_size, n_samples = z_vals.shape
|
77 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
78 |
+
|
79 |
+
if conditional_valid_mask_volume is not None:
|
80 |
+
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
|
81 |
+
pts_mask = pts_mask.reshape(batch_size, n_samples)
|
82 |
+
pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1]
|
83 |
+
else:
|
84 |
+
pts_mask = torch.ones([batch_size, n_samples]).to(pts.device)
|
85 |
+
|
86 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
87 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
88 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
89 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
90 |
+
dot_val = None
|
91 |
+
if self.alpha_type == 'uniform':
|
92 |
+
dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0
|
93 |
+
else:
|
94 |
+
dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
95 |
+
prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1)
|
96 |
+
dot_val = torch.stack([prev_dot_val, dot_val], dim=-1)
|
97 |
+
dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False)
|
98 |
+
dot_val = dot_val.clip(-10.0, 0.0) * pts_mask
|
99 |
+
dist = (next_z_vals - prev_z_vals)
|
100 |
+
prev_esti_sdf = mid_sdf - dot_val * dist * 0.5
|
101 |
+
next_esti_sdf = mid_sdf + dot_val * dist * 0.5
|
102 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance)
|
103 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_variance)
|
104 |
+
alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
105 |
+
|
106 |
+
alpha = alpha_sdf
|
107 |
+
|
108 |
+
# - apply pts_mask
|
109 |
+
alpha = pts_mask * alpha
|
110 |
+
|
111 |
+
weights = alpha * torch.cumprod(
|
112 |
+
torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
113 |
+
|
114 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
115 |
+
return z_samples
|
116 |
+
|
117 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
|
118 |
+
sdf_network, gru_fusion,
|
119 |
+
# * related to conditional feature
|
120 |
+
conditional_volume=None,
|
121 |
+
conditional_valid_mask_volume=None
|
122 |
+
):
|
123 |
+
device = rays_o.device
|
124 |
+
batch_size, n_samples = z_vals.shape
|
125 |
+
_, n_importance = new_z_vals.shape
|
126 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
127 |
+
|
128 |
+
if conditional_valid_mask_volume is not None:
|
129 |
+
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume)
|
130 |
+
pts_mask = pts_mask.reshape(batch_size, n_importance)
|
131 |
+
pts_mask_bool = (pts_mask > 0).view(-1)
|
132 |
+
else:
|
133 |
+
pts_mask = torch.ones([batch_size, n_importance]).to(pts.device)
|
134 |
+
|
135 |
+
new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100
|
136 |
+
|
137 |
+
if torch.sum(pts_mask) > 1:
|
138 |
+
new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod)
|
139 |
+
new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance)
|
140 |
+
|
141 |
+
new_sdf = new_sdf.view(batch_size, n_importance)
|
142 |
+
|
143 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
144 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
145 |
+
|
146 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
147 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
148 |
+
index = index.reshape(-1)
|
149 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
150 |
+
|
151 |
+
return z_vals, sdf
|
152 |
+
|
153 |
+
@torch.no_grad()
|
154 |
+
def get_pts_mask_for_conditional_volume(self, pts, mask_volume):
|
155 |
+
"""
|
156 |
+
|
157 |
+
:param pts: [N, 3]
|
158 |
+
:param mask_volume: [1, 1, X, Y, Z]
|
159 |
+
:return:
|
160 |
+
"""
|
161 |
+
num_pts = pts.shape[0]
|
162 |
+
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
|
163 |
+
|
164 |
+
pts = torch.flip(pts, dims=[-1])
|
165 |
+
|
166 |
+
pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts]
|
167 |
+
pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1]
|
168 |
+
|
169 |
+
return pts_mask
|
170 |
+
|
171 |
+
def render_core(self,
|
172 |
+
rays_o,
|
173 |
+
rays_d,
|
174 |
+
z_vals,
|
175 |
+
sample_dist,
|
176 |
+
lod,
|
177 |
+
sdf_network,
|
178 |
+
rendering_network,
|
179 |
+
background_alpha=None, # - no use here
|
180 |
+
background_sampled_color=None, # - no use here
|
181 |
+
background_rgb=None, # - no use here
|
182 |
+
alpha_inter_ratio=0.0,
|
183 |
+
# * related to conditional feature
|
184 |
+
conditional_volume=None,
|
185 |
+
conditional_valid_mask_volume=None,
|
186 |
+
# * 2d feature maps
|
187 |
+
feature_maps=None,
|
188 |
+
color_maps=None,
|
189 |
+
w2cs=None,
|
190 |
+
intrinsics=None,
|
191 |
+
img_wh=None,
|
192 |
+
query_c2w=None, # - used for testing
|
193 |
+
if_general_rendering=True,
|
194 |
+
if_render_with_grad=True,
|
195 |
+
# * used for blending mlp rendering network
|
196 |
+
img_index=None,
|
197 |
+
rays_uv=None,
|
198 |
+
# * used for clear bg and fg
|
199 |
+
bg_num=0
|
200 |
+
):
|
201 |
+
device = rays_o.device
|
202 |
+
N_rays = rays_o.shape[0]
|
203 |
+
_, n_samples = z_vals.shape
|
204 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
205 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1)
|
206 |
+
|
207 |
+
mid_z_vals = z_vals + dists * 0.5
|
208 |
+
mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1]
|
209 |
+
|
210 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
211 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
212 |
+
|
213 |
+
pts = pts.reshape(-1, 3)
|
214 |
+
dirs = dirs.reshape(-1, 3)
|
215 |
+
|
216 |
+
# * if conditional_volume is restored from sparse volume, need mask for pts
|
217 |
+
if conditional_valid_mask_volume is not None:
|
218 |
+
pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume)
|
219 |
+
pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach()
|
220 |
+
pts_mask_bool = (pts_mask > 0).view(-1)
|
221 |
+
|
222 |
+
if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem
|
223 |
+
pts_mask_bool[:100] = True
|
224 |
+
|
225 |
+
else:
|
226 |
+
pts_mask = torch.ones([N_rays, n_samples]).to(pts.device)
|
227 |
+
# import ipdb; ipdb.set_trace()
|
228 |
+
# pts_valid = pts[pts_mask_bool]
|
229 |
+
sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod)
|
230 |
+
|
231 |
+
sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100
|
232 |
+
sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1]
|
233 |
+
feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod]
|
234 |
+
feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device)
|
235 |
+
feature_vector[pts_mask_bool] = feature_vector_valid
|
236 |
+
|
237 |
+
# * estimate alpha from sdf
|
238 |
+
gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
|
239 |
+
# import ipdb; ipdb.set_trace()
|
240 |
+
gradients[pts_mask_bool] = sdf_network.gradient(
|
241 |
+
pts[pts_mask_bool], conditional_volume, lod=lod).squeeze()
|
242 |
+
|
243 |
+
sampled_color_mlp = None
|
244 |
+
rendering_valid_mask_mlp = None
|
245 |
+
sampled_color_patch = None
|
246 |
+
rendering_patch_mask = None
|
247 |
+
|
248 |
+
if self.if_fitted_rendering: # used for fine-tuning
|
249 |
+
position_latent = sdf_nn_output['sampled_latent_scale%d' % lod]
|
250 |
+
sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device)
|
251 |
+
sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device)
|
252 |
+
|
253 |
+
# - extract pixel
|
254 |
+
pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp(
|
255 |
+
pts[pts_mask_bool][:, None, :], color_maps, intrinsics,
|
256 |
+
w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views]
|
257 |
+
pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3]
|
258 |
+
pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views]
|
259 |
+
|
260 |
+
# - extract patch
|
261 |
+
if_patch_blending = False if rays_uv is None else True
|
262 |
+
pts_patch_color, pts_patch_mask = None, None
|
263 |
+
if if_patch_blending:
|
264 |
+
pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp(
|
265 |
+
pts.reshape([N_rays, n_samples, 3]),
|
266 |
+
rays_uv, gradients.reshape([N_rays, n_samples, 3]),
|
267 |
+
color_maps,
|
268 |
+
intrinsics[0], intrinsics,
|
269 |
+
query_c2w[0], torch.inverse(w2cs), img_wh=None
|
270 |
+
) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx)
|
271 |
+
N_src, Npx = pts_patch_mask.shape[2:]
|
272 |
+
pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool]
|
273 |
+
pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool]
|
274 |
+
|
275 |
+
sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device)
|
276 |
+
sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device)
|
277 |
+
|
278 |
+
sampled_color_mlp_, sampled_color_mlp_mask_, \
|
279 |
+
sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend(
|
280 |
+
pts[pts_mask_bool],
|
281 |
+
position_latent,
|
282 |
+
gradients[pts_mask_bool],
|
283 |
+
dirs[pts_mask_bool],
|
284 |
+
feature_vector[pts_mask_bool],
|
285 |
+
img_index=img_index,
|
286 |
+
pts_pixel_color=pts_pixel_color,
|
287 |
+
pts_pixel_mask=pts_pixel_mask,
|
288 |
+
pts_patch_color=pts_patch_color,
|
289 |
+
pts_patch_mask=pts_patch_mask
|
290 |
+
|
291 |
+
) # [n, 3], [n, 1]
|
292 |
+
sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_
|
293 |
+
sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float()
|
294 |
+
sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3)
|
295 |
+
sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples)
|
296 |
+
rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5
|
297 |
+
|
298 |
+
# patch blending
|
299 |
+
if if_patch_blending:
|
300 |
+
sampled_color_patch[pts_mask_bool] = sampled_color_patch_
|
301 |
+
sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float()
|
302 |
+
sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3)
|
303 |
+
sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples)
|
304 |
+
rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1,
|
305 |
+
keepdim=True) > 0.5 # [N_rays, 1]
|
306 |
+
else:
|
307 |
+
sampled_color_patch, rendering_patch_mask = None, None
|
308 |
+
|
309 |
+
if if_general_rendering: # used for general training
|
310 |
+
# [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4]
|
311 |
+
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute(
|
312 |
+
pts.view(N_rays, n_samples, 3),
|
313 |
+
# * 3d geometry feature volumes
|
314 |
+
geometryVolume=conditional_volume[0],
|
315 |
+
geometryVolumeMask=conditional_valid_mask_volume[0],
|
316 |
+
# * 2d rendering feature maps
|
317 |
+
rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256]
|
318 |
+
color_maps=color_maps,
|
319 |
+
w2cs=w2cs,
|
320 |
+
intrinsics=intrinsics,
|
321 |
+
img_wh=img_wh,
|
322 |
+
query_img_idx=0, # the index of the N_views dim for rendering
|
323 |
+
query_c2w=query_c2w,
|
324 |
+
)
|
325 |
+
|
326 |
+
# (N_rays, n_samples, 3)
|
327 |
+
if if_render_with_grad:
|
328 |
+
# import ipdb; ipdb.set_trace()
|
329 |
+
# [nrays, 3] [nrays, 1]
|
330 |
+
sampled_color, rendering_valid_mask = rendering_network(
|
331 |
+
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
|
332 |
+
# import ipdb; ipdb.set_trace()
|
333 |
+
else:
|
334 |
+
with torch.no_grad():
|
335 |
+
sampled_color, rendering_valid_mask = rendering_network(
|
336 |
+
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
|
337 |
+
else:
|
338 |
+
sampled_color, rendering_valid_mask = None, None
|
339 |
+
|
340 |
+
inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6)
|
341 |
+
|
342 |
+
true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate
|
343 |
+
|
344 |
+
iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu(
|
345 |
+
-true_dot_val) * alpha_inter_ratio) # always non-positive
|
346 |
+
|
347 |
+
iter_cos = iter_cos * pts_mask.view(-1, 1)
|
348 |
+
|
349 |
+
true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
|
350 |
+
true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5
|
351 |
+
|
352 |
+
prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance)
|
353 |
+
next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance)
|
354 |
+
|
355 |
+
p = prev_cdf - next_cdf
|
356 |
+
c = prev_cdf
|
357 |
+
|
358 |
+
if self.alpha_type == 'div':
|
359 |
+
alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0)
|
360 |
+
elif self.alpha_type == 'uniform':
|
361 |
+
uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5
|
362 |
+
uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5
|
363 |
+
uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance)
|
364 |
+
uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance)
|
365 |
+
uniform_alpha = F.relu(
|
366 |
+
(uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape(
|
367 |
+
N_rays, n_samples).clip(0.0, 1.0)
|
368 |
+
alpha_sdf = uniform_alpha
|
369 |
+
else:
|
370 |
+
assert False
|
371 |
+
|
372 |
+
alpha = alpha_sdf
|
373 |
+
|
374 |
+
# - apply pts_mask
|
375 |
+
alpha = alpha * pts_mask
|
376 |
+
|
377 |
+
# pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples)
|
378 |
+
# inside_sphere = (pts_radius < 1.0).float().detach()
|
379 |
+
# relax_inside_sphere = (pts_radius < 1.2).float().detach()
|
380 |
+
inside_sphere = pts_mask
|
381 |
+
relax_inside_sphere = pts_mask
|
382 |
+
|
383 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:,
|
384 |
+
:-1] # n_rays, n_samples
|
385 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
386 |
+
alpha_sum = alpha.sum(dim=-1, keepdim=True)
|
387 |
+
|
388 |
+
if bg_num > 0:
|
389 |
+
weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True)
|
390 |
+
else:
|
391 |
+
weights_sum_fg = weights_sum
|
392 |
+
|
393 |
+
if sampled_color is not None:
|
394 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
395 |
+
else:
|
396 |
+
color = None
|
397 |
+
# import ipdb; ipdb.set_trace()
|
398 |
+
|
399 |
+
if background_rgb is not None and color is not None:
|
400 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
401 |
+
# print("color device:" + str(color.device))
|
402 |
+
# if color is not None:
|
403 |
+
# # import ipdb; ipdb.set_trace()
|
404 |
+
# color = color + (1.0 - weights_sum)
|
405 |
+
|
406 |
+
|
407 |
+
###################* mlp color rendering #####################
|
408 |
+
color_mlp = None
|
409 |
+
# import ipdb; ipdb.set_trace()
|
410 |
+
if sampled_color_mlp is not None:
|
411 |
+
color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1)
|
412 |
+
|
413 |
+
if background_rgb is not None and color_mlp is not None:
|
414 |
+
color_mlp = color_mlp + background_rgb * (1.0 - weights_sum)
|
415 |
+
|
416 |
+
############################ * patch blending ################
|
417 |
+
blended_color_patch = None
|
418 |
+
if sampled_color_patch is not None:
|
419 |
+
blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3]
|
420 |
+
|
421 |
+
######################################################
|
422 |
+
|
423 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2,
|
424 |
+
dim=-1) - 1.0) ** 2
|
425 |
+
# ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized
|
426 |
+
gradient_error = (pts_mask * gradient_error).sum() / (
|
427 |
+
(pts_mask).sum() + 1e-5)
|
428 |
+
|
429 |
+
depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
|
430 |
+
# print("[TEST]: weights_sum in render_core", weights_sum.mean())
|
431 |
+
# print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum())
|
432 |
+
# if weights_sum.isnan().sum() > 0:
|
433 |
+
# import ipdb; ipdb.set_trace()
|
434 |
+
return {
|
435 |
+
'color': color,
|
436 |
+
'color_mask': rendering_valid_mask, # (N_rays, 1)
|
437 |
+
'color_mlp': color_mlp,
|
438 |
+
'color_mlp_mask': rendering_valid_mask_mlp,
|
439 |
+
'sdf': sdf, # (N_rays, n_samples)
|
440 |
+
'depth': depth, # (N_rays, 1)
|
441 |
+
'dists': dists,
|
442 |
+
'gradients': gradients.reshape(N_rays, n_samples, 3),
|
443 |
+
'variance': 1.0 / inv_variance,
|
444 |
+
'mid_z_vals': mid_z_vals,
|
445 |
+
'weights': weights,
|
446 |
+
'weights_sum': weights_sum,
|
447 |
+
'alpha_sum': alpha_sum,
|
448 |
+
'alpha_mean': alpha.mean(),
|
449 |
+
'cdf': c.reshape(N_rays, n_samples),
|
450 |
+
'gradient_error': gradient_error,
|
451 |
+
'inside_sphere': inside_sphere,
|
452 |
+
'blended_color_patch': blended_color_patch,
|
453 |
+
'blended_color_patch_mask': rendering_patch_mask,
|
454 |
+
'weights_sum_fg': weights_sum_fg
|
455 |
+
}
|
456 |
+
|
457 |
+
def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network,
|
458 |
+
perturb_overwrite=-1,
|
459 |
+
background_rgb=None,
|
460 |
+
alpha_inter_ratio=0.0,
|
461 |
+
# * related to conditional feature
|
462 |
+
lod=None,
|
463 |
+
conditional_volume=None,
|
464 |
+
conditional_valid_mask_volume=None,
|
465 |
+
# * 2d feature maps
|
466 |
+
feature_maps=None,
|
467 |
+
color_maps=None,
|
468 |
+
w2cs=None,
|
469 |
+
intrinsics=None,
|
470 |
+
img_wh=None,
|
471 |
+
query_c2w=None, # -used for testing
|
472 |
+
if_general_rendering=True,
|
473 |
+
if_render_with_grad=True,
|
474 |
+
# * used for blending mlp rendering network
|
475 |
+
img_index=None,
|
476 |
+
rays_uv=None,
|
477 |
+
# * importance sample for second lod network
|
478 |
+
pre_sample=False, # no use here
|
479 |
+
# * for clear foreground
|
480 |
+
bg_ratio=0.0
|
481 |
+
):
|
482 |
+
device = rays_o.device
|
483 |
+
N_rays = len(rays_o)
|
484 |
+
# sample_dist = 2.0 / self.n_samples
|
485 |
+
sample_dist = ((far - near) / self.n_samples).mean().item()
|
486 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device)
|
487 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
488 |
+
|
489 |
+
bg_num = int(self.n_samples * bg_ratio)
|
490 |
+
|
491 |
+
if z_vals.shape[0] == 1:
|
492 |
+
z_vals = z_vals.repeat(N_rays, 1)
|
493 |
+
|
494 |
+
if bg_num > 0:
|
495 |
+
z_vals_bg = z_vals[:, self.n_samples - bg_num:]
|
496 |
+
z_vals = z_vals[:, :self.n_samples - bg_num]
|
497 |
+
|
498 |
+
n_samples = self.n_samples - bg_num
|
499 |
+
perturb = self.perturb
|
500 |
+
|
501 |
+
# - significantly speed up training, for the second lod network
|
502 |
+
if pre_sample:
|
503 |
+
z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far,
|
504 |
+
conditional_valid_mask_volume)
|
505 |
+
|
506 |
+
if perturb_overwrite >= 0:
|
507 |
+
perturb = perturb_overwrite
|
508 |
+
if perturb > 0:
|
509 |
+
# get intervals between samples
|
510 |
+
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
511 |
+
upper = torch.cat([mids, z_vals[..., -1:]], -1)
|
512 |
+
lower = torch.cat([z_vals[..., :1], mids], -1)
|
513 |
+
# stratified samples in those intervals
|
514 |
+
t_rand = torch.rand(z_vals.shape).to(device)
|
515 |
+
z_vals = lower + (upper - lower) * t_rand
|
516 |
+
|
517 |
+
background_alpha = None
|
518 |
+
background_sampled_color = None
|
519 |
+
z_val_before = z_vals.clone()
|
520 |
+
# Up sample
|
521 |
+
if self.n_importance > 0:
|
522 |
+
with torch.no_grad():
|
523 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
524 |
+
|
525 |
+
sdf_outputs = sdf_network.sdf(
|
526 |
+
pts.reshape(-1, 3), conditional_volume, lod=lod)
|
527 |
+
# pdb.set_trace()
|
528 |
+
sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num)
|
529 |
+
|
530 |
+
n_steps = 4
|
531 |
+
for i in range(n_steps):
|
532 |
+
new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps,
|
533 |
+
64 * 2 ** i,
|
534 |
+
conditional_valid_mask_volume=conditional_valid_mask_volume,
|
535 |
+
)
|
536 |
+
|
537 |
+
# if new_z_vals.isnan().sum() > 0:
|
538 |
+
# import ipdb; ipdb.set_trace()
|
539 |
+
|
540 |
+
z_vals, sdf = self.cat_z_vals(
|
541 |
+
rays_o, rays_d, z_vals, new_z_vals, sdf, lod,
|
542 |
+
sdf_network, gru_fusion=False,
|
543 |
+
conditional_volume=conditional_volume,
|
544 |
+
conditional_valid_mask_volume=conditional_valid_mask_volume,
|
545 |
+
)
|
546 |
+
|
547 |
+
del sdf
|
548 |
+
|
549 |
+
n_samples = self.n_samples + self.n_importance
|
550 |
+
|
551 |
+
# Background
|
552 |
+
ret_outside = None
|
553 |
+
|
554 |
+
# Render
|
555 |
+
if bg_num > 0:
|
556 |
+
z_vals = torch.cat([z_vals, z_vals_bg], dim=1)
|
557 |
+
# if z_vals.isnan().sum() > 0:
|
558 |
+
# import ipdb; ipdb.set_trace()
|
559 |
+
ret_fine = self.render_core(rays_o,
|
560 |
+
rays_d,
|
561 |
+
z_vals,
|
562 |
+
sample_dist,
|
563 |
+
lod,
|
564 |
+
sdf_network,
|
565 |
+
rendering_network,
|
566 |
+
background_rgb=background_rgb,
|
567 |
+
background_alpha=background_alpha,
|
568 |
+
background_sampled_color=background_sampled_color,
|
569 |
+
alpha_inter_ratio=alpha_inter_ratio,
|
570 |
+
# * related to conditional feature
|
571 |
+
conditional_volume=conditional_volume,
|
572 |
+
conditional_valid_mask_volume=conditional_valid_mask_volume,
|
573 |
+
# * 2d feature maps
|
574 |
+
feature_maps=feature_maps,
|
575 |
+
color_maps=color_maps,
|
576 |
+
w2cs=w2cs,
|
577 |
+
intrinsics=intrinsics,
|
578 |
+
img_wh=img_wh,
|
579 |
+
query_c2w=query_c2w,
|
580 |
+
if_general_rendering=if_general_rendering,
|
581 |
+
if_render_with_grad=if_render_with_grad,
|
582 |
+
# * used for blending mlp rendering network
|
583 |
+
img_index=img_index,
|
584 |
+
rays_uv=rays_uv
|
585 |
+
)
|
586 |
+
|
587 |
+
color_fine = ret_fine['color']
|
588 |
+
|
589 |
+
if self.n_outside > 0:
|
590 |
+
color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask'])
|
591 |
+
else:
|
592 |
+
color_fine_mask = ret_fine['color_mask']
|
593 |
+
|
594 |
+
weights = ret_fine['weights']
|
595 |
+
weights_sum = ret_fine['weights_sum']
|
596 |
+
|
597 |
+
gradients = ret_fine['gradients']
|
598 |
+
mid_z_vals = ret_fine['mid_z_vals']
|
599 |
+
|
600 |
+
# depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True)
|
601 |
+
depth = ret_fine['depth']
|
602 |
+
depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True)
|
603 |
+
variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True)
|
604 |
+
|
605 |
+
# - randomly sample points from the volume, and maximize the sdf
|
606 |
+
pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1)
|
607 |
+
sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod]
|
608 |
+
|
609 |
+
result = {
|
610 |
+
'depth': depth,
|
611 |
+
'color_fine': color_fine,
|
612 |
+
'color_fine_mask': color_fine_mask,
|
613 |
+
'color_outside': ret_outside['color'] if ret_outside is not None else None,
|
614 |
+
'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None,
|
615 |
+
'color_mlp': ret_fine['color_mlp'],
|
616 |
+
'color_mlp_mask': ret_fine['color_mlp_mask'],
|
617 |
+
'variance': variance.mean(),
|
618 |
+
'cdf_fine': ret_fine['cdf'],
|
619 |
+
'depth_variance': depth_varaince,
|
620 |
+
'weights_sum': weights_sum,
|
621 |
+
'weights_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
622 |
+
'alpha_sum': ret_fine['alpha_sum'].mean(),
|
623 |
+
'alpha_mean': ret_fine['alpha_mean'],
|
624 |
+
'gradients': gradients,
|
625 |
+
'weights': weights,
|
626 |
+
'gradient_error_fine': ret_fine['gradient_error'],
|
627 |
+
'inside_sphere': ret_fine['inside_sphere'],
|
628 |
+
'sdf': ret_fine['sdf'],
|
629 |
+
'sdf_random': sdf_random,
|
630 |
+
'blended_color_patch': ret_fine['blended_color_patch'],
|
631 |
+
'blended_color_patch_mask': ret_fine['blended_color_patch_mask'],
|
632 |
+
'weights_sum_fg': ret_fine['weights_sum_fg']
|
633 |
+
}
|
634 |
+
|
635 |
+
return result
|
636 |
+
|
637 |
+
@torch.no_grad()
|
638 |
+
def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume):
|
639 |
+
# ? based on sdf to do importance sampling, seems that too biased on pre-estimation
|
640 |
+
device = rays_o.device
|
641 |
+
N_rays = len(rays_o)
|
642 |
+
n_samples = self.n_samples * 2
|
643 |
+
|
644 |
+
z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
|
645 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
646 |
+
|
647 |
+
if z_vals.shape[0] == 1:
|
648 |
+
z_vals = z_vals.repeat(N_rays, 1)
|
649 |
+
|
650 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
651 |
+
|
652 |
+
sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples])
|
653 |
+
|
654 |
+
new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples,
|
655 |
+
200,
|
656 |
+
conditional_valid_mask_volume=mask_volume,
|
657 |
+
)
|
658 |
+
return new_z_vals
|
659 |
+
|
660 |
+
@torch.no_grad()
|
661 |
+
def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use
|
662 |
+
device = rays_o.device
|
663 |
+
N_rays = len(rays_o)
|
664 |
+
n_samples = self.n_samples * 2
|
665 |
+
|
666 |
+
z_vals = torch.linspace(0.0, 1.0, n_samples).to(device)
|
667 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
668 |
+
|
669 |
+
if z_vals.shape[0] == 1:
|
670 |
+
z_vals = z_vals.repeat(N_rays, 1)
|
671 |
+
|
672 |
+
mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5
|
673 |
+
|
674 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None]
|
675 |
+
|
676 |
+
pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape(
|
677 |
+
[N_rays, n_samples - 1])
|
678 |
+
|
679 |
+
# empty voxel set to 0.1, non-empty voxel set to 1
|
680 |
+
weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device),
|
681 |
+
0.1 * torch.ones_like(pts_mask).to(device))
|
682 |
+
|
683 |
+
# sample more pts in non-empty voxels
|
684 |
+
z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach()
|
685 |
+
return z_samples
|
686 |
+
|
687 |
+
@torch.no_grad()
|
688 |
+
def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices,
|
689 |
+
partial_vol_origin, voxel_size,
|
690 |
+
near, far, depth_interval, d_plane_nums):
|
691 |
+
"""
|
692 |
+
Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless)
|
693 |
+
:param coords: [n, 3] int coords
|
694 |
+
:param pred_depth_maps: [N_views, 1, h, w]
|
695 |
+
:param proj_matrices: [N_views, 4, 4]
|
696 |
+
:param partial_vol_origin: [3]
|
697 |
+
:param voxel_size: 1
|
698 |
+
:param near: 1
|
699 |
+
:param far: 1
|
700 |
+
:param depth_interval: 1
|
701 |
+
:param d_plane_nums: 1
|
702 |
+
:return:
|
703 |
+
"""
|
704 |
+
device = pred_depth_maps.device
|
705 |
+
n_views, _, sizeH, sizeW = pred_depth_maps.shape
|
706 |
+
|
707 |
+
if len(partial_vol_origin.shape) == 1:
|
708 |
+
partial_vol_origin = partial_vol_origin[None, :]
|
709 |
+
pts = coords * voxel_size + partial_vol_origin
|
710 |
+
|
711 |
+
rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1)
|
712 |
+
rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts]
|
713 |
+
nV = rs_grid.shape[-1]
|
714 |
+
rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts]
|
715 |
+
|
716 |
+
# Project grid
|
717 |
+
im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts]
|
718 |
+
im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
|
719 |
+
im_x = im_x / im_z
|
720 |
+
im_y = im_y / im_z
|
721 |
+
|
722 |
+
im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
|
723 |
+
|
724 |
+
im_grid = im_grid.view(n_views, 1, -1, 2)
|
725 |
+
sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear',
|
726 |
+
padding_mode='zeros',
|
727 |
+
align_corners=True)[:, 0, 0, :] # [n_views, n_pts]
|
728 |
+
sampled_depths_valid = (sampled_depths > 0.5 * near).float()
|
729 |
+
valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(),
|
730 |
+
far.item()) * sampled_depths_valid
|
731 |
+
valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(),
|
732 |
+
far.item()) * sampled_depths_valid
|
733 |
+
|
734 |
+
mask = im_grid.abs() <= 1
|
735 |
+
mask = mask[:, 0] # [n_views, n_pts, 2]
|
736 |
+
mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max)
|
737 |
+
|
738 |
+
mask = mask.view(n_views, -1)
|
739 |
+
mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
|
740 |
+
|
741 |
+
mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0
|
742 |
+
|
743 |
+
return mask_final
|
744 |
+
|
745 |
+
@torch.no_grad()
|
746 |
+
def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume,
|
747 |
+
pred_depth_maps, proj_matrices,
|
748 |
+
partial_vol_origin, voxel_size,
|
749 |
+
near, far, depth_interval, d_plane_nums,
|
750 |
+
threshold=0.02, maximum_pts=110000):
|
751 |
+
"""
|
752 |
+
assume batch size == 1, from the first lod to get sparse voxels
|
753 |
+
:param sdf_volume: [1, X, Y, Z]
|
754 |
+
:param coords_volume: [3, X, Y, Z]
|
755 |
+
:param mask_volume: [1, X, Y, Z]
|
756 |
+
:param feature_volume: [C, X, Y, Z]
|
757 |
+
:param threshold:
|
758 |
+
:return:
|
759 |
+
"""
|
760 |
+
device = coords_volume.device
|
761 |
+
_, dX, dY, dZ = coords_volume.shape
|
762 |
+
|
763 |
+
def prune(sdf_pts, coords_pts, mask_volume, threshold):
|
764 |
+
occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts]
|
765 |
+
valid_coords = coords_pts[occupancy_mask]
|
766 |
+
|
767 |
+
# - filter backside surface by depth maps
|
768 |
+
mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices,
|
769 |
+
partial_vol_origin, voxel_size,
|
770 |
+
near, far, depth_interval, d_plane_nums)
|
771 |
+
valid_coords = valid_coords[mask_filtered]
|
772 |
+
|
773 |
+
# - dilate
|
774 |
+
occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1]
|
775 |
+
|
776 |
+
# - dilate
|
777 |
+
occupancy_mask = occupancy_mask.float()
|
778 |
+
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
|
779 |
+
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
|
780 |
+
occupancy_mask = occupancy_mask.view(-1, 1) > 0
|
781 |
+
|
782 |
+
final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
|
783 |
+
|
784 |
+
return final_mask, torch.sum(final_mask.float())
|
785 |
+
|
786 |
+
C, dX, dY, dZ = feature_volume.shape
|
787 |
+
sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
|
788 |
+
coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
|
789 |
+
mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
|
790 |
+
feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
|
791 |
+
|
792 |
+
# - for check
|
793 |
+
# sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02
|
794 |
+
|
795 |
+
final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
|
796 |
+
|
797 |
+
while (valid_num > maximum_pts) and (threshold > 0.003):
|
798 |
+
threshold = threshold - 0.002
|
799 |
+
final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold)
|
800 |
+
|
801 |
+
valid_coords = coords_volume[final_mask] # [N, 3]
|
802 |
+
valid_feature = feature_volume[final_mask] # [N, C]
|
803 |
+
|
804 |
+
valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
|
805 |
+
valid_coords], dim=1) # [N, 4], append batch idx
|
806 |
+
|
807 |
+
# ! if the valid_num is still larger than maximum_pts, sample part of pts
|
808 |
+
if valid_num > maximum_pts:
|
809 |
+
valid_num = valid_num.long()
|
810 |
+
occupancy = torch.ones([valid_num]).to(device) > 0
|
811 |
+
choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
|
812 |
+
replace=False)
|
813 |
+
ind = torch.nonzero(occupancy).to(device)
|
814 |
+
occupancy[ind[choice]] = False
|
815 |
+
valid_coords = valid_coords[occupancy]
|
816 |
+
valid_feature = valid_feature[occupancy]
|
817 |
+
|
818 |
+
print(threshold, "randomly sample to save memory")
|
819 |
+
|
820 |
+
return valid_coords, valid_feature
|
821 |
+
|
822 |
+
@torch.no_grad()
|
823 |
+
def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02,
|
824 |
+
maximum_pts=110000):
|
825 |
+
"""
|
826 |
+
assume batch size == 1, from the first lod to get sparse voxels
|
827 |
+
:param sdf_volume: [num_pts, 1]
|
828 |
+
:param coords_volume: [3, X, Y, Z]
|
829 |
+
:param mask_volume: [1, X, Y, Z]
|
830 |
+
:param feature_volume: [C, X, Y, Z]
|
831 |
+
:param threshold:
|
832 |
+
:return:
|
833 |
+
"""
|
834 |
+
|
835 |
+
def prune(sdf_volume, mask_volume, threshold):
|
836 |
+
occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
|
837 |
+
|
838 |
+
# - dilate
|
839 |
+
occupancy_mask = occupancy_mask.float()
|
840 |
+
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
|
841 |
+
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
|
842 |
+
occupancy_mask = occupancy_mask.view(-1, 1) > 0
|
843 |
+
|
844 |
+
final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts]
|
845 |
+
|
846 |
+
return final_mask, torch.sum(final_mask.float())
|
847 |
+
|
848 |
+
C, dX, dY, dZ = feature_volume.shape
|
849 |
+
coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3)
|
850 |
+
mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1)
|
851 |
+
feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C)
|
852 |
+
|
853 |
+
final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
|
854 |
+
|
855 |
+
while (valid_num > maximum_pts) and (threshold > 0.003):
|
856 |
+
threshold = threshold - 0.002
|
857 |
+
final_mask, valid_num = prune(sdf_volume, mask_volume, threshold)
|
858 |
+
|
859 |
+
valid_coords = coords_volume[final_mask] # [N, 3]
|
860 |
+
valid_feature = feature_volume[final_mask] # [N, C]
|
861 |
+
|
862 |
+
valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0,
|
863 |
+
valid_coords], dim=1) # [N, 4], append batch idx
|
864 |
+
|
865 |
+
# ! if the valid_num is still larger than maximum_pts, sample part of pts
|
866 |
+
if valid_num > maximum_pts:
|
867 |
+
device = sdf_volume.device
|
868 |
+
valid_num = valid_num.long()
|
869 |
+
occupancy = torch.ones([valid_num]).to(device) > 0
|
870 |
+
choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts,
|
871 |
+
replace=False)
|
872 |
+
ind = torch.nonzero(occupancy).to(device)
|
873 |
+
occupancy[ind[choice]] = False
|
874 |
+
valid_coords = valid_coords[occupancy]
|
875 |
+
valid_feature = valid_feature[occupancy]
|
876 |
+
|
877 |
+
print(threshold, "randomly sample to save memory")
|
878 |
+
|
879 |
+
return valid_coords, valid_feature
|
880 |
+
|
881 |
+
@torch.no_grad()
|
882 |
+
def extract_fields(self, bound_min, bound_max, resolution, query_func, device,
|
883 |
+
# * related to conditional feature
|
884 |
+
**kwargs
|
885 |
+
):
|
886 |
+
N = 64
|
887 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N)
|
888 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N)
|
889 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N)
|
890 |
+
|
891 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
892 |
+
with torch.no_grad():
|
893 |
+
for xi, xs in enumerate(X):
|
894 |
+
for yi, ys in enumerate(Y):
|
895 |
+
for zi, zs in enumerate(Z):
|
896 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
|
897 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
898 |
+
|
899 |
+
# ! attention, the query function is different for extract geometry and fields
|
900 |
+
output = query_func(pts, **kwargs)
|
901 |
+
sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys),
|
902 |
+
len(zs)).detach().cpu().numpy()
|
903 |
+
|
904 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf
|
905 |
+
return u
|
906 |
+
|
907 |
+
@torch.no_grad()
|
908 |
+
def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None,
|
909 |
+
# * 3d feature volume
|
910 |
+
**kwargs
|
911 |
+
):
|
912 |
+
# logging.info('threshold: {}'.format(threshold))
|
913 |
+
|
914 |
+
u = self.extract_fields(bound_min, bound_max, resolution,
|
915 |
+
lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs),
|
916 |
+
# - sdf need to be multiplied by -1
|
917 |
+
device,
|
918 |
+
# * 3d feature volume
|
919 |
+
**kwargs
|
920 |
+
)
|
921 |
+
if occupancy_mask is not None:
|
922 |
+
dX, dY, dZ = occupancy_mask.shape
|
923 |
+
empty_mask = 1 - occupancy_mask
|
924 |
+
empty_mask = empty_mask.view(1, 1, dX, dY, dZ)
|
925 |
+
# - dilation
|
926 |
+
# empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3)
|
927 |
+
empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest')
|
928 |
+
empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0
|
929 |
+
u[empty_mask] = -100
|
930 |
+
del empty_mask
|
931 |
+
|
932 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold)
|
933 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
934 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
935 |
+
|
936 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
937 |
+
return vertices, triangles, u
|
938 |
+
|
939 |
+
@torch.no_grad()
|
940 |
+
def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far):
|
941 |
+
"""
|
942 |
+
extract depth maps from the density volume
|
943 |
+
:param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume
|
944 |
+
:param c2ws: [B, 4, 4]
|
945 |
+
:param H:
|
946 |
+
:param W:
|
947 |
+
:param near:
|
948 |
+
:param far:
|
949 |
+
:return:
|
950 |
+
"""
|
951 |
+
device = con_volume.device
|
952 |
+
batch_size = intrinsics.shape[0]
|
953 |
+
|
954 |
+
with torch.no_grad():
|
955 |
+
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H),
|
956 |
+
torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij'
|
957 |
+
p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3
|
958 |
+
|
959 |
+
intrinsics_inv = torch.inverse(intrinsics)
|
960 |
+
|
961 |
+
p = p.view(-1, 3).float().to(device) # N_rays, 3
|
962 |
+
p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3
|
963 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3
|
964 |
+
rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3
|
965 |
+
rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3
|
966 |
+
rays_d = rays_v
|
967 |
+
|
968 |
+
rays_o = rays_o.contiguous().view(-1, 3)
|
969 |
+
rays_d = rays_d.contiguous().view(-1, 3)
|
970 |
+
|
971 |
+
################## - sphere tracer to extract depth maps ######################
|
972 |
+
depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps(
|
973 |
+
rays_o, rays_d,
|
974 |
+
near[None, :].repeat(rays_o.shape[0], 1),
|
975 |
+
far[None, :].repeat(rays_o.shape[0], 1),
|
976 |
+
sdf_network, con_volume
|
977 |
+
)
|
978 |
+
|
979 |
+
depth_maps = depth_maps_sphere.view(batch_size, 1, H, W)
|
980 |
+
depth_masks = depth_masks_sphere.view(batch_size, 1, H, W)
|
981 |
+
|
982 |
+
depth_maps = torch.where(depth_masks, depth_maps,
|
983 |
+
torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0
|
984 |
+
|
985 |
+
return depth_maps, depth_masks
|
SparseNeuS_demo_v1/models/sparse_sdf_network.py
ADDED
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchsparse.tensor import PointTensor, SparseTensor
|
6 |
+
import torchsparse.nn as spnn
|
7 |
+
|
8 |
+
from tsparse.modules import SparseCostRegNet
|
9 |
+
from tsparse.torchsparse_utils import sparse_to_dense_channel
|
10 |
+
from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d
|
11 |
+
|
12 |
+
# from .gru_fusion import GRUFusion
|
13 |
+
from ops.back_project import back_project_sparse_type
|
14 |
+
from ops.generate_grids import generate_grid
|
15 |
+
|
16 |
+
from inplace_abn import InPlaceABN
|
17 |
+
|
18 |
+
from models.embedder import Embedding
|
19 |
+
from models.featurenet import ConvBnReLU
|
20 |
+
|
21 |
+
import pdb
|
22 |
+
import random
|
23 |
+
|
24 |
+
torch._C._jit_set_profiling_executor(False)
|
25 |
+
torch._C._jit_set_profiling_mode(False)
|
26 |
+
|
27 |
+
|
28 |
+
@torch.jit.script
|
29 |
+
def fused_mean_variance(x, weight):
|
30 |
+
mean = torch.sum(x * weight, dim=1, keepdim=True)
|
31 |
+
var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True)
|
32 |
+
return mean, var
|
33 |
+
|
34 |
+
|
35 |
+
class LatentSDFLayer(nn.Module):
|
36 |
+
def __init__(self,
|
37 |
+
d_in=3,
|
38 |
+
d_out=129,
|
39 |
+
d_hidden=128,
|
40 |
+
n_layers=4,
|
41 |
+
skip_in=(4,),
|
42 |
+
multires=0,
|
43 |
+
bias=0.5,
|
44 |
+
geometric_init=True,
|
45 |
+
weight_norm=True,
|
46 |
+
activation='softplus',
|
47 |
+
d_conditional_feature=16):
|
48 |
+
super(LatentSDFLayer, self).__init__()
|
49 |
+
|
50 |
+
self.d_conditional_feature = d_conditional_feature
|
51 |
+
|
52 |
+
# concat latent code for ench layer input excepting the first layer and the last layer
|
53 |
+
dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden]
|
54 |
+
dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out]
|
55 |
+
|
56 |
+
self.embed_fn_fine = None
|
57 |
+
|
58 |
+
if multires > 0:
|
59 |
+
embed_fn = Embedding(in_channels=d_in, N_freqs=multires) # * include the input
|
60 |
+
self.embed_fn_fine = embed_fn
|
61 |
+
dims_in[0] = embed_fn.out_channels
|
62 |
+
|
63 |
+
self.num_layers = n_layers
|
64 |
+
self.skip_in = skip_in
|
65 |
+
|
66 |
+
for l in range(0, self.num_layers - 1):
|
67 |
+
if l in self.skip_in:
|
68 |
+
in_dim = dims_in[l] + dims_in[0]
|
69 |
+
else:
|
70 |
+
in_dim = dims_in[l]
|
71 |
+
|
72 |
+
out_dim = dims_out[l]
|
73 |
+
lin = nn.Linear(in_dim, out_dim)
|
74 |
+
|
75 |
+
if geometric_init: # - from IDR code,
|
76 |
+
if l == self.num_layers - 2:
|
77 |
+
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001)
|
78 |
+
torch.nn.init.constant_(lin.bias, -bias)
|
79 |
+
# the channels for latent codes are set to 0
|
80 |
+
torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
|
81 |
+
torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0)
|
82 |
+
|
83 |
+
elif multires > 0 and l == 0: # the first layer
|
84 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
85 |
+
# * the channels for position embeddings are set to 0
|
86 |
+
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
87 |
+
# * the channels for the xyz coordinate (3 channels) for initialized by normal distribution
|
88 |
+
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
89 |
+
elif multires > 0 and l in self.skip_in:
|
90 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
91 |
+
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
92 |
+
# * the channels for position embeddings (and conditional_feature) are initialized to 0
|
93 |
+
torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0)
|
94 |
+
else:
|
95 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
96 |
+
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
|
97 |
+
# the channels for latent code are initialized to 0
|
98 |
+
torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0)
|
99 |
+
|
100 |
+
if weight_norm:
|
101 |
+
lin = nn.utils.weight_norm(lin)
|
102 |
+
|
103 |
+
setattr(self, "lin" + str(l), lin)
|
104 |
+
|
105 |
+
if activation == 'softplus':
|
106 |
+
self.activation = nn.Softplus(beta=100)
|
107 |
+
else:
|
108 |
+
assert activation == 'relu'
|
109 |
+
self.activation = nn.ReLU()
|
110 |
+
|
111 |
+
def forward(self, inputs, latent):
|
112 |
+
inputs = inputs
|
113 |
+
if self.embed_fn_fine is not None:
|
114 |
+
inputs = self.embed_fn_fine(inputs)
|
115 |
+
|
116 |
+
# - only for lod1 network can use the pretrained params of lod0 network
|
117 |
+
if latent.shape[1] != self.d_conditional_feature:
|
118 |
+
latent = torch.cat([latent, latent], dim=1)
|
119 |
+
|
120 |
+
x = inputs
|
121 |
+
for l in range(0, self.num_layers - 1):
|
122 |
+
lin = getattr(self, "lin" + str(l))
|
123 |
+
|
124 |
+
# * due to the conditional bias, different from original neus version
|
125 |
+
if l in self.skip_in:
|
126 |
+
x = torch.cat([x, inputs], 1) / np.sqrt(2)
|
127 |
+
|
128 |
+
if 0 < l < self.num_layers - 1:
|
129 |
+
x = torch.cat([x, latent], 1)
|
130 |
+
|
131 |
+
x = lin(x)
|
132 |
+
|
133 |
+
if l < self.num_layers - 2:
|
134 |
+
x = self.activation(x)
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
|
139 |
+
class SparseSdfNetwork(nn.Module):
|
140 |
+
'''
|
141 |
+
Coarse-to-fine sparse cost regularization network
|
142 |
+
return sparse volume feature for extracting sdf
|
143 |
+
'''
|
144 |
+
|
145 |
+
def __init__(self, lod, ch_in, voxel_size, vol_dims,
|
146 |
+
hidden_dim=128, activation='softplus',
|
147 |
+
cost_type='variance_mean',
|
148 |
+
d_pyramid_feature_compress=16,
|
149 |
+
regnet_d_out=8, num_sdf_layers=4,
|
150 |
+
multires=6,
|
151 |
+
):
|
152 |
+
super(SparseSdfNetwork, self).__init__()
|
153 |
+
|
154 |
+
self.lod = lod # - gradually training, the current regularization lod
|
155 |
+
self.ch_in = ch_in
|
156 |
+
self.voxel_size = voxel_size # - the voxel size of the current volume
|
157 |
+
self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
|
158 |
+
|
159 |
+
self.selected_views_num = 2 # the number of selected views for feature aggregation
|
160 |
+
self.hidden_dim = hidden_dim
|
161 |
+
self.activation = activation
|
162 |
+
self.cost_type = cost_type
|
163 |
+
self.d_pyramid_feature_compress = d_pyramid_feature_compress
|
164 |
+
self.gru_fusion = None
|
165 |
+
|
166 |
+
self.regnet_d_out = regnet_d_out
|
167 |
+
self.multires = multires
|
168 |
+
|
169 |
+
self.pos_embedder = Embedding(3, self.multires)
|
170 |
+
|
171 |
+
self.compress_layer = ConvBnReLU(
|
172 |
+
self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1,
|
173 |
+
norm_act=InPlaceABN)
|
174 |
+
sparse_ch_in = self.d_pyramid_feature_compress * 2
|
175 |
+
|
176 |
+
sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in
|
177 |
+
self.sparse_costreg_net = SparseCostRegNet(
|
178 |
+
d_in=sparse_ch_in, d_out=self.regnet_d_out)
|
179 |
+
# self.regnet_d_out = self.sparse_costreg_net.d_out
|
180 |
+
|
181 |
+
if activation == 'softplus':
|
182 |
+
self.activation = nn.Softplus(beta=100)
|
183 |
+
else:
|
184 |
+
assert activation == 'relu'
|
185 |
+
self.activation = nn.ReLU()
|
186 |
+
|
187 |
+
self.sdf_layer = LatentSDFLayer(d_in=3,
|
188 |
+
d_out=self.hidden_dim + 1,
|
189 |
+
d_hidden=self.hidden_dim,
|
190 |
+
n_layers=num_sdf_layers,
|
191 |
+
multires=multires,
|
192 |
+
geometric_init=True,
|
193 |
+
weight_norm=True,
|
194 |
+
activation=activation,
|
195 |
+
d_conditional_feature=16 # self.regnet_d_out
|
196 |
+
)
|
197 |
+
|
198 |
+
def upsample(self, pre_feat, pre_coords, interval, num=8):
|
199 |
+
'''
|
200 |
+
|
201 |
+
:param pre_feat: (Tensor), features from last level, (N, C)
|
202 |
+
:param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z)
|
203 |
+
:param interval: interval of voxels, interval = scale ** 2
|
204 |
+
:param num: 1 -> 8
|
205 |
+
:return: up_feat : (Tensor), upsampled features, (N*8, C)
|
206 |
+
:return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z)
|
207 |
+
'''
|
208 |
+
with torch.no_grad():
|
209 |
+
pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]]
|
210 |
+
n, c = pre_feat.shape
|
211 |
+
up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous()
|
212 |
+
up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous()
|
213 |
+
for i in range(num - 1):
|
214 |
+
up_coords[:, i + 1, pos_list[i]] += interval
|
215 |
+
|
216 |
+
up_feat = up_feat.view(-1, c)
|
217 |
+
up_coords = up_coords.view(-1, 4)
|
218 |
+
|
219 |
+
return up_feat, up_coords
|
220 |
+
|
221 |
+
def aggregate_multiview_features(self, multiview_features, multiview_masks):
|
222 |
+
"""
|
223 |
+
aggregate mutli-view features by compute their cost variance
|
224 |
+
:param multiview_features: (num of voxels, num_of_views, c)
|
225 |
+
:param multiview_masks: (num of voxels, num_of_views)
|
226 |
+
:return:
|
227 |
+
"""
|
228 |
+
num_pts, n_views, C = multiview_features.shape
|
229 |
+
|
230 |
+
counts = torch.sum(multiview_masks, dim=1, keepdim=False) # [num_pts]
|
231 |
+
|
232 |
+
assert torch.all(counts > 0) # the point is visible for at least 1 view
|
233 |
+
|
234 |
+
volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) # [num_pts, C]
|
235 |
+
volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False)
|
236 |
+
|
237 |
+
if volume_sum.isnan().sum() > 0:
|
238 |
+
import ipdb; ipdb.set_trace()
|
239 |
+
|
240 |
+
del multiview_features
|
241 |
+
|
242 |
+
counts = 1. / (counts + 1e-5)
|
243 |
+
costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2
|
244 |
+
|
245 |
+
costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1)
|
246 |
+
del volume_sum, volume_sq_sum, counts
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
return costvar_mean
|
251 |
+
|
252 |
+
def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None):
|
253 |
+
"""
|
254 |
+
convert the sparse volume into dense volume to enable trilinear sampling
|
255 |
+
to save GPU memory;
|
256 |
+
:param coords: [num_pts, 3]
|
257 |
+
:param feature: [num_pts, C]
|
258 |
+
:param vol_dims: [3] dX, dY, dZ
|
259 |
+
:param interval:
|
260 |
+
:return:
|
261 |
+
"""
|
262 |
+
|
263 |
+
# * assume batch size is 1
|
264 |
+
if device is None:
|
265 |
+
device = feature.device
|
266 |
+
|
267 |
+
coords_int = (coords / interval).to(torch.int64)
|
268 |
+
vol_dims = (vol_dims / interval).to(torch.int64)
|
269 |
+
|
270 |
+
# - if stored in CPU, too slow
|
271 |
+
dense_volume = sparse_to_dense_channel(
|
272 |
+
coords_int.to(device), feature.to(device), vol_dims.to(device),
|
273 |
+
feature.shape[1], 0, device) # [X, Y, Z, C]
|
274 |
+
|
275 |
+
valid_mask_volume = sparse_to_dense_channel(
|
276 |
+
coords_int.to(device),
|
277 |
+
torch.ones([feature.shape[0], 1]).to(feature.device),
|
278 |
+
vol_dims.to(device),
|
279 |
+
1, 0, device) # [X, Y, Z, 1]
|
280 |
+
|
281 |
+
dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
|
282 |
+
valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
|
283 |
+
|
284 |
+
return dense_volume, valid_mask_volume
|
285 |
+
|
286 |
+
def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0,
|
287 |
+
pre_coords=None, pre_feats=None,
|
288 |
+
):
|
289 |
+
"""
|
290 |
+
|
291 |
+
:param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features
|
292 |
+
:param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0)
|
293 |
+
:param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size
|
294 |
+
:param sizeH: the H of original image size
|
295 |
+
:param sizeW: the W of original image size
|
296 |
+
:param pre_coords: the coordinates of sparse volume from the prior lod
|
297 |
+
:param pre_feats: the features of sparse volume from the prior lod
|
298 |
+
:return:
|
299 |
+
"""
|
300 |
+
device = proj_mats.device
|
301 |
+
bs = feature_maps.shape[0]
|
302 |
+
N_views = feature_maps.shape[1]
|
303 |
+
minimum_visible_views = np.min([1, N_views - 1])
|
304 |
+
# import ipdb; ipdb.set_trace()
|
305 |
+
outputs = {}
|
306 |
+
pts_samples = []
|
307 |
+
|
308 |
+
# ----coarse to fine----
|
309 |
+
|
310 |
+
# * use fused pyramid feature maps are very important
|
311 |
+
if self.compress_layer is not None:
|
312 |
+
feats = self.compress_layer(feature_maps[0])
|
313 |
+
else:
|
314 |
+
feats = feature_maps[0]
|
315 |
+
feats = feats[:, None, :, :, :] # [V, B, C, H, W]
|
316 |
+
KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() # [V, B, 4, 4]
|
317 |
+
interval = 1
|
318 |
+
|
319 |
+
if self.lod == 0:
|
320 |
+
# ----generate new coords----
|
321 |
+
coords = generate_grid(self.vol_dims, 1)[0]
|
322 |
+
coords = coords.view(3, -1).to(device) # [3, num_pts]
|
323 |
+
up_coords = []
|
324 |
+
for b in range(bs):
|
325 |
+
up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords]))
|
326 |
+
up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous()
|
327 |
+
# * since we only estimate the geometry of input reference image at one time;
|
328 |
+
# * mask the outside of the camera frustum
|
329 |
+
# import ipdb; ipdb.set_trace()
|
330 |
+
frustum_mask = back_project_sparse_type(
|
331 |
+
up_coords, partial_vol_origin, self.voxel_size,
|
332 |
+
feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) # [num_pts, n_views]
|
333 |
+
frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views # ! here should be large
|
334 |
+
up_coords = up_coords[frustum_mask] # [num_pts_valid, 4]
|
335 |
+
|
336 |
+
else:
|
337 |
+
# ----upsample coords----
|
338 |
+
assert pre_feats is not None
|
339 |
+
assert pre_coords is not None
|
340 |
+
up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1)
|
341 |
+
|
342 |
+
# ----back project----
|
343 |
+
# give each valid 3d grid point all valid 2D features and masks
|
344 |
+
multiview_features, multiview_masks = back_project_sparse_type(
|
345 |
+
up_coords, partial_vol_origin, self.voxel_size, feats,
|
346 |
+
KRcam, sizeH=sizeH, sizeW=sizeW) # (num of voxels, num_of_views, c), (num of voxels, num_of_views)
|
347 |
+
# num_of_views = all views
|
348 |
+
|
349 |
+
# if multiview_features.isnan().sum() > 0:
|
350 |
+
# import ipdb; ipdb.set_trace()
|
351 |
+
|
352 |
+
# import ipdb; ipdb.set_trace()
|
353 |
+
if self.lod > 0:
|
354 |
+
# ! need another invalid voxels filtering
|
355 |
+
frustum_mask = torch.sum(multiview_masks, dim=-1) > 1
|
356 |
+
up_feat = up_feat[frustum_mask]
|
357 |
+
up_coords = up_coords[frustum_mask]
|
358 |
+
multiview_features = multiview_features[frustum_mask]
|
359 |
+
multiview_masks = multiview_masks[frustum_mask]
|
360 |
+
# if multiview_features.isnan().sum() > 0:
|
361 |
+
# import ipdb; ipdb.set_trace()
|
362 |
+
volume = self.aggregate_multiview_features(multiview_features, multiview_masks) # compute variance for all images features
|
363 |
+
# import ipdb; ipdb.set_trace()
|
364 |
+
|
365 |
+
# if volume.isnan().sum() > 0:
|
366 |
+
# import ipdb; ipdb.set_trace()
|
367 |
+
|
368 |
+
del multiview_features, multiview_masks
|
369 |
+
|
370 |
+
# ----concat feature from last stage----
|
371 |
+
if self.lod != 0:
|
372 |
+
feat = torch.cat([volume, up_feat], dim=1)
|
373 |
+
else:
|
374 |
+
feat = volume
|
375 |
+
|
376 |
+
# batch index is in the last position
|
377 |
+
r_coords = up_coords[:, [1, 2, 3, 0]]
|
378 |
+
|
379 |
+
# if feat.isnan().sum() > 0:
|
380 |
+
# print('feat has nan:', feat.isnan().sum())
|
381 |
+
# import ipdb; ipdb.set_trace()
|
382 |
+
|
383 |
+
sparse_feat = SparseTensor(feat, r_coords.to(
|
384 |
+
torch.int32)) # - directly use sparse tensor to avoid point2voxel operations
|
385 |
+
# import ipdb; ipdb.set_trace()
|
386 |
+
feat = self.sparse_costreg_net(sparse_feat)
|
387 |
+
|
388 |
+
dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval,
|
389 |
+
device=None) # [1, C/1, X, Y, Z]
|
390 |
+
|
391 |
+
# if dense_volume.isnan().sum() > 0:
|
392 |
+
# import ipdb; ipdb.set_trace()
|
393 |
+
|
394 |
+
|
395 |
+
outputs['dense_volume_scale%d' % self.lod] = dense_volume # [1, 16, 96, 96, 96]
|
396 |
+
outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
|
397 |
+
outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96]
|
398 |
+
outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device)
|
399 |
+
# import ipdb; ipdb.set_trace()
|
400 |
+
return outputs
|
401 |
+
|
402 |
+
def sdf(self, pts, conditional_volume, lod):
|
403 |
+
num_pts = pts.shape[0]
|
404 |
+
device = pts.device
|
405 |
+
pts_ = pts.clone()
|
406 |
+
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
|
407 |
+
|
408 |
+
pts = torch.flip(pts, dims=[-1])
|
409 |
+
# import ipdb; ipdb.set_trace()
|
410 |
+
sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
|
411 |
+
sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device)
|
412 |
+
|
413 |
+
sdf_pts = self.sdf_layer(pts_, sampled_feature)
|
414 |
+
|
415 |
+
outputs = {}
|
416 |
+
outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
|
417 |
+
outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
|
418 |
+
outputs['sampled_latent_scale%d' % lod] = sampled_feature
|
419 |
+
|
420 |
+
return outputs
|
421 |
+
|
422 |
+
@torch.no_grad()
|
423 |
+
def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0):
|
424 |
+
num_pts = pts.shape[0]
|
425 |
+
device = pts.device
|
426 |
+
pts_ = pts.clone()
|
427 |
+
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
|
428 |
+
|
429 |
+
pts = torch.flip(pts, dims=[-1])
|
430 |
+
|
431 |
+
sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True,
|
432 |
+
padding_mode='border')
|
433 |
+
sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device)
|
434 |
+
|
435 |
+
outputs = {}
|
436 |
+
outputs['sdf_pts_scale%d' % lod] = sdf
|
437 |
+
|
438 |
+
return outputs
|
439 |
+
|
440 |
+
@torch.no_grad()
|
441 |
+
def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin):
|
442 |
+
"""
|
443 |
+
|
444 |
+
:param conditional_volume: [1,C, dX,dY,dZ]
|
445 |
+
:param mask_volume: [1,1, dX,dY,dZ]
|
446 |
+
:param coords_volume: [1,3, dX,dY,dZ]
|
447 |
+
:return:
|
448 |
+
"""
|
449 |
+
device = conditional_volume.device
|
450 |
+
chunk_size = 10240
|
451 |
+
|
452 |
+
_, C, dX, dY, dZ = conditional_volume.shape
|
453 |
+
conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous()
|
454 |
+
mask_volume = mask_volume.view(-1)
|
455 |
+
coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous()
|
456 |
+
|
457 |
+
pts = coords_volume * self.voxel_size + partial_origin # [dX*dY*dZ, 3]
|
458 |
+
|
459 |
+
sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device)
|
460 |
+
|
461 |
+
conditional_volume = conditional_volume[mask_volume > 0]
|
462 |
+
pts = pts[mask_volume > 0]
|
463 |
+
conditional_volume = conditional_volume.split(chunk_size)
|
464 |
+
pts = pts.split(chunk_size)
|
465 |
+
|
466 |
+
sdf_all = []
|
467 |
+
for pts_part, feature_part in zip(pts, conditional_volume):
|
468 |
+
sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
|
469 |
+
sdf_all.append(sdf_part)
|
470 |
+
|
471 |
+
sdf_all = torch.cat(sdf_all, dim=0)
|
472 |
+
sdf_volume[mask_volume > 0] = sdf_all
|
473 |
+
sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ)
|
474 |
+
return sdf_volume
|
475 |
+
|
476 |
+
def gradient(self, x, conditional_volume, lod):
|
477 |
+
"""
|
478 |
+
return the gradient of specific lod
|
479 |
+
:param x:
|
480 |
+
:param lod:
|
481 |
+
:return:
|
482 |
+
"""
|
483 |
+
x.requires_grad_(True)
|
484 |
+
# import ipdb; ipdb.set_trace()
|
485 |
+
with torch.enable_grad():
|
486 |
+
output = self.sdf(x, conditional_volume, lod)
|
487 |
+
y = output['sdf_pts_scale%d' % lod]
|
488 |
+
|
489 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
490 |
+
# ! Distributed Data Parallel doesn’t work with torch.autograd.grad()
|
491 |
+
# ! (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters).
|
492 |
+
gradients = torch.autograd.grad(
|
493 |
+
outputs=y,
|
494 |
+
inputs=x,
|
495 |
+
grad_outputs=d_output,
|
496 |
+
create_graph=True,
|
497 |
+
retain_graph=True,
|
498 |
+
only_inputs=True)[0]
|
499 |
+
return gradients.unsqueeze(1)
|
500 |
+
|
501 |
+
|
502 |
+
def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None):
|
503 |
+
"""
|
504 |
+
convert the sparse volume into dense volume to enable trilinear sampling
|
505 |
+
to save GPU memory;
|
506 |
+
:param coords: [num_pts, 3]
|
507 |
+
:param feature: [num_pts, C]
|
508 |
+
:param vol_dims: [3] dX, dY, dZ
|
509 |
+
:param interval:
|
510 |
+
:return:
|
511 |
+
"""
|
512 |
+
|
513 |
+
# * assume batch size is 1
|
514 |
+
if device is None:
|
515 |
+
device = feature.device
|
516 |
+
|
517 |
+
coords_int = (coords / interval).to(torch.int64)
|
518 |
+
vol_dims = (vol_dims / interval).to(torch.int64)
|
519 |
+
|
520 |
+
# - if stored in CPU, too slow
|
521 |
+
dense_volume = sparse_to_dense_channel(
|
522 |
+
coords_int.to(device), feature.to(device), vol_dims.to(device),
|
523 |
+
feature.shape[1], 0, device) # [X, Y, Z, C]
|
524 |
+
|
525 |
+
valid_mask_volume = sparse_to_dense_channel(
|
526 |
+
coords_int.to(device),
|
527 |
+
torch.ones([feature.shape[0], 1]).to(feature.device),
|
528 |
+
vol_dims.to(device),
|
529 |
+
1, 0, device) # [X, Y, Z, 1]
|
530 |
+
|
531 |
+
dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z]
|
532 |
+
valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z]
|
533 |
+
|
534 |
+
return dense_volume, valid_mask_volume
|
535 |
+
|
536 |
+
|
537 |
+
class SdfVolume(nn.Module):
|
538 |
+
def __init__(self, volume, coords=None, type='dense'):
|
539 |
+
super(SdfVolume, self).__init__()
|
540 |
+
self.volume = torch.nn.Parameter(volume, requires_grad=True)
|
541 |
+
self.coords = coords
|
542 |
+
self.type = type
|
543 |
+
|
544 |
+
def forward(self):
|
545 |
+
return self.volume
|
546 |
+
|
547 |
+
|
548 |
+
class FinetuneOctreeSdfNetwork(nn.Module):
|
549 |
+
'''
|
550 |
+
After obtain the conditional volume from generalized network;
|
551 |
+
directly optimize the conditional volume
|
552 |
+
The conditional volume is still sparse
|
553 |
+
'''
|
554 |
+
|
555 |
+
def __init__(self, voxel_size, vol_dims,
|
556 |
+
origin=[-1., -1., -1.],
|
557 |
+
hidden_dim=128, activation='softplus',
|
558 |
+
regnet_d_out=8,
|
559 |
+
multires=6,
|
560 |
+
if_fitted_rendering=True,
|
561 |
+
num_sdf_layers=4,
|
562 |
+
):
|
563 |
+
super(FinetuneOctreeSdfNetwork, self).__init__()
|
564 |
+
|
565 |
+
self.voxel_size = voxel_size # - the voxel size of the current volume
|
566 |
+
self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume
|
567 |
+
|
568 |
+
self.origin = torch.tensor(origin).to(torch.float32)
|
569 |
+
|
570 |
+
self.hidden_dim = hidden_dim
|
571 |
+
self.activation = activation
|
572 |
+
|
573 |
+
self.regnet_d_out = regnet_d_out
|
574 |
+
|
575 |
+
self.if_fitted_rendering = if_fitted_rendering
|
576 |
+
self.multires = multires
|
577 |
+
# d_in_embedding = self.regnet_d_out if self.pos_add_type == 'latent' else 3
|
578 |
+
# self.pos_embedder = Embedding(d_in_embedding, self.multires)
|
579 |
+
|
580 |
+
# - the optimized parameters
|
581 |
+
self.sparse_volume_lod0 = None
|
582 |
+
self.sparse_coords_lod0 = None
|
583 |
+
|
584 |
+
if activation == 'softplus':
|
585 |
+
self.activation = nn.Softplus(beta=100)
|
586 |
+
else:
|
587 |
+
assert activation == 'relu'
|
588 |
+
self.activation = nn.ReLU()
|
589 |
+
|
590 |
+
self.sdf_layer = LatentSDFLayer(d_in=3,
|
591 |
+
d_out=self.hidden_dim + 1,
|
592 |
+
d_hidden=self.hidden_dim,
|
593 |
+
n_layers=num_sdf_layers,
|
594 |
+
multires=multires,
|
595 |
+
geometric_init=True,
|
596 |
+
weight_norm=True,
|
597 |
+
activation=activation,
|
598 |
+
d_conditional_feature=16 # self.regnet_d_out
|
599 |
+
)
|
600 |
+
|
601 |
+
# - add mlp rendering when finetuning
|
602 |
+
self.renderer = None
|
603 |
+
|
604 |
+
d_in_renderer = 3 + self.regnet_d_out + 3 + 3
|
605 |
+
self.renderer = BlendingRenderingNetwork(
|
606 |
+
d_feature=self.hidden_dim - 1,
|
607 |
+
mode='idr', # ! the view direction influence a lot
|
608 |
+
d_in=d_in_renderer,
|
609 |
+
d_out=50, # maximum 50 images
|
610 |
+
d_hidden=self.hidden_dim,
|
611 |
+
n_layers=3,
|
612 |
+
weight_norm=True,
|
613 |
+
multires_view=4,
|
614 |
+
squeeze_out=True,
|
615 |
+
)
|
616 |
+
|
617 |
+
def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0,
|
618 |
+
sparse_volume_lod0=None, sparse_coords_lod0=None):
|
619 |
+
"""
|
620 |
+
|
621 |
+
:param dense_volume_lod0: [1,C,dX,dY,dZ]
|
622 |
+
:param dense_volume_mask_lod0: [1,1,dX,dY,dZ]
|
623 |
+
:param dense_volume_lod1:
|
624 |
+
:param dense_volume_mask_lod1:
|
625 |
+
:return:
|
626 |
+
"""
|
627 |
+
|
628 |
+
if sparse_volume_lod0 is None:
|
629 |
+
device = dense_volume_lod0.device
|
630 |
+
_, C, dX, dY, dZ = dense_volume_lod0.shape
|
631 |
+
|
632 |
+
dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous()
|
633 |
+
mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0
|
634 |
+
|
635 |
+
self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse')
|
636 |
+
|
637 |
+
coords = generate_grid(self.vol_dims, 1)[0] # [3, dX, dY, dZ]
|
638 |
+
coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device)
|
639 |
+
self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False)
|
640 |
+
else:
|
641 |
+
self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse')
|
642 |
+
self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False)
|
643 |
+
|
644 |
+
def get_conditional_volume(self):
|
645 |
+
dense_volume, valid_mask_volume = sparse_to_dense_volume(
|
646 |
+
self.sparse_coords_lod0,
|
647 |
+
self.sparse_volume_lod0(), self.vol_dims, interval=1,
|
648 |
+
device=None) # [1, C/1, X, Y, Z]
|
649 |
+
|
650 |
+
# valid_mask_volume = self.dense_volume_mask_lod0
|
651 |
+
|
652 |
+
outputs = {}
|
653 |
+
outputs['dense_volume_scale%d' % 0] = dense_volume
|
654 |
+
outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume
|
655 |
+
|
656 |
+
return outputs
|
657 |
+
|
658 |
+
def tv_regularizer(self):
|
659 |
+
dense_volume, valid_mask_volume = sparse_to_dense_volume(
|
660 |
+
self.sparse_coords_lod0,
|
661 |
+
self.sparse_volume_lod0(), self.vol_dims, interval=1,
|
662 |
+
device=None) # [1, C/1, X, Y, Z]
|
663 |
+
|
664 |
+
dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 # [1, C/1, X-1, Y, Z]
|
665 |
+
dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 # [1, C/1, X, Y-1, Z]
|
666 |
+
dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 # [1, C/1, X, Y, Z-1]
|
667 |
+
|
668 |
+
tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] # [1, C/1, X-1, Y-1, Z-1]
|
669 |
+
|
670 |
+
mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \
|
671 |
+
valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:]
|
672 |
+
|
673 |
+
tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask
|
674 |
+
# tv = tv.mean(dim=1, keepdim=True) * mask
|
675 |
+
|
676 |
+
assert torch.all(~torch.isnan(tv))
|
677 |
+
|
678 |
+
return torch.mean(tv)
|
679 |
+
|
680 |
+
def sdf(self, pts, conditional_volume, lod):
|
681 |
+
|
682 |
+
outputs = {}
|
683 |
+
|
684 |
+
num_pts = pts.shape[0]
|
685 |
+
device = pts.device
|
686 |
+
pts_ = pts.clone()
|
687 |
+
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
|
688 |
+
|
689 |
+
pts = torch.flip(pts, dims=[-1])
|
690 |
+
|
691 |
+
sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts]
|
692 |
+
sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous()
|
693 |
+
outputs['sampled_latent_scale%d' % lod] = sampled_feature
|
694 |
+
|
695 |
+
sdf_pts = self.sdf_layer(pts_, sampled_feature)
|
696 |
+
|
697 |
+
lod = 0
|
698 |
+
outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1]
|
699 |
+
outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:]
|
700 |
+
|
701 |
+
return outputs
|
702 |
+
|
703 |
+
def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index,
|
704 |
+
pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
|
705 |
+
|
706 |
+
return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors,
|
707 |
+
img_index, pts_pixel_color, pts_pixel_mask,
|
708 |
+
pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask)
|
709 |
+
|
710 |
+
def gradient(self, x, conditional_volume, lod):
|
711 |
+
"""
|
712 |
+
return the gradient of specific lod
|
713 |
+
:param x:
|
714 |
+
:param lod:
|
715 |
+
:return:
|
716 |
+
"""
|
717 |
+
x.requires_grad_(True)
|
718 |
+
output = self.sdf(x, conditional_volume, lod)
|
719 |
+
y = output['sdf_pts_scale%d' % 0]
|
720 |
+
|
721 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
722 |
+
|
723 |
+
gradients = torch.autograd.grad(
|
724 |
+
outputs=y,
|
725 |
+
inputs=x,
|
726 |
+
grad_outputs=d_output,
|
727 |
+
create_graph=True,
|
728 |
+
retain_graph=True,
|
729 |
+
only_inputs=True)[0]
|
730 |
+
return gradients.unsqueeze(1)
|
731 |
+
|
732 |
+
@torch.no_grad()
|
733 |
+
def prune_dense_mask(self, threshold=0.02):
|
734 |
+
"""
|
735 |
+
Just gradually prune the mask of dense volume to decrease the number of sdf network inference
|
736 |
+
:return:
|
737 |
+
"""
|
738 |
+
chunk_size = 10240
|
739 |
+
coords = generate_grid(self.vol_dims_lod0, 1)[0] # [3, dX, dY, dZ]
|
740 |
+
|
741 |
+
_, dX, dY, dZ = coords.shape
|
742 |
+
|
743 |
+
pts = coords.view(3, -1).permute(1,
|
744 |
+
0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] # [dX*dY*dZ, 3]
|
745 |
+
|
746 |
+
# dense_volume = self.dense_volume_lod0() # [1,C,dX,dY,dZ]
|
747 |
+
dense_volume, _ = sparse_to_dense_volume(
|
748 |
+
self.sparse_coords_lod0,
|
749 |
+
self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1,
|
750 |
+
device=None) # [1, C/1, X, Y, Z]
|
751 |
+
|
752 |
+
sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100
|
753 |
+
|
754 |
+
mask = self.dense_volume_mask_lod0.view(-1) > 0
|
755 |
+
|
756 |
+
pts_valid = pts[mask].to(dense_volume.device)
|
757 |
+
feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask]
|
758 |
+
|
759 |
+
pts_valid = pts_valid.split(chunk_size)
|
760 |
+
feature_valid = feature_valid.split(chunk_size)
|
761 |
+
|
762 |
+
sdf_list = []
|
763 |
+
|
764 |
+
for pts_part, feature_part in zip(pts_valid, feature_valid):
|
765 |
+
sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1]
|
766 |
+
sdf_list.append(sdf_part)
|
767 |
+
|
768 |
+
sdf_list = torch.cat(sdf_list, dim=0)
|
769 |
+
|
770 |
+
sdf_volume[mask] = sdf_list
|
771 |
+
|
772 |
+
occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1]
|
773 |
+
|
774 |
+
# - dilate
|
775 |
+
occupancy_mask = occupancy_mask.float()
|
776 |
+
occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ)
|
777 |
+
occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3)
|
778 |
+
occupancy_mask = occupancy_mask > 0
|
779 |
+
|
780 |
+
self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0,
|
781 |
+
occupancy_mask).float() # (1, 1, dX, dY, dZ)
|
782 |
+
|
783 |
+
|
784 |
+
class BlendingRenderingNetwork(nn.Module):
|
785 |
+
def __init__(
|
786 |
+
self,
|
787 |
+
d_feature,
|
788 |
+
mode,
|
789 |
+
d_in,
|
790 |
+
d_out,
|
791 |
+
d_hidden,
|
792 |
+
n_layers,
|
793 |
+
weight_norm=True,
|
794 |
+
multires_view=0,
|
795 |
+
squeeze_out=True,
|
796 |
+
):
|
797 |
+
super(BlendingRenderingNetwork, self).__init__()
|
798 |
+
|
799 |
+
self.mode = mode
|
800 |
+
self.squeeze_out = squeeze_out
|
801 |
+
dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
|
802 |
+
|
803 |
+
self.embedder = None
|
804 |
+
if multires_view > 0:
|
805 |
+
self.embedder = Embedding(3, multires_view)
|
806 |
+
dims[0] += (self.embedder.out_channels - 3)
|
807 |
+
|
808 |
+
self.num_layers = len(dims)
|
809 |
+
|
810 |
+
for l in range(0, self.num_layers - 1):
|
811 |
+
out_dim = dims[l + 1]
|
812 |
+
lin = nn.Linear(dims[l], out_dim)
|
813 |
+
|
814 |
+
if weight_norm:
|
815 |
+
lin = nn.utils.weight_norm(lin)
|
816 |
+
|
817 |
+
setattr(self, "lin" + str(l), lin)
|
818 |
+
|
819 |
+
self.relu = nn.ReLU()
|
820 |
+
|
821 |
+
self.color_volume = None
|
822 |
+
|
823 |
+
self.softmax = nn.Softmax(dim=1)
|
824 |
+
|
825 |
+
self.type = 'blending'
|
826 |
+
|
827 |
+
def sample_pts_from_colorVolume(self, pts):
|
828 |
+
device = pts.device
|
829 |
+
num_pts = pts.shape[0]
|
830 |
+
pts_ = pts.clone()
|
831 |
+
pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1)
|
832 |
+
|
833 |
+
pts = torch.flip(pts, dims=[-1])
|
834 |
+
|
835 |
+
sampled_color = grid_sample_3d(self.color_volume, pts) # [1, c, 1, 1, num_pts]
|
836 |
+
sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device)
|
837 |
+
|
838 |
+
return sampled_color
|
839 |
+
|
840 |
+
def forward(self, position, normals, view_dirs, feature_vectors, img_index,
|
841 |
+
pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None):
|
842 |
+
"""
|
843 |
+
|
844 |
+
:param position: can be 3d coord or interpolated volume latent
|
845 |
+
:param normals:
|
846 |
+
:param view_dirs:
|
847 |
+
:param feature_vectors:
|
848 |
+
:param img_index: [N_views], used to extract corresponding weights
|
849 |
+
:param pts_pixel_color: [N_pts, N_views, 3]
|
850 |
+
:param pts_pixel_mask: [N_pts, N_views]
|
851 |
+
:param pts_patch_color: [N_pts, N_views, Npx, 3]
|
852 |
+
:return:
|
853 |
+
"""
|
854 |
+
if self.embedder is not None:
|
855 |
+
view_dirs = self.embedder(view_dirs)
|
856 |
+
|
857 |
+
rendering_input = None
|
858 |
+
|
859 |
+
if self.mode == 'idr':
|
860 |
+
rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1)
|
861 |
+
elif self.mode == 'no_view_dir':
|
862 |
+
rendering_input = torch.cat([position, normals, feature_vectors], dim=-1)
|
863 |
+
elif self.mode == 'no_normal':
|
864 |
+
rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1)
|
865 |
+
elif self.mode == 'no_points':
|
866 |
+
rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1)
|
867 |
+
elif self.mode == 'no_points_no_view_dir':
|
868 |
+
rendering_input = torch.cat([normals, feature_vectors], dim=-1)
|
869 |
+
|
870 |
+
x = rendering_input
|
871 |
+
|
872 |
+
for l in range(0, self.num_layers - 1):
|
873 |
+
lin = getattr(self, "lin" + str(l))
|
874 |
+
|
875 |
+
x = lin(x)
|
876 |
+
|
877 |
+
if l < self.num_layers - 2:
|
878 |
+
x = self.relu(x) # [n_pts, d_out]
|
879 |
+
|
880 |
+
## extract value based on img_index
|
881 |
+
x_extracted = torch.index_select(x, 1, img_index.long())
|
882 |
+
|
883 |
+
weights_pixel = self.softmax(x_extracted) # [n_pts, N_views]
|
884 |
+
weights_pixel = weights_pixel * pts_pixel_mask
|
885 |
+
weights_pixel = weights_pixel / (
|
886 |
+
torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
|
887 |
+
final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1,
|
888 |
+
keepdim=False) # [N_pts, 3]
|
889 |
+
|
890 |
+
final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 # [N_pts, 1]
|
891 |
+
|
892 |
+
final_patch_color, final_patch_mask = None, None
|
893 |
+
# pts_patch_color [N_pts, N_views, Npx, 3]; pts_patch_mask [N_pts, N_views, Npx]
|
894 |
+
if pts_patch_color is not None:
|
895 |
+
N_pts, N_views, Npx, _ = pts_patch_color.shape
|
896 |
+
patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 # [N_pts, N_views]
|
897 |
+
|
898 |
+
weights_patch = self.softmax(x_extracted) # [N_pts, N_views]
|
899 |
+
weights_patch = weights_patch * patch_mask
|
900 |
+
weights_patch = weights_patch / (
|
901 |
+
torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views]
|
902 |
+
|
903 |
+
final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1,
|
904 |
+
keepdim=False) # [N_pts, Npx, 3]
|
905 |
+
final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 # [N_pts, 1] at least one image sees
|
906 |
+
|
907 |
+
return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask
|
SparseNeuS_demo_v1/models/trainer_generic.py
ADDED
@@ -0,0 +1,1207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
decouple the trainer with the renderer
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import cv2 as cv
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import trimesh
|
12 |
+
from icecream import ic
|
13 |
+
|
14 |
+
from utils.misc_utils import visualize_depth_numpy
|
15 |
+
|
16 |
+
from loss.depth_metric import compute_depth_errors
|
17 |
+
|
18 |
+
from loss.depth_loss import DepthLoss, DepthSmoothLoss
|
19 |
+
|
20 |
+
from models.sparse_neus_renderer import SparseNeuSRenderer
|
21 |
+
|
22 |
+
class GenericTrainer(nn.Module):
|
23 |
+
def __init__(self,
|
24 |
+
rendering_network_outside,
|
25 |
+
pyramid_feature_network_lod0,
|
26 |
+
pyramid_feature_network_lod1,
|
27 |
+
sdf_network_lod0,
|
28 |
+
sdf_network_lod1,
|
29 |
+
variance_network_lod0,
|
30 |
+
variance_network_lod1,
|
31 |
+
rendering_network_lod0,
|
32 |
+
rendering_network_lod1,
|
33 |
+
n_samples_lod0,
|
34 |
+
n_importance_lod0,
|
35 |
+
n_samples_lod1,
|
36 |
+
n_importance_lod1,
|
37 |
+
n_outside,
|
38 |
+
perturb,
|
39 |
+
alpha_type='div',
|
40 |
+
conf=None,
|
41 |
+
timestamp="",
|
42 |
+
mode='train',
|
43 |
+
base_exp_dir=None,
|
44 |
+
):
|
45 |
+
super(GenericTrainer, self).__init__()
|
46 |
+
|
47 |
+
self.conf = conf
|
48 |
+
self.timestamp = timestamp
|
49 |
+
|
50 |
+
|
51 |
+
self.base_exp_dir = base_exp_dir
|
52 |
+
|
53 |
+
|
54 |
+
self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0)
|
55 |
+
self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
|
56 |
+
self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0)
|
57 |
+
self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0)
|
58 |
+
|
59 |
+
# network setups
|
60 |
+
self.rendering_network_outside = rendering_network_outside
|
61 |
+
self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry
|
62 |
+
self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods
|
63 |
+
|
64 |
+
# when num_lods==2, may consume too much memeory
|
65 |
+
self.sdf_network_lod0 = sdf_network_lod0
|
66 |
+
self.sdf_network_lod1 = sdf_network_lod1
|
67 |
+
|
68 |
+
# - warpped by ModuleList to support DataParallel
|
69 |
+
self.variance_network_lod0 = variance_network_lod0
|
70 |
+
self.variance_network_lod1 = variance_network_lod1
|
71 |
+
|
72 |
+
self.rendering_network_lod0 = rendering_network_lod0
|
73 |
+
self.rendering_network_lod1 = rendering_network_lod1
|
74 |
+
|
75 |
+
self.n_samples_lod0 = n_samples_lod0
|
76 |
+
self.n_importance_lod0 = n_importance_lod0
|
77 |
+
self.n_samples_lod1 = n_samples_lod1
|
78 |
+
self.n_importance_lod1 = n_importance_lod1
|
79 |
+
self.n_outside = n_outside
|
80 |
+
self.num_lods = conf.get_int('model.num_lods') # the number of octree lods
|
81 |
+
self.perturb = perturb
|
82 |
+
self.alpha_type = alpha_type
|
83 |
+
|
84 |
+
# - the two renderers
|
85 |
+
self.sdf_renderer_lod0 = SparseNeuSRenderer(
|
86 |
+
self.rendering_network_outside,
|
87 |
+
self.sdf_network_lod0,
|
88 |
+
self.variance_network_lod0,
|
89 |
+
self.rendering_network_lod0,
|
90 |
+
self.n_samples_lod0,
|
91 |
+
self.n_importance_lod0,
|
92 |
+
self.n_outside,
|
93 |
+
self.perturb,
|
94 |
+
alpha_type='div',
|
95 |
+
conf=self.conf)
|
96 |
+
|
97 |
+
self.sdf_renderer_lod1 = SparseNeuSRenderer(
|
98 |
+
self.rendering_network_outside,
|
99 |
+
self.sdf_network_lod1,
|
100 |
+
self.variance_network_lod1,
|
101 |
+
self.rendering_network_lod1,
|
102 |
+
self.n_samples_lod1,
|
103 |
+
self.n_importance_lod1,
|
104 |
+
self.n_outside,
|
105 |
+
self.perturb,
|
106 |
+
alpha_type='div',
|
107 |
+
conf=self.conf)
|
108 |
+
|
109 |
+
self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks')
|
110 |
+
|
111 |
+
# sdf network weights
|
112 |
+
self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight')
|
113 |
+
self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0)
|
114 |
+
self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100)
|
115 |
+
self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00)
|
116 |
+
self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0)
|
117 |
+
|
118 |
+
self.depth_criterion = DepthLoss()
|
119 |
+
|
120 |
+
# - DataParallel mode, cannot modify attributes in forward()
|
121 |
+
# self.iter_step = 0
|
122 |
+
self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
|
123 |
+
|
124 |
+
# - True for finetuning; False for general training
|
125 |
+
self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False)
|
126 |
+
|
127 |
+
self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False)
|
128 |
+
|
129 |
+
def get_trainable_params(self):
|
130 |
+
# set trainable params
|
131 |
+
|
132 |
+
self.params_to_train = []
|
133 |
+
|
134 |
+
if not self.if_fix_lod0_networks:
|
135 |
+
# load pretrained featurenet
|
136 |
+
self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters())
|
137 |
+
self.params_to_train += list(self.sdf_network_lod0.parameters())
|
138 |
+
self.params_to_train += list(self.variance_network_lod0.parameters())
|
139 |
+
|
140 |
+
if self.rendering_network_lod0 is not None:
|
141 |
+
self.params_to_train += list(self.rendering_network_lod0.parameters())
|
142 |
+
|
143 |
+
if self.sdf_network_lod1 is not None:
|
144 |
+
# load pretrained featurenet
|
145 |
+
self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters())
|
146 |
+
|
147 |
+
self.params_to_train += list(self.sdf_network_lod1.parameters())
|
148 |
+
self.params_to_train += list(self.variance_network_lod1.parameters())
|
149 |
+
if self.rendering_network_lod1 is not None:
|
150 |
+
self.params_to_train += list(self.rendering_network_lod1.parameters())
|
151 |
+
|
152 |
+
return self.params_to_train
|
153 |
+
|
154 |
+
def train_step(self, sample,
|
155 |
+
perturb_overwrite=-1,
|
156 |
+
background_rgb=None,
|
157 |
+
alpha_inter_ratio_lod0=0.0,
|
158 |
+
alpha_inter_ratio_lod1=0.0,
|
159 |
+
iter_step=0,
|
160 |
+
):
|
161 |
+
# * only support batch_size==1
|
162 |
+
# ! attention: the list of string cannot be splited in DataParallel
|
163 |
+
batch_idx = sample['batch_idx'][0]
|
164 |
+
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
|
165 |
+
|
166 |
+
sizeW = sample['img_wh'][0][0]
|
167 |
+
sizeH = sample['img_wh'][0][1]
|
168 |
+
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
|
169 |
+
near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:]
|
170 |
+
|
171 |
+
# the full-size ray variables
|
172 |
+
sample_rays = sample['rays']
|
173 |
+
rays_o = sample_rays['rays_o'][0]
|
174 |
+
rays_d = sample_rays['rays_v'][0]
|
175 |
+
|
176 |
+
imgs = sample['images'][0]
|
177 |
+
intrinsics = sample['intrinsics'][0]
|
178 |
+
intrinsics_l_4x = intrinsics.clone()
|
179 |
+
intrinsics_l_4x[:, :2] *= 0.25
|
180 |
+
w2cs = sample['w2cs'][0]
|
181 |
+
c2ws = sample['c2ws'][0]
|
182 |
+
proj_matrices = sample['affine_mats']
|
183 |
+
scale_mat = sample['scale_mat']
|
184 |
+
trans_mat = sample['trans_mat']
|
185 |
+
|
186 |
+
# *********************** Lod==0 ***********************
|
187 |
+
if not self.if_fix_lod0_networks:
|
188 |
+
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs)
|
189 |
+
|
190 |
+
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
|
191 |
+
feature_maps=geometry_feature_maps[None, 1:, :, :, :],
|
192 |
+
partial_vol_origin=partial_vol_origin,
|
193 |
+
proj_mats=proj_matrices[:,1:],
|
194 |
+
# proj_mats=proj_matrices,
|
195 |
+
sizeH=sizeH,
|
196 |
+
sizeW=sizeW,
|
197 |
+
lod=0,
|
198 |
+
)
|
199 |
+
|
200 |
+
else:
|
201 |
+
with torch.no_grad():
|
202 |
+
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
|
203 |
+
# geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
|
204 |
+
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
|
205 |
+
feature_maps=geometry_feature_maps[None, 1:, :, :, :],
|
206 |
+
partial_vol_origin=partial_vol_origin,
|
207 |
+
proj_mats=proj_matrices[:,1:],
|
208 |
+
# proj_mats=proj_matrices,
|
209 |
+
sizeH=sizeH,
|
210 |
+
sizeW=sizeW,
|
211 |
+
lod=0,
|
212 |
+
)
|
213 |
+
# print("Checker2:, construct cost volume")
|
214 |
+
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
|
215 |
+
|
216 |
+
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
|
217 |
+
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
|
218 |
+
|
219 |
+
# * extract depth maps for all the images
|
220 |
+
depth_maps_lod0, depth_masks_lod0 = None, None
|
221 |
+
if self.num_lods > 1:
|
222 |
+
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
|
223 |
+
con_volume_lod0, con_valid_mask_volume_lod0,
|
224 |
+
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
|
225 |
+
|
226 |
+
if self.prune_depth_filter:
|
227 |
+
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
|
228 |
+
self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws,
|
229 |
+
sizeH // 4, sizeW // 4, near * 1.5, far)
|
230 |
+
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
|
231 |
+
align_corners=True)
|
232 |
+
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
|
233 |
+
|
234 |
+
# *************** losses
|
235 |
+
loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None
|
236 |
+
|
237 |
+
if not self.if_fix_lod0_networks:
|
238 |
+
|
239 |
+
render_out = self.sdf_renderer_lod0.render(
|
240 |
+
rays_o, rays_d, near, far,
|
241 |
+
self.sdf_network_lod0,
|
242 |
+
self.rendering_network_lod0,
|
243 |
+
background_rgb=background_rgb,
|
244 |
+
alpha_inter_ratio=alpha_inter_ratio_lod0,
|
245 |
+
# * related to conditional feature
|
246 |
+
lod=0,
|
247 |
+
conditional_volume=con_volume_lod0,
|
248 |
+
conditional_valid_mask_volume=con_valid_mask_volume_lod0,
|
249 |
+
# * 2d feature maps
|
250 |
+
feature_maps=geometry_feature_maps,
|
251 |
+
color_maps=imgs,
|
252 |
+
w2cs=w2cs,
|
253 |
+
intrinsics=intrinsics,
|
254 |
+
img_wh=[sizeW, sizeH],
|
255 |
+
if_general_rendering=True,
|
256 |
+
if_render_with_grad=True,
|
257 |
+
)
|
258 |
+
|
259 |
+
loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays,
|
260 |
+
iter_step, lod=0)
|
261 |
+
|
262 |
+
# *********************** Lod==1 ***********************
|
263 |
+
|
264 |
+
loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None
|
265 |
+
|
266 |
+
if self.num_lods > 1:
|
267 |
+
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
|
268 |
+
# geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
|
269 |
+
if self.prune_depth_filter:
|
270 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
|
271 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
|
272 |
+
depth_maps_lod0, proj_matrices[0],
|
273 |
+
partial_vol_origin, self.sdf_network_lod0.voxel_size,
|
274 |
+
near, far, self.sdf_network_lod0.voxel_size, 12)
|
275 |
+
else:
|
276 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
|
277 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
|
278 |
+
|
279 |
+
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
|
280 |
+
|
281 |
+
# ? It seems that training gru_fusion, this part should be trainable too
|
282 |
+
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
|
283 |
+
feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :],
|
284 |
+
partial_vol_origin=partial_vol_origin,
|
285 |
+
proj_mats=proj_matrices[:,1:],
|
286 |
+
# proj_mats=proj_matrices,
|
287 |
+
sizeH=sizeH,
|
288 |
+
sizeW=sizeW,
|
289 |
+
pre_coords=pre_coords,
|
290 |
+
pre_feats=pre_feats,
|
291 |
+
)
|
292 |
+
|
293 |
+
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
|
294 |
+
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
|
295 |
+
|
296 |
+
# if not self.if_gru_fusion_lod1:
|
297 |
+
render_out_lod1 = self.sdf_renderer_lod1.render(
|
298 |
+
rays_o, rays_d, near, far,
|
299 |
+
self.sdf_network_lod1,
|
300 |
+
self.rendering_network_lod1,
|
301 |
+
background_rgb=background_rgb,
|
302 |
+
alpha_inter_ratio=alpha_inter_ratio_lod1,
|
303 |
+
# * related to conditional feature
|
304 |
+
lod=1,
|
305 |
+
conditional_volume=con_volume_lod1,
|
306 |
+
conditional_valid_mask_volume=con_valid_mask_volume_lod1,
|
307 |
+
# * 2d feature maps
|
308 |
+
feature_maps=geometry_feature_maps_lod1,
|
309 |
+
color_maps=imgs,
|
310 |
+
w2cs=w2cs,
|
311 |
+
intrinsics=intrinsics,
|
312 |
+
img_wh=[sizeW, sizeH],
|
313 |
+
bg_ratio=self.bg_ratio,
|
314 |
+
)
|
315 |
+
loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays,
|
316 |
+
iter_step, lod=1)
|
317 |
+
|
318 |
+
# print("Checker3:, compute losses")
|
319 |
+
# # - extract mesh
|
320 |
+
if iter_step % self.val_mesh_freq == 0:
|
321 |
+
torch.cuda.empty_cache()
|
322 |
+
self.validate_mesh(self.sdf_network_lod0,
|
323 |
+
self.sdf_renderer_lod0.extract_geometry,
|
324 |
+
conditional_volume=con_volume_lod0, lod=0,
|
325 |
+
threshold=0,
|
326 |
+
# occupancy_mask=con_valid_mask_volume_lod0[0, 0],
|
327 |
+
mode='train_bg', meta=meta,
|
328 |
+
iter_step=iter_step, scale_mat=scale_mat,
|
329 |
+
trans_mat=trans_mat)
|
330 |
+
torch.cuda.empty_cache()
|
331 |
+
|
332 |
+
if self.num_lods > 1:
|
333 |
+
self.validate_mesh(self.sdf_network_lod1,
|
334 |
+
self.sdf_renderer_lod1.extract_geometry,
|
335 |
+
conditional_volume=con_volume_lod1, lod=1,
|
336 |
+
# occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
|
337 |
+
mode='train_bg', meta=meta,
|
338 |
+
iter_step=iter_step, scale_mat=scale_mat,
|
339 |
+
trans_mat=trans_mat)
|
340 |
+
losses = {
|
341 |
+
# - lod 0
|
342 |
+
'loss_lod0': loss_lod0,
|
343 |
+
'losses_lod0': losses_lod0,
|
344 |
+
'depth_statis_lod0': depth_statis_lod0,
|
345 |
+
|
346 |
+
# - lod 1
|
347 |
+
'loss_lod1': loss_lod1,
|
348 |
+
'losses_lod1': losses_lod1,
|
349 |
+
'depth_statis_lod1': depth_statis_lod1,
|
350 |
+
|
351 |
+
}
|
352 |
+
|
353 |
+
return losses
|
354 |
+
|
355 |
+
def val_step(self, sample,
|
356 |
+
perturb_overwrite=-1,
|
357 |
+
background_rgb=None,
|
358 |
+
alpha_inter_ratio_lod0=0.0,
|
359 |
+
alpha_inter_ratio_lod1=0.0,
|
360 |
+
iter_step=0,
|
361 |
+
chunk_size=512,
|
362 |
+
save_vis=False,
|
363 |
+
):
|
364 |
+
# * only support batch_size==1
|
365 |
+
# ! attention: the list of string cannot be splited in DataParallel
|
366 |
+
batch_idx = sample['batch_idx'][0]
|
367 |
+
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
|
368 |
+
|
369 |
+
sizeW = sample['img_wh'][0][0]
|
370 |
+
sizeH = sample['img_wh'][0][1]
|
371 |
+
H, W = sizeH, sizeW
|
372 |
+
|
373 |
+
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
|
374 |
+
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
|
375 |
+
|
376 |
+
# the ray variables
|
377 |
+
sample_rays = sample['rays']
|
378 |
+
rays_o = sample_rays['rays_o'][0]
|
379 |
+
rays_d = sample_rays['rays_v'][0]
|
380 |
+
rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
|
381 |
+
|
382 |
+
imgs = sample['images'][0]
|
383 |
+
intrinsics = sample['intrinsics'][0]
|
384 |
+
intrinsics_l_4x = intrinsics.clone()
|
385 |
+
intrinsics_l_4x[:, :2] *= 0.25
|
386 |
+
w2cs = sample['w2cs'][0]
|
387 |
+
c2ws = sample['c2ws'][0]
|
388 |
+
proj_matrices = sample['affine_mats']
|
389 |
+
|
390 |
+
# render_img_idx = sample['render_img_idx'][0]
|
391 |
+
# true_img = sample['images'][0][render_img_idx]
|
392 |
+
|
393 |
+
# - the image to render
|
394 |
+
scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
|
395 |
+
trans_mat = sample['trans_mat']
|
396 |
+
query_c2w = sample['query_c2w'] # [1,4,4]
|
397 |
+
query_w2c = sample['query_w2c'] # [1,4,4]
|
398 |
+
true_img = sample['query_image'][0]
|
399 |
+
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
|
400 |
+
|
401 |
+
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
|
402 |
+
|
403 |
+
scale_factor = sample['scale_factor'][0].cpu().numpy()
|
404 |
+
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
|
405 |
+
if true_depth is not None:
|
406 |
+
true_depth = true_depth[0].cpu().numpy()
|
407 |
+
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
|
408 |
+
else:
|
409 |
+
true_depth_colored = None
|
410 |
+
|
411 |
+
rays_o = rays_o.reshape(-1, 3).split(chunk_size)
|
412 |
+
rays_d = rays_d.reshape(-1, 3).split(chunk_size)
|
413 |
+
|
414 |
+
# - obtain conditional features
|
415 |
+
with torch.no_grad():
|
416 |
+
# - obtain conditional features
|
417 |
+
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
|
418 |
+
# - lod 0
|
419 |
+
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
|
420 |
+
feature_maps=geometry_feature_maps[None, :, :, :, :],
|
421 |
+
partial_vol_origin=partial_vol_origin,
|
422 |
+
proj_mats=proj_matrices,
|
423 |
+
sizeH=sizeH,
|
424 |
+
sizeW=sizeW,
|
425 |
+
lod=0,
|
426 |
+
)
|
427 |
+
|
428 |
+
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
|
429 |
+
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
|
430 |
+
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
|
431 |
+
|
432 |
+
if self.num_lods > 1:
|
433 |
+
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
|
434 |
+
con_volume_lod0, con_valid_mask_volume_lod0,
|
435 |
+
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
|
436 |
+
|
437 |
+
depth_maps_lod0, depth_masks_lod0 = None, None
|
438 |
+
if self.prune_depth_filter:
|
439 |
+
depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps(
|
440 |
+
self.sdf_network_lod0, sdf_volume_lod0,
|
441 |
+
intrinsics_l_4x, c2ws,
|
442 |
+
sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number
|
443 |
+
depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear',
|
444 |
+
align_corners=True)
|
445 |
+
depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest')
|
446 |
+
|
447 |
+
#### visualize the depth_maps_lod0 for checking
|
448 |
+
colored_depth_maps_lod0 = []
|
449 |
+
for i in range(depth_maps_lod0.shape[0]):
|
450 |
+
colored_depth_maps_lod0.append(
|
451 |
+
visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0])
|
452 |
+
|
453 |
+
colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8)
|
454 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True)
|
455 |
+
cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0',
|
456 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)),
|
457 |
+
colored_depth_maps_lod0[:, :, ::-1])
|
458 |
+
|
459 |
+
if self.num_lods > 1:
|
460 |
+
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
|
461 |
+
|
462 |
+
if self.prune_depth_filter:
|
463 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
|
464 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
|
465 |
+
depth_maps_lod0, proj_matrices[0],
|
466 |
+
partial_vol_origin, self.sdf_network_lod0.voxel_size,
|
467 |
+
near, far, self.sdf_network_lod0.voxel_size, 12)
|
468 |
+
else:
|
469 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
|
470 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
|
471 |
+
|
472 |
+
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
|
473 |
+
|
474 |
+
with torch.no_grad():
|
475 |
+
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
|
476 |
+
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
|
477 |
+
partial_vol_origin=partial_vol_origin,
|
478 |
+
proj_mats=proj_matrices,
|
479 |
+
sizeH=sizeH,
|
480 |
+
sizeW=sizeW,
|
481 |
+
pre_coords=pre_coords,
|
482 |
+
pre_feats=pre_feats,
|
483 |
+
)
|
484 |
+
|
485 |
+
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
|
486 |
+
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
|
487 |
+
|
488 |
+
out_rgb_fine = []
|
489 |
+
out_normal_fine = []
|
490 |
+
out_depth_fine = []
|
491 |
+
|
492 |
+
out_rgb_fine_lod1 = []
|
493 |
+
out_normal_fine_lod1 = []
|
494 |
+
out_depth_fine_lod1 = []
|
495 |
+
|
496 |
+
# out_depth_fine_explicit = []
|
497 |
+
if save_vis:
|
498 |
+
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
|
499 |
+
|
500 |
+
# ****** lod 0 ****
|
501 |
+
render_out = self.sdf_renderer_lod0.render(
|
502 |
+
rays_o_batch, rays_d_batch, near, far,
|
503 |
+
self.sdf_network_lod0,
|
504 |
+
self.rendering_network_lod0,
|
505 |
+
background_rgb=background_rgb,
|
506 |
+
alpha_inter_ratio=alpha_inter_ratio_lod0,
|
507 |
+
# * related to conditional feature
|
508 |
+
lod=0,
|
509 |
+
conditional_volume=con_volume_lod0,
|
510 |
+
conditional_valid_mask_volume=con_valid_mask_volume_lod0,
|
511 |
+
# * 2d feature maps
|
512 |
+
feature_maps=geometry_feature_maps,
|
513 |
+
color_maps=imgs,
|
514 |
+
w2cs=w2cs,
|
515 |
+
intrinsics=intrinsics,
|
516 |
+
img_wh=[sizeW, sizeH],
|
517 |
+
query_c2w=query_c2w,
|
518 |
+
if_render_with_grad=False,
|
519 |
+
)
|
520 |
+
|
521 |
+
feasible = lambda key: ((key in render_out) and (render_out[key] is not None))
|
522 |
+
|
523 |
+
if feasible('depth'):
|
524 |
+
out_depth_fine.append(render_out['depth'].detach().cpu().numpy())
|
525 |
+
|
526 |
+
# if render_out['color_coarse'] is not None:
|
527 |
+
if feasible('color_fine'):
|
528 |
+
out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
|
529 |
+
if feasible('gradients') and feasible('weights'):
|
530 |
+
if render_out['inside_sphere'] is not None:
|
531 |
+
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
|
532 |
+
:self.n_samples_lod0 + self.n_importance_lod0,
|
533 |
+
None] * render_out['inside_sphere'][
|
534 |
+
..., None]).sum(dim=1).detach().cpu().numpy())
|
535 |
+
else:
|
536 |
+
out_normal_fine.append((render_out['gradients'] * render_out['weights'][:,
|
537 |
+
:self.n_samples_lod0 + self.n_importance_lod0,
|
538 |
+
None]).sum(dim=1).detach().cpu().numpy())
|
539 |
+
del render_out
|
540 |
+
|
541 |
+
# ****************** lod 1 **************************
|
542 |
+
if self.num_lods > 1:
|
543 |
+
for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
|
544 |
+
render_out_lod1 = self.sdf_renderer_lod1.render(
|
545 |
+
rays_o_batch, rays_d_batch, near, far,
|
546 |
+
self.sdf_network_lod1,
|
547 |
+
self.rendering_network_lod1,
|
548 |
+
background_rgb=background_rgb,
|
549 |
+
alpha_inter_ratio=alpha_inter_ratio_lod1,
|
550 |
+
# * related to conditional feature
|
551 |
+
lod=1,
|
552 |
+
conditional_volume=con_volume_lod1,
|
553 |
+
conditional_valid_mask_volume=con_valid_mask_volume_lod1,
|
554 |
+
# * 2d feature maps
|
555 |
+
feature_maps=geometry_feature_maps_lod1,
|
556 |
+
color_maps=imgs,
|
557 |
+
w2cs=w2cs,
|
558 |
+
intrinsics=intrinsics,
|
559 |
+
img_wh=[sizeW, sizeH],
|
560 |
+
query_c2w=query_c2w,
|
561 |
+
if_render_with_grad=False,
|
562 |
+
)
|
563 |
+
|
564 |
+
feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None))
|
565 |
+
|
566 |
+
if feasible('depth'):
|
567 |
+
out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy())
|
568 |
+
|
569 |
+
# if render_out['color_coarse'] is not None:
|
570 |
+
if feasible('color_fine'):
|
571 |
+
out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy())
|
572 |
+
if feasible('gradients') and feasible('weights'):
|
573 |
+
if render_out_lod1['inside_sphere'] is not None:
|
574 |
+
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
|
575 |
+
:self.n_samples_lod1 + self.n_importance_lod1,
|
576 |
+
None] *
|
577 |
+
render_out_lod1['inside_sphere'][
|
578 |
+
..., None]).sum(dim=1).detach().cpu().numpy())
|
579 |
+
else:
|
580 |
+
out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:,
|
581 |
+
:self.n_samples_lod1 + self.n_importance_lod1,
|
582 |
+
None]).sum(
|
583 |
+
dim=1).detach().cpu().numpy())
|
584 |
+
del render_out_lod1
|
585 |
+
|
586 |
+
# - save visualization of lod 0
|
587 |
+
|
588 |
+
self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine,
|
589 |
+
query_w2c[0], out_rgb_fine, H, W,
|
590 |
+
depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor)
|
591 |
+
|
592 |
+
if self.num_lods > 1:
|
593 |
+
self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1,
|
594 |
+
query_w2c[0], out_rgb_fine_lod1, H, W,
|
595 |
+
depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor)
|
596 |
+
|
597 |
+
# - extract mesh
|
598 |
+
if (iter_step % self.val_mesh_freq == 0):
|
599 |
+
torch.cuda.empty_cache()
|
600 |
+
self.validate_mesh(self.sdf_network_lod0,
|
601 |
+
self.sdf_renderer_lod0.extract_geometry,
|
602 |
+
conditional_volume=con_volume_lod0, lod=0,
|
603 |
+
threshold=0,
|
604 |
+
# occupancy_mask=con_valid_mask_volume_lod0[0, 0],
|
605 |
+
mode='val_bg', meta=meta,
|
606 |
+
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
|
607 |
+
torch.cuda.empty_cache()
|
608 |
+
|
609 |
+
if self.num_lods > 1:
|
610 |
+
self.validate_mesh(self.sdf_network_lod1,
|
611 |
+
self.sdf_renderer_lod1.extract_geometry,
|
612 |
+
conditional_volume=con_volume_lod1, lod=1,
|
613 |
+
# occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(),
|
614 |
+
mode='val_bg', meta=meta,
|
615 |
+
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat)
|
616 |
+
|
617 |
+
torch.cuda.empty_cache()
|
618 |
+
|
619 |
+
|
620 |
+
|
621 |
+
def export_mesh_step(self, sample,
|
622 |
+
perturb_overwrite=-1,
|
623 |
+
background_rgb=None,
|
624 |
+
alpha_inter_ratio_lod0=0.0,
|
625 |
+
alpha_inter_ratio_lod1=0.0,
|
626 |
+
iter_step=0,
|
627 |
+
chunk_size=512,
|
628 |
+
save_vis=False,
|
629 |
+
):
|
630 |
+
# * only support batch_size==1
|
631 |
+
# ! attention: the list of string cannot be splited in DataParallel
|
632 |
+
batch_idx = sample['batch_idx'][0]
|
633 |
+
meta = sample['meta'][batch_idx] # the scan lighting ref_view info
|
634 |
+
|
635 |
+
sizeW = sample['img_wh'][0][0]
|
636 |
+
sizeH = sample['img_wh'][0][1]
|
637 |
+
H, W = sizeH, sizeW
|
638 |
+
|
639 |
+
partial_vol_origin = sample['partial_vol_origin'] # [B, 3]
|
640 |
+
near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:]
|
641 |
+
|
642 |
+
# the ray variables
|
643 |
+
sample_rays = sample['rays']
|
644 |
+
rays_o = sample_rays['rays_o'][0]
|
645 |
+
rays_d = sample_rays['rays_v'][0]
|
646 |
+
rays_ndc_uv = sample_rays['rays_ndc_uv'][0]
|
647 |
+
|
648 |
+
imgs = sample['images'][0]
|
649 |
+
intrinsics = sample['intrinsics'][0]
|
650 |
+
intrinsics_l_4x = intrinsics.clone()
|
651 |
+
intrinsics_l_4x[:, :2] *= 0.25
|
652 |
+
w2cs = sample['w2cs'][0]
|
653 |
+
c2ws = sample['c2ws'][0]
|
654 |
+
# target_candidate_w2cs = sample['target_candidate_w2cs'][0]
|
655 |
+
proj_matrices = sample['affine_mats']
|
656 |
+
|
657 |
+
|
658 |
+
# - the image to render
|
659 |
+
scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale
|
660 |
+
trans_mat = sample['trans_mat']
|
661 |
+
query_c2w = sample['query_c2w'] # [1,4,4]
|
662 |
+
query_w2c = sample['query_w2c'] # [1,4,4]
|
663 |
+
true_img = sample['query_image'][0]
|
664 |
+
true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255)
|
665 |
+
|
666 |
+
depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy()
|
667 |
+
|
668 |
+
scale_factor = sample['scale_factor'][0].cpu().numpy()
|
669 |
+
true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None
|
670 |
+
if true_depth is not None:
|
671 |
+
true_depth = true_depth[0].cpu().numpy()
|
672 |
+
true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0]
|
673 |
+
else:
|
674 |
+
true_depth_colored = None
|
675 |
+
|
676 |
+
rays_o = rays_o.reshape(-1, 3).split(chunk_size)
|
677 |
+
rays_d = rays_d.reshape(-1, 3).split(chunk_size)
|
678 |
+
# import time
|
679 |
+
# jha_begin1 = time.time()
|
680 |
+
# - obtain conditional features
|
681 |
+
with torch.no_grad():
|
682 |
+
# - obtain conditional features
|
683 |
+
geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0)
|
684 |
+
# - lod 0
|
685 |
+
conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume(
|
686 |
+
feature_maps=geometry_feature_maps[None, :, :, :, :],
|
687 |
+
partial_vol_origin=partial_vol_origin,
|
688 |
+
proj_mats=proj_matrices,
|
689 |
+
sizeH=sizeH,
|
690 |
+
sizeW=sizeW,
|
691 |
+
lod=0,
|
692 |
+
)
|
693 |
+
# jha_end1 = time.time()
|
694 |
+
# print("get_conditional_volume: ", jha_end1 - jha_begin1)
|
695 |
+
# jha_begin2 = time.time()
|
696 |
+
con_volume_lod0 = conditional_features_lod0['dense_volume_scale0']
|
697 |
+
con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0']
|
698 |
+
coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ]
|
699 |
+
|
700 |
+
if self.num_lods > 1:
|
701 |
+
sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume(
|
702 |
+
con_volume_lod0, con_valid_mask_volume_lod0,
|
703 |
+
coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ]
|
704 |
+
|
705 |
+
depth_maps_lod0, depth_masks_lod0 = None, None
|
706 |
+
|
707 |
+
|
708 |
+
if self.num_lods > 1:
|
709 |
+
geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1)
|
710 |
+
|
711 |
+
if self.prune_depth_filter:
|
712 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter(
|
713 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0],
|
714 |
+
depth_maps_lod0, proj_matrices[0],
|
715 |
+
partial_vol_origin, self.sdf_network_lod0.voxel_size,
|
716 |
+
near, far, self.sdf_network_lod0.voxel_size, 12)
|
717 |
+
else:
|
718 |
+
pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf(
|
719 |
+
sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0])
|
720 |
+
|
721 |
+
pre_coords[:, 1:] = pre_coords[:, 1:] * 2
|
722 |
+
|
723 |
+
with torch.no_grad():
|
724 |
+
conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume(
|
725 |
+
feature_maps=geometry_feature_maps_lod1[None, :, :, :, :],
|
726 |
+
partial_vol_origin=partial_vol_origin,
|
727 |
+
proj_mats=proj_matrices,
|
728 |
+
sizeH=sizeH,
|
729 |
+
sizeW=sizeW,
|
730 |
+
pre_coords=pre_coords,
|
731 |
+
pre_feats=pre_feats,
|
732 |
+
)
|
733 |
+
|
734 |
+
con_volume_lod1 = conditional_features_lod1['dense_volume_scale1']
|
735 |
+
con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1']
|
736 |
+
|
737 |
+
out_rgb_fine = []
|
738 |
+
out_normal_fine = []
|
739 |
+
out_depth_fine = []
|
740 |
+
|
741 |
+
out_rgb_fine_lod1 = []
|
742 |
+
out_normal_fine_lod1 = []
|
743 |
+
out_depth_fine_lod1 = []
|
744 |
+
|
745 |
+
# jha_end2 = time.time()
|
746 |
+
# print("interval before starting mesh export: ", jha_end2 - jha_begin2)
|
747 |
+
|
748 |
+
# - extract mesh
|
749 |
+
if (iter_step % self.val_mesh_freq == 0):
|
750 |
+
torch.cuda.empty_cache()
|
751 |
+
# jha_begin3 = time.time()
|
752 |
+
self.validate_colored_mesh(
|
753 |
+
density_or_sdf_network=self.sdf_network_lod0,
|
754 |
+
func_extract_geometry=self.sdf_renderer_lod0.extract_geometry,
|
755 |
+
conditional_volume=con_volume_lod0,
|
756 |
+
conditional_valid_mask_volume = con_valid_mask_volume_lod0,
|
757 |
+
feature_maps=geometry_feature_maps,
|
758 |
+
color_maps=imgs,
|
759 |
+
w2cs=w2cs,
|
760 |
+
target_candidate_w2cs=None,
|
761 |
+
intrinsics=intrinsics,
|
762 |
+
rendering_network=self.rendering_network_lod0,
|
763 |
+
rendering_projector=self.sdf_renderer_lod0.rendering_projector,
|
764 |
+
lod=0,
|
765 |
+
threshold=0,
|
766 |
+
query_c2w=query_c2w,
|
767 |
+
mode='val_bg', meta=meta,
|
768 |
+
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
|
769 |
+
)
|
770 |
+
torch.cuda.empty_cache()
|
771 |
+
# jha_end3 = time.time()
|
772 |
+
# print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3)
|
773 |
+
|
774 |
+
if self.num_lods > 1:
|
775 |
+
self.validate_colored_mesh(
|
776 |
+
density_or_sdf_network=self.sdf_network_lod1,
|
777 |
+
func_extract_geometry=self.sdf_renderer_lod1.extract_geometry,
|
778 |
+
conditional_volume=con_volume_lod1,
|
779 |
+
conditional_valid_mask_volume = con_valid_mask_volume_lod1,
|
780 |
+
feature_maps=geometry_feature_maps,
|
781 |
+
color_maps=imgs,
|
782 |
+
w2cs=w2cs,
|
783 |
+
target_candidate_w2cs=None,
|
784 |
+
intrinsics=intrinsics,
|
785 |
+
rendering_network=self.rendering_network_lod1,
|
786 |
+
rendering_projector=self.sdf_renderer_lod1.rendering_projector,
|
787 |
+
lod=1,
|
788 |
+
threshold=0,
|
789 |
+
query_c2w=query_c2w,
|
790 |
+
mode='val_bg', meta=meta,
|
791 |
+
iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat
|
792 |
+
)
|
793 |
+
torch.cuda.empty_cache()
|
794 |
+
|
795 |
+
|
796 |
+
|
797 |
+
def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W,
|
798 |
+
depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0):
|
799 |
+
if len(out_color) > 0:
|
800 |
+
img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
|
801 |
+
|
802 |
+
if len(out_color_mlp) > 0:
|
803 |
+
img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255)
|
804 |
+
|
805 |
+
if len(out_normal) > 0:
|
806 |
+
normal_img = np.concatenate(out_normal, axis=0)
|
807 |
+
rot = w2cs[:3, :3].detach().cpu().numpy()
|
808 |
+
# - convert normal from world space to camera space
|
809 |
+
normal_img = (np.matmul(rot[None, :, :],
|
810 |
+
normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255)
|
811 |
+
if len(out_depth) > 0:
|
812 |
+
pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W])
|
813 |
+
pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0]
|
814 |
+
|
815 |
+
if len(out_depth) > 0:
|
816 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True)
|
817 |
+
if true_colored_depth is not None:
|
818 |
+
|
819 |
+
if true_depth is not None:
|
820 |
+
depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor
|
821 |
+
# [256, 256, 1] -> [256, 256, 3]
|
822 |
+
depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3])
|
823 |
+
print("meta: ", meta)
|
824 |
+
print("scale_factor: ", scale_factor)
|
825 |
+
print("depth_error_mean: ", depth_error_map.mean())
|
826 |
+
depth_visualized = np.concatenate(
|
827 |
+
[(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1]
|
828 |
+
# print("depth_visualized.shape: ", depth_visualized.shape)
|
829 |
+
# write depth error result text on img, the input is a numpy array of [256, 1024, 3]
|
830 |
+
# cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
|
831 |
+
else:
|
832 |
+
depth_visualized = np.concatenate(
|
833 |
+
[true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1]
|
834 |
+
cv.imwrite(
|
835 |
+
os.path.join(self.base_exp_dir, 'depths_' + comment,
|
836 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized
|
837 |
+
)
|
838 |
+
else:
|
839 |
+
cv.imwrite(
|
840 |
+
os.path.join(self.base_exp_dir, 'depths_' + comment,
|
841 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)),
|
842 |
+
np.concatenate(
|
843 |
+
[pred_depth_colored, true_img])[:, :, ::-1])
|
844 |
+
if len(out_color) > 0:
|
845 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True)
|
846 |
+
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment,
|
847 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)),
|
848 |
+
np.concatenate(
|
849 |
+
[img_fine, true_img])[:, :, ::-1]) # bgr2rgb
|
850 |
+
# compute psnr (image pixel lie in [0, 255])
|
851 |
+
mse_loss = np.mean((img_fine - true_img) ** 2)
|
852 |
+
psnr = 10 * np.log10(255 ** 2 / mse_loss)
|
853 |
+
|
854 |
+
print("PSNR: ", psnr)
|
855 |
+
|
856 |
+
if len(out_color_mlp) > 0:
|
857 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True)
|
858 |
+
cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment,
|
859 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)),
|
860 |
+
np.concatenate(
|
861 |
+
[img_mlp, true_img])[:, :, ::-1]) # bgr2rgb
|
862 |
+
|
863 |
+
if len(out_normal) > 0:
|
864 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True)
|
865 |
+
cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment,
|
866 |
+
'{:0>8d}_{}.png'.format(iter_step, meta)),
|
867 |
+
normal_img[:, :, ::-1])
|
868 |
+
|
869 |
+
def forward(self, sample,
|
870 |
+
perturb_overwrite=-1,
|
871 |
+
background_rgb=None,
|
872 |
+
alpha_inter_ratio_lod0=0.0,
|
873 |
+
alpha_inter_ratio_lod1=0.0,
|
874 |
+
iter_step=0,
|
875 |
+
mode='train',
|
876 |
+
save_vis=False,
|
877 |
+
):
|
878 |
+
|
879 |
+
if mode == 'train':
|
880 |
+
return self.train_step(sample,
|
881 |
+
perturb_overwrite=perturb_overwrite,
|
882 |
+
background_rgb=background_rgb,
|
883 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
884 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
885 |
+
iter_step=iter_step
|
886 |
+
)
|
887 |
+
elif mode == 'val':
|
888 |
+
import time
|
889 |
+
begin = time.time()
|
890 |
+
result = self.val_step(sample,
|
891 |
+
perturb_overwrite=perturb_overwrite,
|
892 |
+
background_rgb=background_rgb,
|
893 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
894 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
895 |
+
iter_step=iter_step,
|
896 |
+
save_vis=save_vis,
|
897 |
+
)
|
898 |
+
end = time.time()
|
899 |
+
print("val_step time: ", end - begin)
|
900 |
+
return result
|
901 |
+
elif mode == 'export_mesh':
|
902 |
+
import time
|
903 |
+
begin = time.time()
|
904 |
+
result = self.export_mesh_step(sample,
|
905 |
+
perturb_overwrite=perturb_overwrite,
|
906 |
+
background_rgb=background_rgb,
|
907 |
+
alpha_inter_ratio_lod0=alpha_inter_ratio_lod0,
|
908 |
+
alpha_inter_ratio_lod1=alpha_inter_ratio_lod1,
|
909 |
+
iter_step=iter_step,
|
910 |
+
save_vis=save_vis,
|
911 |
+
)
|
912 |
+
end = time.time()
|
913 |
+
print("export mesh time: ", end - begin)
|
914 |
+
return result
|
915 |
+
def obtain_pyramid_feature_maps(self, imgs, lod=0):
|
916 |
+
"""
|
917 |
+
get feature maps of all conditional images
|
918 |
+
:param imgs:
|
919 |
+
:return:
|
920 |
+
"""
|
921 |
+
|
922 |
+
if lod == 0:
|
923 |
+
extractor = self.pyramid_feature_network_geometry_lod0
|
924 |
+
elif lod >= 1:
|
925 |
+
extractor = self.pyramid_feature_network_geometry_lod1
|
926 |
+
|
927 |
+
pyramid_feature_maps = extractor(imgs)
|
928 |
+
|
929 |
+
# * the pyramid features are very important, if only use the coarst features, hard to optimize
|
930 |
+
fused_feature_maps = torch.cat([
|
931 |
+
F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True),
|
932 |
+
F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True),
|
933 |
+
pyramid_feature_maps[2]
|
934 |
+
], dim=1)
|
935 |
+
|
936 |
+
return fused_feature_maps
|
937 |
+
|
938 |
+
def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0):
|
939 |
+
|
940 |
+
# loss weight schedule; the regularization terms should be added in later training stage
|
941 |
+
def get_weight(iter_step, weight):
|
942 |
+
if lod == 1:
|
943 |
+
anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1
|
944 |
+
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
|
945 |
+
anneal_end = anneal_end * 2
|
946 |
+
else:
|
947 |
+
anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1
|
948 |
+
anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1
|
949 |
+
anneal_end = anneal_end * 2
|
950 |
+
|
951 |
+
if iter_step < 0:
|
952 |
+
return weight
|
953 |
+
|
954 |
+
if anneal_end == 0.0:
|
955 |
+
return weight
|
956 |
+
elif iter_step < anneal_start:
|
957 |
+
return 0.0
|
958 |
+
else:
|
959 |
+
return np.min(
|
960 |
+
[1.0,
|
961 |
+
(iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight
|
962 |
+
|
963 |
+
rays_o = sample_rays['rays_o'][0]
|
964 |
+
rays_d = sample_rays['rays_v'][0]
|
965 |
+
true_rgb = sample_rays['rays_color'][0]
|
966 |
+
|
967 |
+
if 'rays_depth' in sample_rays.keys():
|
968 |
+
true_depth = sample_rays['rays_depth'][0]
|
969 |
+
else:
|
970 |
+
true_depth = None
|
971 |
+
mask = sample_rays['rays_mask'][0]
|
972 |
+
|
973 |
+
color_fine = render_out['color_fine']
|
974 |
+
color_fine_mask = render_out['color_fine_mask']
|
975 |
+
depth_pred = render_out['depth']
|
976 |
+
|
977 |
+
variance = render_out['variance']
|
978 |
+
cdf_fine = render_out['cdf_fine']
|
979 |
+
weight_sum = render_out['weights_sum']
|
980 |
+
|
981 |
+
gradient_error_fine = render_out['gradient_error_fine']
|
982 |
+
|
983 |
+
sdf = render_out['sdf']
|
984 |
+
|
985 |
+
# * color generated by mlp
|
986 |
+
color_mlp = render_out['color_mlp']
|
987 |
+
color_mlp_mask = render_out['color_mlp_mask']
|
988 |
+
|
989 |
+
if color_fine is not None:
|
990 |
+
# Color loss
|
991 |
+
color_mask = color_fine_mask if color_fine_mask is not None else mask
|
992 |
+
color_mask = color_mask[..., 0]
|
993 |
+
color_error = (color_fine[color_mask] - true_rgb[color_mask])
|
994 |
+
# print("Nan number", torch.isnan(color_error).sum())
|
995 |
+
# print("Color error shape", color_error.shape)
|
996 |
+
color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device),
|
997 |
+
reduction='mean')
|
998 |
+
# print(color_fine_loss)
|
999 |
+
psnr = 20.0 * torch.log10(
|
1000 |
+
1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt())
|
1001 |
+
else:
|
1002 |
+
color_fine_loss = 0.
|
1003 |
+
psnr = 0.
|
1004 |
+
|
1005 |
+
if color_mlp is not None:
|
1006 |
+
# Color loss
|
1007 |
+
color_mlp_mask = color_mlp_mask[..., 0]
|
1008 |
+
color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask])
|
1009 |
+
color_mlp_loss = F.l1_loss(color_error_mlp,
|
1010 |
+
torch.zeros_like(color_error_mlp).to(color_error_mlp.device),
|
1011 |
+
reduction='mean')
|
1012 |
+
|
1013 |
+
psnr_mlp = 20.0 * torch.log10(
|
1014 |
+
1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt())
|
1015 |
+
else:
|
1016 |
+
color_mlp_loss = 0.
|
1017 |
+
psnr_mlp = 0.
|
1018 |
+
|
1019 |
+
# depth loss is only used for inference, not included in total loss
|
1020 |
+
if true_depth is not None:
|
1021 |
+
# depth_loss = self.depth_criterion(depth_pred, true_depth, mask)
|
1022 |
+
depth_loss = self.depth_criterion(depth_pred, true_depth)
|
1023 |
+
|
1024 |
+
# # depth evaluation
|
1025 |
+
# depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy())
|
1026 |
+
# depth_statis = numpy2tensor(depth_statis, device=rays_o.device)
|
1027 |
+
depth_statis = None
|
1028 |
+
else:
|
1029 |
+
depth_loss = 0.
|
1030 |
+
depth_statis = None
|
1031 |
+
|
1032 |
+
sparse_loss_1 = torch.exp(
|
1033 |
+
-1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal
|
1034 |
+
sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean()
|
1035 |
+
sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2
|
1036 |
+
|
1037 |
+
sdf_mean = torch.abs(sdf).mean()
|
1038 |
+
sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean()
|
1039 |
+
sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean()
|
1040 |
+
|
1041 |
+
# Eikonal loss
|
1042 |
+
gradient_error_loss = gradient_error_fine
|
1043 |
+
|
1044 |
+
# ! the first 50k, don't use bg constraint
|
1045 |
+
fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight)
|
1046 |
+
|
1047 |
+
# Mask loss, optional
|
1048 |
+
# The images of DTU dataset contain large black regions (0 rgb values),
|
1049 |
+
# can use this data prior to make fg more clean
|
1050 |
+
background_loss = 0.0
|
1051 |
+
fg_bg_loss = 0.0
|
1052 |
+
if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02:
|
1053 |
+
weights_sum_fg = render_out['weights_sum_fg']
|
1054 |
+
fg_bg_error = (weights_sum_fg - mask)[mask < 0.5]
|
1055 |
+
fg_bg_loss = F.l1_loss(fg_bg_error,
|
1056 |
+
torch.zeros_like(fg_bg_error).to(fg_bg_error.device),
|
1057 |
+
reduction='mean')
|
1058 |
+
|
1059 |
+
|
1060 |
+
|
1061 |
+
loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \
|
1062 |
+
sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \
|
1063 |
+
fg_bg_loss * fg_bg_weight + \
|
1064 |
+
gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask
|
1065 |
+
|
1066 |
+
losses = {
|
1067 |
+
"loss": loss,
|
1068 |
+
"depth_loss": depth_loss,
|
1069 |
+
"color_fine_loss": color_fine_loss,
|
1070 |
+
"color_mlp_loss": color_mlp_loss,
|
1071 |
+
"gradient_error_loss": gradient_error_loss,
|
1072 |
+
"background_loss": background_loss,
|
1073 |
+
"sparse_loss": sparse_loss,
|
1074 |
+
"sparseness_1": sparseness_1,
|
1075 |
+
"sparseness_2": sparseness_2,
|
1076 |
+
"sdf_mean": sdf_mean,
|
1077 |
+
"psnr": psnr,
|
1078 |
+
"psnr_mlp": psnr_mlp,
|
1079 |
+
"weights_sum": render_out['weights_sum'],
|
1080 |
+
"weights_sum_fg": render_out['weights_sum_fg'],
|
1081 |
+
"alpha_sum": render_out['alpha_sum'],
|
1082 |
+
"variance": render_out['variance'],
|
1083 |
+
"sparse_weight": get_weight(iter_step, self.sdf_sparse_weight),
|
1084 |
+
"fg_bg_weight": fg_bg_weight,
|
1085 |
+
"fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS
|
1086 |
+
}
|
1087 |
+
losses = torch.tensor(losses, device=rays_o.device)
|
1088 |
+
return loss, losses, depth_statis
|
1089 |
+
|
1090 |
+
@torch.no_grad()
|
1091 |
+
def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
|
1092 |
+
threshold=0.0, mode='val',
|
1093 |
+
# * 3d feature volume
|
1094 |
+
conditional_volume=None, lod=None, occupancy_mask=None,
|
1095 |
+
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
|
1096 |
+
trans_mat=None
|
1097 |
+
):
|
1098 |
+
|
1099 |
+
bound_min = torch.tensor(bound_min, dtype=torch.float32)
|
1100 |
+
bound_max = torch.tensor(bound_max, dtype=torch.float32)
|
1101 |
+
|
1102 |
+
vertices, triangles, fields = func_extract_geometry(
|
1103 |
+
density_or_sdf_network,
|
1104 |
+
bound_min, bound_max, resolution=resolution,
|
1105 |
+
threshold=threshold, device=conditional_volume.device,
|
1106 |
+
# * 3d feature volume
|
1107 |
+
conditional_volume=conditional_volume, lod=lod,
|
1108 |
+
occupancy_mask=occupancy_mask
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
|
1112 |
+
if scale_mat is not None:
|
1113 |
+
scale_mat_np = scale_mat.cpu().numpy()
|
1114 |
+
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
|
1115 |
+
|
1116 |
+
if trans_mat is not None: # w2c_ref_inv
|
1117 |
+
trans_mat_np = trans_mat.cpu().numpy()
|
1118 |
+
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
|
1119 |
+
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
|
1120 |
+
|
1121 |
+
mesh = trimesh.Trimesh(vertices, triangles)
|
1122 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True)
|
1123 |
+
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode,
|
1124 |
+
'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
|
1125 |
+
|
1126 |
+
|
1127 |
+
|
1128 |
+
def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360,
|
1129 |
+
threshold=0.0, mode='val',
|
1130 |
+
# * 3d feature volume
|
1131 |
+
conditional_volume=None,
|
1132 |
+
conditional_valid_mask_volume=None,
|
1133 |
+
feature_maps=None,
|
1134 |
+
color_maps = None,
|
1135 |
+
w2cs=None,
|
1136 |
+
target_candidate_w2cs=None,
|
1137 |
+
intrinsics=None,
|
1138 |
+
rendering_network=None,
|
1139 |
+
rendering_projector=None,
|
1140 |
+
query_c2w=None,
|
1141 |
+
lod=None, occupancy_mask=None,
|
1142 |
+
bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None,
|
1143 |
+
trans_mat=None
|
1144 |
+
):
|
1145 |
+
|
1146 |
+
bound_min = torch.tensor(bound_min, dtype=torch.float32)
|
1147 |
+
bound_max = torch.tensor(bound_max, dtype=torch.float32)
|
1148 |
+
# import time
|
1149 |
+
# jha_begin4 = time.time()
|
1150 |
+
vertices, triangles, fields = func_extract_geometry(
|
1151 |
+
density_or_sdf_network,
|
1152 |
+
bound_min, bound_max, resolution=resolution,
|
1153 |
+
threshold=threshold, device=conditional_volume.device,
|
1154 |
+
# * 3d feature volume
|
1155 |
+
conditional_volume=conditional_volume, lod=lod,
|
1156 |
+
occupancy_mask=occupancy_mask
|
1157 |
+
)
|
1158 |
+
# jha_end4 = time.time()
|
1159 |
+
# print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4)
|
1160 |
+
|
1161 |
+
# import time
|
1162 |
+
# jha_begin5 = time.time()
|
1163 |
+
|
1164 |
+
|
1165 |
+
with torch.no_grad():
|
1166 |
+
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent(
|
1167 |
+
torch.tensor(vertices).to(conditional_volume),
|
1168 |
+
lod=lod, # JHA EDITED
|
1169 |
+
# * 3d geometry feature volumes
|
1170 |
+
geometryVolume=conditional_volume[0],
|
1171 |
+
geometryVolumeMask=conditional_valid_mask_volume[0],
|
1172 |
+
sdf_network=density_or_sdf_network,
|
1173 |
+
# * 2d rendering feature maps
|
1174 |
+
rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256]
|
1175 |
+
color_maps=color_maps,
|
1176 |
+
w2cs=w2cs,
|
1177 |
+
target_candidate_w2cs=target_candidate_w2cs,
|
1178 |
+
intrinsics=intrinsics,
|
1179 |
+
img_wh=[256,256],
|
1180 |
+
query_img_idx=0, # the index of the N_views dim for rendering
|
1181 |
+
query_c2w=query_c2w,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
|
1185 |
+
vertices_color, rendering_valid_mask = rendering_network(
|
1186 |
+
ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask)
|
1187 |
+
|
1188 |
+
# jha_end5 = time.time()
|
1189 |
+
# print("[TEST]: rendering_network time", jha_end5 - jha_begin5)
|
1190 |
+
|
1191 |
+
if scale_mat is not None:
|
1192 |
+
scale_mat_np = scale_mat.cpu().numpy()
|
1193 |
+
vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None]
|
1194 |
+
|
1195 |
+
if trans_mat is not None: # w2c_ref_inv
|
1196 |
+
trans_mat_np = trans_mat.cpu().numpy()
|
1197 |
+
vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1)
|
1198 |
+
vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0]
|
1199 |
+
|
1200 |
+
vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8)
|
1201 |
+
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color)
|
1202 |
+
os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True)
|
1203 |
+
# mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
|
1204 |
+
# 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod)))
|
1205 |
+
# MODIFIED
|
1206 |
+
mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod),
|
1207 |
+
'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod)))
|
SparseNeuS_demo_v1/ops/__init__.py
ADDED
File without changes
|
SparseNeuS_demo_v1/ops/back_project.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.functional import grid_sample
|
3 |
+
|
4 |
+
|
5 |
+
def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False,
|
6 |
+
with_proj_z=False):
|
7 |
+
# - modified version from NeuRecon
|
8 |
+
'''
|
9 |
+
Unproject the image fetures to form a 3D (sparse) feature volume
|
10 |
+
|
11 |
+
:param coords: coordinates of voxels,
|
12 |
+
dim: (num of voxels, 4) (4 : batch ind, x, y, z)
|
13 |
+
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
|
14 |
+
dim: (batch size, 3) (3: x, y, z)
|
15 |
+
:param voxel_size: floats specifying the size of a voxel
|
16 |
+
:param feats: image features
|
17 |
+
dim: (num of views, batch size, C, H, W)
|
18 |
+
:param KRcam: projection matrix
|
19 |
+
dim: (num of views, batch size, 4, 4)
|
20 |
+
:return: feature_volume_all: 3D feature volumes
|
21 |
+
dim: (num of voxels, num_of_views, c)
|
22 |
+
:return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not
|
23 |
+
dim: (num of voxels, num_of_views)
|
24 |
+
'''
|
25 |
+
n_views, bs, c, h, w = feats.shape
|
26 |
+
device = feats.device
|
27 |
+
|
28 |
+
if sizeH is None:
|
29 |
+
sizeH, sizeW = h, w # - if the KRcam is not suitable for the current feats
|
30 |
+
|
31 |
+
feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device)
|
32 |
+
mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device)
|
33 |
+
# import ipdb; ipdb.set_trace()
|
34 |
+
for batch in range(bs):
|
35 |
+
# import ipdb; ipdb.set_trace()
|
36 |
+
batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1)
|
37 |
+
coords_batch = coords[batch_ind][:, 1:]
|
38 |
+
|
39 |
+
coords_batch = coords_batch.view(-1, 3)
|
40 |
+
origin_batch = origin[batch].unsqueeze(0)
|
41 |
+
feats_batch = feats[:, batch]
|
42 |
+
proj_batch = KRcam[:, batch]
|
43 |
+
|
44 |
+
grid_batch = coords_batch * voxel_size + origin_batch.float()
|
45 |
+
rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1)
|
46 |
+
rs_grid = rs_grid.permute(0, 2, 1).contiguous()
|
47 |
+
nV = rs_grid.shape[-1]
|
48 |
+
rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1)
|
49 |
+
|
50 |
+
# Project grid
|
51 |
+
im_p = proj_batch @ rs_grid # - transform world pts to image UV space
|
52 |
+
im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
|
53 |
+
|
54 |
+
im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6)
|
55 |
+
|
56 |
+
im_x = im_x / im_z
|
57 |
+
im_y = im_y / im_z
|
58 |
+
|
59 |
+
im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
|
60 |
+
mask = im_grid.abs() <= 1
|
61 |
+
mask = (mask.sum(dim=-1) == 2) & (im_z > 0)
|
62 |
+
|
63 |
+
mask = mask.view(n_views, -1)
|
64 |
+
mask = mask.permute(1, 0).contiguous() # [num_pts, nviews]
|
65 |
+
|
66 |
+
mask_volume_all[batch_ind] = mask.to(torch.int32)
|
67 |
+
|
68 |
+
if only_mask:
|
69 |
+
return mask_volume_all
|
70 |
+
|
71 |
+
feats_batch = feats_batch.view(n_views, c, h, w)
|
72 |
+
im_grid = im_grid.view(n_views, 1, -1, 2)
|
73 |
+
features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True)
|
74 |
+
# if features.isnan().sum() > 0:
|
75 |
+
# import ipdb; ipdb.set_trace()
|
76 |
+
features = features.view(n_views, c, -1)
|
77 |
+
features = features.permute(2, 0, 1).contiguous() # [num_pts, nviews, c]
|
78 |
+
|
79 |
+
feature_volume_all[batch_ind] = features
|
80 |
+
|
81 |
+
if with_proj_z:
|
82 |
+
im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() # [num_pts, nviews, 1]
|
83 |
+
return feature_volume_all, mask_volume_all, im_z
|
84 |
+
# if feature_volume_all.isnan().sum() > 0:
|
85 |
+
# import ipdb; ipdb.set_trace()
|
86 |
+
return feature_volume_all, mask_volume_all
|
87 |
+
|
88 |
+
|
89 |
+
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False):
|
90 |
+
"""Transform coordinates in the camera frame to the pixel frame.
|
91 |
+
Args:
|
92 |
+
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
|
93 |
+
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3]
|
94 |
+
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
|
95 |
+
Returns:
|
96 |
+
array of [-1,1] coordinates -- [B, H, W, 2]
|
97 |
+
"""
|
98 |
+
b, _, h, w = cam_coords.size()
|
99 |
+
if sizeH is None:
|
100 |
+
sizeH = h
|
101 |
+
sizeW = w
|
102 |
+
|
103 |
+
cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W]
|
104 |
+
if proj_c2p_rot is not None:
|
105 |
+
pcoords = proj_c2p_rot.bmm(cam_coords_flat)
|
106 |
+
else:
|
107 |
+
pcoords = cam_coords_flat
|
108 |
+
|
109 |
+
if proj_c2p_tr is not None:
|
110 |
+
pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
|
111 |
+
X = pcoords[:, 0]
|
112 |
+
Y = pcoords[:, 1]
|
113 |
+
Z = pcoords[:, 2].clamp(min=1e-3)
|
114 |
+
|
115 |
+
X_norm = 2 * (X / Z) / (sizeW - 1) - 1 # Normalized, -1 if on extreme left,
|
116 |
+
# 1 if on extreme right (x = w-1) [B, H*W]
|
117 |
+
Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 # Idem [B, H*W]
|
118 |
+
if padding_mode == 'zeros':
|
119 |
+
X_mask = ((X_norm > 1) + (X_norm < -1)).detach()
|
120 |
+
X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
|
121 |
+
Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach()
|
122 |
+
Y_norm[Y_mask] = 2
|
123 |
+
|
124 |
+
if with_depth:
|
125 |
+
pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) # [B, H*W, 3]
|
126 |
+
return pixel_coords.view(b, h, w, 3)
|
127 |
+
else:
|
128 |
+
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
|
129 |
+
return pixel_coords.view(b, h, w, 2)
|
130 |
+
|
131 |
+
|
132 |
+
# * have already checked, should check whether proj_matrix is for right coordinate system and resolution
|
133 |
+
def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None):
|
134 |
+
'''
|
135 |
+
Unproject the image fetures to form a 3D (dense) feature volume
|
136 |
+
|
137 |
+
:param coords: coordinates of voxels,
|
138 |
+
dim: (batch, nviews, 3, X,Y,Z)
|
139 |
+
:param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
|
140 |
+
dim: (batch size, 3) (3: x, y, z)
|
141 |
+
:param voxel_size: floats specifying the size of a voxel
|
142 |
+
:param feats: image features
|
143 |
+
dim: (batch size, num of views, C, H, W)
|
144 |
+
:param proj_matrix: projection matrix
|
145 |
+
dim: (batch size, num of views, 4, 4)
|
146 |
+
:return: feature_volume_all: 3D feature volumes
|
147 |
+
dim: (batch, nviews, C, X,Y,Z)
|
148 |
+
:return: count: number of times each voxel can be seen
|
149 |
+
dim: (batch, nviews, 1, X,Y,Z)
|
150 |
+
'''
|
151 |
+
|
152 |
+
batch, nviews, _, wX, wY, wZ = coords.shape
|
153 |
+
|
154 |
+
if sizeH is None:
|
155 |
+
sizeH, sizeW = feats.shape[-2:]
|
156 |
+
proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:])
|
157 |
+
|
158 |
+
coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1)
|
159 |
+
coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) # (b*nviews,3,wX*wY*wZ, 1)
|
160 |
+
|
161 |
+
pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
|
162 |
+
'zeros', sizeH=sizeH, sizeW=sizeW) # (b*nviews,wX*wY*wZ, 2)
|
163 |
+
pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2)
|
164 |
+
|
165 |
+
feats = feats.view(batch * nviews, *feats.shape[2:]) # (b*nviews,c,h,w)
|
166 |
+
|
167 |
+
ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device)
|
168 |
+
|
169 |
+
features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True)
|
170 |
+
counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True)
|
171 |
+
|
172 |
+
features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) # (batch, nviews, C, X,Y,Z)
|
173 |
+
counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ)
|
174 |
+
return features_volume, counts_volume
|
175 |
+
|
SparseNeuS_demo_v1/ops/generate_grids.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def generate_grid(n_vox, interval):
|
5 |
+
"""
|
6 |
+
generate grid
|
7 |
+
if 3D volume, grid[:,:,x,y,z] = (x,y,z)
|
8 |
+
:param n_vox:
|
9 |
+
:param interval:
|
10 |
+
:return:
|
11 |
+
"""
|
12 |
+
with torch.no_grad():
|
13 |
+
# Create voxel grid
|
14 |
+
grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)]
|
15 |
+
grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing="ij")) # 3 dx dy dz
|
16 |
+
# ! don't create tensor on gpu; imbalanced gpu memory in ddp mode
|
17 |
+
grid = grid.unsqueeze(0).type(torch.float32) # 1 3 dx dy dz
|
18 |
+
|
19 |
+
return grid
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
import torch.nn.functional as F
|
24 |
+
grid = generate_grid([5, 6, 8], 1)
|
25 |
+
|
26 |
+
pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1
|
27 |
+
pts = pts.view(1, 1, 1, 1, 3)
|
28 |
+
|
29 |
+
pts = torch.flip(pts, dims=[-1])
|
30 |
+
|
31 |
+
sampled = F.grid_sample(grid, pts, mode='nearest')
|
32 |
+
|
33 |
+
print(sampled)
|
SparseNeuS_demo_v1/ops/grid_sampler.py
ADDED
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
pytorch grid_sample doesn't support second-order derivative
|
3 |
+
implement custom version
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def grid_sample_2d(image, optical):
|
12 |
+
N, C, IH, IW = image.shape
|
13 |
+
_, H, W, _ = optical.shape
|
14 |
+
|
15 |
+
ix = optical[..., 0]
|
16 |
+
iy = optical[..., 1]
|
17 |
+
|
18 |
+
ix = ((ix + 1) / 2) * (IW - 1);
|
19 |
+
iy = ((iy + 1) / 2) * (IH - 1);
|
20 |
+
with torch.no_grad():
|
21 |
+
ix_nw = torch.floor(ix);
|
22 |
+
iy_nw = torch.floor(iy);
|
23 |
+
ix_ne = ix_nw + 1;
|
24 |
+
iy_ne = iy_nw;
|
25 |
+
ix_sw = ix_nw;
|
26 |
+
iy_sw = iy_nw + 1;
|
27 |
+
ix_se = ix_nw + 1;
|
28 |
+
iy_se = iy_nw + 1;
|
29 |
+
|
30 |
+
nw = (ix_se - ix) * (iy_se - iy)
|
31 |
+
ne = (ix - ix_sw) * (iy_sw - iy)
|
32 |
+
sw = (ix_ne - ix) * (iy - iy_ne)
|
33 |
+
se = (ix - ix_nw) * (iy - iy_nw)
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
torch.clamp(ix_nw, 0, IW - 1, out=ix_nw)
|
37 |
+
torch.clamp(iy_nw, 0, IH - 1, out=iy_nw)
|
38 |
+
|
39 |
+
torch.clamp(ix_ne, 0, IW - 1, out=ix_ne)
|
40 |
+
torch.clamp(iy_ne, 0, IH - 1, out=iy_ne)
|
41 |
+
|
42 |
+
torch.clamp(ix_sw, 0, IW - 1, out=ix_sw)
|
43 |
+
torch.clamp(iy_sw, 0, IH - 1, out=iy_sw)
|
44 |
+
|
45 |
+
torch.clamp(ix_se, 0, IW - 1, out=ix_se)
|
46 |
+
torch.clamp(iy_se, 0, IH - 1, out=iy_se)
|
47 |
+
|
48 |
+
image = image.view(N, C, IH * IW)
|
49 |
+
|
50 |
+
nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
|
51 |
+
ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
|
52 |
+
sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
|
53 |
+
se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))
|
54 |
+
|
55 |
+
out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
|
56 |
+
ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
|
57 |
+
sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
|
58 |
+
se_val.view(N, C, H, W) * se.view(N, 1, H, W))
|
59 |
+
|
60 |
+
return out_val
|
61 |
+
|
62 |
+
|
63 |
+
# - checked for correctness
|
64 |
+
def grid_sample_3d(volume, optical):
|
65 |
+
"""
|
66 |
+
bilinear sampling cannot guarantee continuous first-order gradient
|
67 |
+
mimic pytorch grid_sample function
|
68 |
+
The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view)
|
69 |
+
fnw (front north west) point
|
70 |
+
bse (back south east) point
|
71 |
+
:param volume: [B, C, X, Y, Z]
|
72 |
+
:param optical: [B, x, y, z, 3]
|
73 |
+
:return:
|
74 |
+
"""
|
75 |
+
N, C, ID, IH, IW = volume.shape
|
76 |
+
_, D, H, W, _ = optical.shape
|
77 |
+
|
78 |
+
ix = optical[..., 0]
|
79 |
+
iy = optical[..., 1]
|
80 |
+
iz = optical[..., 2]
|
81 |
+
|
82 |
+
ix = ((ix + 1) / 2) * (IW - 1)
|
83 |
+
iy = ((iy + 1) / 2) * (IH - 1)
|
84 |
+
iz = ((iz + 1) / 2) * (ID - 1)
|
85 |
+
|
86 |
+
mask_x = (ix > 0) & (ix < IW)
|
87 |
+
mask_y = (iy > 0) & (iy < IH)
|
88 |
+
mask_z = (iz > 0) & (iz < ID)
|
89 |
+
|
90 |
+
mask = mask_x & mask_y & mask_z # [B, x, y, z]
|
91 |
+
mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) # [B, C, x, y, z]
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
# back north west
|
95 |
+
ix_bnw = torch.floor(ix)
|
96 |
+
iy_bnw = torch.floor(iy)
|
97 |
+
iz_bnw = torch.floor(iz)
|
98 |
+
|
99 |
+
ix_bne = ix_bnw + 1
|
100 |
+
iy_bne = iy_bnw
|
101 |
+
iz_bne = iz_bnw
|
102 |
+
|
103 |
+
ix_bsw = ix_bnw
|
104 |
+
iy_bsw = iy_bnw + 1
|
105 |
+
iz_bsw = iz_bnw
|
106 |
+
|
107 |
+
ix_bse = ix_bnw + 1
|
108 |
+
iy_bse = iy_bnw + 1
|
109 |
+
iz_bse = iz_bnw
|
110 |
+
|
111 |
+
# front view
|
112 |
+
ix_fnw = ix_bnw
|
113 |
+
iy_fnw = iy_bnw
|
114 |
+
iz_fnw = iz_bnw + 1
|
115 |
+
|
116 |
+
ix_fne = ix_bnw + 1
|
117 |
+
iy_fne = iy_bnw
|
118 |
+
iz_fne = iz_bnw + 1
|
119 |
+
|
120 |
+
ix_fsw = ix_bnw
|
121 |
+
iy_fsw = iy_bnw + 1
|
122 |
+
iz_fsw = iz_bnw + 1
|
123 |
+
|
124 |
+
ix_fse = ix_bnw + 1
|
125 |
+
iy_fse = iy_bnw + 1
|
126 |
+
iz_fse = iz_bnw + 1
|
127 |
+
|
128 |
+
# back view
|
129 |
+
bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) # smaller volume, larger weight
|
130 |
+
bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz)
|
131 |
+
bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz)
|
132 |
+
bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz)
|
133 |
+
|
134 |
+
# front view
|
135 |
+
fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) # smaller volume, larger weight
|
136 |
+
fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw)
|
137 |
+
fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne)
|
138 |
+
fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw)
|
139 |
+
|
140 |
+
with torch.no_grad():
|
141 |
+
# back view
|
142 |
+
torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
|
143 |
+
torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
|
144 |
+
torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)
|
145 |
+
|
146 |
+
torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
|
147 |
+
torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
|
148 |
+
torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)
|
149 |
+
|
150 |
+
torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
|
151 |
+
torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
|
152 |
+
torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)
|
153 |
+
|
154 |
+
torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
|
155 |
+
torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
|
156 |
+
torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)
|
157 |
+
|
158 |
+
# front view
|
159 |
+
torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw)
|
160 |
+
torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw)
|
161 |
+
torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw)
|
162 |
+
|
163 |
+
torch.clamp(ix_fne, 0, IW - 1, out=ix_fne)
|
164 |
+
torch.clamp(iy_fne, 0, IH - 1, out=iy_fne)
|
165 |
+
torch.clamp(iz_fne, 0, ID - 1, out=iz_fne)
|
166 |
+
|
167 |
+
torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw)
|
168 |
+
torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw)
|
169 |
+
torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw)
|
170 |
+
|
171 |
+
torch.clamp(ix_fse, 0, IW - 1, out=ix_fse)
|
172 |
+
torch.clamp(iy_fse, 0, IH - 1, out=iy_fse)
|
173 |
+
torch.clamp(iz_fse, 0, ID - 1, out=iz_fse)
|
174 |
+
|
175 |
+
# xxx = volume[:, :, iz_bnw.long(), iy_bnw.long(), ix_bnw.long()]
|
176 |
+
volume = volume.view(N, C, ID * IH * IW)
|
177 |
+
# yyy = volume[:, :, (iz_bnw * ID + iy_bnw * IW + ix_bnw).long()]
|
178 |
+
|
179 |
+
# back view
|
180 |
+
bnw_val = torch.gather(volume, 2,
|
181 |
+
(iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
182 |
+
bne_val = torch.gather(volume, 2,
|
183 |
+
(iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
184 |
+
bsw_val = torch.gather(volume, 2,
|
185 |
+
(iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
186 |
+
bse_val = torch.gather(volume, 2,
|
187 |
+
(iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
188 |
+
|
189 |
+
# front view
|
190 |
+
fnw_val = torch.gather(volume, 2,
|
191 |
+
(iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
192 |
+
fne_val = torch.gather(volume, 2,
|
193 |
+
(iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
194 |
+
fsw_val = torch.gather(volume, 2,
|
195 |
+
(iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
196 |
+
fse_val = torch.gather(volume, 2,
|
197 |
+
(iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1))
|
198 |
+
|
199 |
+
out_val = (
|
200 |
+
# back
|
201 |
+
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
|
202 |
+
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
|
203 |
+
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
|
204 |
+
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) +
|
205 |
+
# front
|
206 |
+
fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) +
|
207 |
+
fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) +
|
208 |
+
fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) +
|
209 |
+
fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W)
|
210 |
+
|
211 |
+
)
|
212 |
+
|
213 |
+
# * zero padding
|
214 |
+
out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device))
|
215 |
+
|
216 |
+
return out_val
|
217 |
+
|
218 |
+
|
219 |
+
# Interpolation kernel
|
220 |
+
def get_weight(s, a=-0.5):
|
221 |
+
mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1)
|
222 |
+
mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2)
|
223 |
+
mask_2 = torch.abs(s) > 2
|
224 |
+
|
225 |
+
weight = torch.zeros_like(s).to(s.device)
|
226 |
+
weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight)
|
227 |
+
weight = torch.where(mask_1,
|
228 |
+
a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a,
|
229 |
+
weight)
|
230 |
+
|
231 |
+
# if (torch.abs(s) >= 0) & (torch.abs(s) <= 1):
|
232 |
+
# return (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1
|
233 |
+
#
|
234 |
+
# elif (torch.abs(s) > 1) & (torch.abs(s) <= 2):
|
235 |
+
# return a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a
|
236 |
+
# return 0
|
237 |
+
|
238 |
+
return weight
|
239 |
+
|
240 |
+
|
241 |
+
def cubic_interpolate(p, x):
|
242 |
+
"""
|
243 |
+
one dimensional cubic interpolation
|
244 |
+
:param p: [N, 4] (4) should be in order
|
245 |
+
:param x: [N]
|
246 |
+
:return:
|
247 |
+
"""
|
248 |
+
return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * (
|
249 |
+
2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * (
|
250 |
+
3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0])))
|
251 |
+
|
252 |
+
|
253 |
+
def bicubic_interpolate(p, x, y, if_batch=True):
|
254 |
+
"""
|
255 |
+
two dimensional cubic interpolation
|
256 |
+
:param p: [N, 4, 4]
|
257 |
+
:param x: [N]
|
258 |
+
:param y: [N]
|
259 |
+
:return:
|
260 |
+
"""
|
261 |
+
num = p.shape[0]
|
262 |
+
|
263 |
+
if not if_batch:
|
264 |
+
arr0 = cubic_interpolate(p[:, 0, :], x) # [N]
|
265 |
+
arr1 = cubic_interpolate(p[:, 1, :], x)
|
266 |
+
arr2 = cubic_interpolate(p[:, 2, :], x)
|
267 |
+
arr3 = cubic_interpolate(p[:, 3, :], x)
|
268 |
+
return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) # [N]
|
269 |
+
else:
|
270 |
+
x = x[:, None].repeat(1, 4).view(-1)
|
271 |
+
p = p.contiguous().view(num * 4, 4)
|
272 |
+
arr = cubic_interpolate(p, x)
|
273 |
+
arr = arr.view(num, 4)
|
274 |
+
|
275 |
+
return cubic_interpolate(arr, y)
|
276 |
+
|
277 |
+
|
278 |
+
def tricubic_interpolate(p, x, y, z):
|
279 |
+
"""
|
280 |
+
three dimensional cubic interpolation
|
281 |
+
:param p: [N,4,4,4]
|
282 |
+
:param x: [N]
|
283 |
+
:param y: [N]
|
284 |
+
:param z: [N]
|
285 |
+
:return:
|
286 |
+
"""
|
287 |
+
num = p.shape[0]
|
288 |
+
|
289 |
+
arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) # [N]
|
290 |
+
arr1 = bicubic_interpolate(p[:, 1, :, :], x, y)
|
291 |
+
arr2 = bicubic_interpolate(p[:, 2, :, :], x, y)
|
292 |
+
arr3 = bicubic_interpolate(p[:, 3, :, :], x, y)
|
293 |
+
|
294 |
+
return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) # [N]
|
295 |
+
|
296 |
+
|
297 |
+
def cubic_interpolate_batch(p, x):
|
298 |
+
"""
|
299 |
+
one dimensional cubic interpolation
|
300 |
+
:param p: [B, N, 4] (4) should be in order
|
301 |
+
:param x: [B, N]
|
302 |
+
:return:
|
303 |
+
"""
|
304 |
+
return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * (
|
305 |
+
2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * (
|
306 |
+
3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0])))
|
307 |
+
|
308 |
+
|
309 |
+
def bicubic_interpolate_batch(p, x, y):
|
310 |
+
"""
|
311 |
+
two dimensional cubic interpolation
|
312 |
+
:param p: [B, N, 4, 4]
|
313 |
+
:param x: [B, N]
|
314 |
+
:param y: [B, N]
|
315 |
+
:return:
|
316 |
+
"""
|
317 |
+
B, N, _, _ = p.shape
|
318 |
+
|
319 |
+
x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) # [B, N*4]
|
320 |
+
arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x)
|
321 |
+
arr = arr.view(B, N, 4)
|
322 |
+
return cubic_interpolate_batch(arr, y) # [B, N]
|
323 |
+
|
324 |
+
|
325 |
+
# * batch version cannot speed up training
|
326 |
+
def tricubic_interpolate_batch(p, x, y, z):
|
327 |
+
"""
|
328 |
+
three dimensional cubic interpolation
|
329 |
+
:param p: [N,4,4,4]
|
330 |
+
:param x: [N]
|
331 |
+
:param y: [N]
|
332 |
+
:param z: [N]
|
333 |
+
:return:
|
334 |
+
"""
|
335 |
+
N = p.shape[0]
|
336 |
+
|
337 |
+
x = x[None, :].repeat(4, 1)
|
338 |
+
y = y[None, :].repeat(4, 1)
|
339 |
+
|
340 |
+
p = p.permute(1, 0, 2, 3).contiguous()
|
341 |
+
|
342 |
+
arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) # [4, N]
|
343 |
+
|
344 |
+
arr = arr.permute(1, 0).contiguous() # [N, 4]
|
345 |
+
|
346 |
+
return cubic_interpolate(arr, z) # [N]
|
347 |
+
|
348 |
+
|
349 |
+
def tricubic_sample_3d(volume, optical):
|
350 |
+
"""
|
351 |
+
tricubic sampling; can guarantee continuous gradient (interpolation border)
|
352 |
+
:param volume: [B, C, ID, IH, IW]
|
353 |
+
:param optical: [B, D, H, W, 3]
|
354 |
+
:param sample_num:
|
355 |
+
:return:
|
356 |
+
"""
|
357 |
+
|
358 |
+
@torch.no_grad()
|
359 |
+
def get_shifts(x):
|
360 |
+
x1 = -1 * (1 + x - torch.floor(x))
|
361 |
+
x2 = -1 * (x - torch.floor(x))
|
362 |
+
x3 = torch.floor(x) + 1 - x
|
363 |
+
x4 = torch.floor(x) + 2 - x
|
364 |
+
|
365 |
+
return torch.stack([x1, x2, x3, x4], dim=-1) # (B,d,h,w,4)
|
366 |
+
|
367 |
+
N, C, ID, IH, IW = volume.shape
|
368 |
+
_, D, H, W, _ = optical.shape
|
369 |
+
|
370 |
+
device = volume.device
|
371 |
+
|
372 |
+
ix = optical[..., 0]
|
373 |
+
iy = optical[..., 1]
|
374 |
+
iz = optical[..., 2]
|
375 |
+
|
376 |
+
ix = ((ix + 1) / 2) * (IW - 1) # (B,d,h,w)
|
377 |
+
iy = ((iy + 1) / 2) * (IH - 1)
|
378 |
+
iz = ((iz + 1) / 2) * (ID - 1)
|
379 |
+
|
380 |
+
ix = ix.view(-1)
|
381 |
+
iy = iy.view(-1)
|
382 |
+
iz = iz.view(-1)
|
383 |
+
|
384 |
+
with torch.no_grad():
|
385 |
+
shifts_x = get_shifts(ix).view(-1, 4) # (B*d*h*w,4)
|
386 |
+
shifts_y = get_shifts(iy).view(-1, 4)
|
387 |
+
shifts_z = get_shifts(iz).view(-1, 4)
|
388 |
+
|
389 |
+
perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device)
|
390 |
+
perm = torch.cumsum(perm_weights, dim=-1) - 1 # (B*d*h*w,64)
|
391 |
+
|
392 |
+
perm_z = perm // 16 # [N*D*H*W, num]
|
393 |
+
perm_y = (perm - perm_z * 16) // 4
|
394 |
+
perm_x = (perm - perm_z * 16 - perm_y * 4)
|
395 |
+
|
396 |
+
shifts_x = torch.gather(shifts_x, 1, perm_x) # [N*D*H*W, num]
|
397 |
+
shifts_y = torch.gather(shifts_y, 1, perm_y)
|
398 |
+
shifts_z = torch.gather(shifts_z, 1, perm_z)
|
399 |
+
|
400 |
+
ix_target = (ix[:, None] + shifts_x).long() # [N*D*H*W, num]
|
401 |
+
iy_target = (iy[:, None] + shifts_y).long()
|
402 |
+
iz_target = (iz[:, None] + shifts_z).long()
|
403 |
+
|
404 |
+
torch.clamp(ix_target, 0, IW - 1, out=ix_target)
|
405 |
+
torch.clamp(iy_target, 0, IH - 1, out=iy_target)
|
406 |
+
torch.clamp(iz_target, 0, ID - 1, out=iz_target)
|
407 |
+
|
408 |
+
local_dist_x = ix - ix_target[:, 1] # ! attention here is [:, 1]
|
409 |
+
local_dist_y = iy - iy_target[:, 1 + 4]
|
410 |
+
local_dist_z = iz - iz_target[:, 1 + 16]
|
411 |
+
|
412 |
+
local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
|
413 |
+
local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
|
414 |
+
local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1)
|
415 |
+
|
416 |
+
# ! attention: IW is correct
|
417 |
+
idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target # [N*D*H*W, num]
|
418 |
+
|
419 |
+
volume = volume.view(N, C, ID * IH * IW)
|
420 |
+
|
421 |
+
out = torch.gather(volume, 2,
|
422 |
+
idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1))
|
423 |
+
out = out.view(N * C * D * H * W, 4, 4, 4)
|
424 |
+
|
425 |
+
# - tricubic_interpolate() is a bit faster than tricubic_interpolate_batch()
|
426 |
+
final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) # [N,C,D,H,W]
|
427 |
+
|
428 |
+
return final
|
429 |
+
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
# image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
|
434 |
+
#
|
435 |
+
# optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)
|
436 |
+
#
|
437 |
+
# print(grid_sample_2d(image, optical))
|
438 |
+
#
|
439 |
+
# print(F.grid_sample(image, optical, padding_mode='border', align_corners=True))
|
440 |
+
|
441 |
+
from ops.generate_grids import generate_grid
|
442 |
+
|
443 |
+
p = torch.tensor([x for x in range(4)]).view(1, 4).float()
|
444 |
+
|
445 |
+
v = cubic_interpolate(p, torch.tensor([0.5]).view(1))
|
446 |
+
# v = bicubic_interpolate(p, torch.tensor([2/3]).view(1) , torch.tensor([2/3]).view(1))
|
447 |
+
|
448 |
+
vsize = 9
|
449 |
+
volume = generate_grid([vsize, vsize, vsize], 1) # [1,3,10,10,10]
|
450 |
+
# volume = torch.tensor([x for x in range(1000)]).view(1, 1, 10, 10, 10).float()
|
451 |
+
X, Y, Z = 0, 0, 6
|
452 |
+
x = 2 * X / (vsize - 1) - 1
|
453 |
+
y = 2 * Y / (vsize - 1) - 1
|
454 |
+
z = 2 * Z / (vsize - 1) - 1
|
455 |
+
|
456 |
+
# print(volume[:, :, Z, Y, X])
|
457 |
+
|
458 |
+
# volume = volume.view(1, 3, -1)
|
459 |
+
# xx = volume[:, :, Z * 9*9 + Y * 9 + X]
|
460 |
+
|
461 |
+
optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3)
|
462 |
+
|
463 |
+
print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True))
|
464 |
+
print(grid_sample_3d(volume, optical))
|
465 |
+
print(tricubic_sample_3d(volume, optical))
|
466 |
+
# target, relative_coords = implicit_sample_3d(volume, optical, 1)
|
467 |
+
# print(target)
|
SparseNeuS_demo_v1/tsparse/__init__.py
ADDED
File without changes
|
SparseNeuS_demo_v1/tsparse/modules.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchsparse
|
4 |
+
import torchsparse.nn as spnn
|
5 |
+
from torchsparse.tensor import PointTensor
|
6 |
+
|
7 |
+
from tsparse.torchsparse_utils import *
|
8 |
+
|
9 |
+
|
10 |
+
# __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU']
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBnReLU(nn.Module):
|
14 |
+
def __init__(self, in_channels, out_channels,
|
15 |
+
kernel_size=3, stride=1, pad=1):
|
16 |
+
super(ConvBnReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_channels, out_channels,
|
18 |
+
kernel_size, stride=stride, padding=pad, bias=False)
|
19 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
20 |
+
self.activation = nn.ReLU(inplace=True)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.activation(self.bn(self.conv(x)))
|
24 |
+
|
25 |
+
|
26 |
+
class ConvBnReLU3D(nn.Module):
|
27 |
+
def __init__(self, in_channels, out_channels,
|
28 |
+
kernel_size=3, stride=1, pad=1):
|
29 |
+
super(ConvBnReLU3D, self).__init__()
|
30 |
+
self.conv = nn.Conv3d(in_channels, out_channels,
|
31 |
+
kernel_size, stride=stride, padding=pad, bias=False)
|
32 |
+
self.bn = nn.BatchNorm3d(out_channels)
|
33 |
+
self.activation = nn.ReLU(inplace=True)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.activation(self.bn(self.conv(x)))
|
37 |
+
|
38 |
+
|
39 |
+
################################### feature net ######################################
|
40 |
+
class FeatureNet(nn.Module):
|
41 |
+
"""
|
42 |
+
output 3 levels of features using a FPN structure
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
super(FeatureNet, self).__init__()
|
47 |
+
|
48 |
+
self.conv0 = nn.Sequential(
|
49 |
+
ConvBnReLU(3, 8, 3, 1, 1),
|
50 |
+
ConvBnReLU(8, 8, 3, 1, 1))
|
51 |
+
|
52 |
+
self.conv1 = nn.Sequential(
|
53 |
+
ConvBnReLU(8, 16, 5, 2, 2),
|
54 |
+
ConvBnReLU(16, 16, 3, 1, 1),
|
55 |
+
ConvBnReLU(16, 16, 3, 1, 1))
|
56 |
+
|
57 |
+
self.conv2 = nn.Sequential(
|
58 |
+
ConvBnReLU(16, 32, 5, 2, 2),
|
59 |
+
ConvBnReLU(32, 32, 3, 1, 1),
|
60 |
+
ConvBnReLU(32, 32, 3, 1, 1))
|
61 |
+
|
62 |
+
self.toplayer = nn.Conv2d(32, 32, 1)
|
63 |
+
self.lat1 = nn.Conv2d(16, 32, 1)
|
64 |
+
self.lat0 = nn.Conv2d(8, 32, 1)
|
65 |
+
|
66 |
+
# to reduce channel size of the outputs from FPN
|
67 |
+
self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
|
68 |
+
self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
|
69 |
+
|
70 |
+
def _upsample_add(self, x, y):
|
71 |
+
return torch.nn.functional.interpolate(x, scale_factor=2,
|
72 |
+
mode="bilinear", align_corners=True) + y
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
# x: (B, 3, H, W)
|
76 |
+
conv0 = self.conv0(x) # (B, 8, H, W)
|
77 |
+
conv1 = self.conv1(conv0) # (B, 16, H//2, W//2)
|
78 |
+
conv2 = self.conv2(conv1) # (B, 32, H//4, W//4)
|
79 |
+
feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4)
|
80 |
+
feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2)
|
81 |
+
feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W)
|
82 |
+
|
83 |
+
# reduce output channels
|
84 |
+
feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2)
|
85 |
+
feat0 = self.smooth0(feat0) # (B, 8, H, W)
|
86 |
+
|
87 |
+
# feats = {"level_0": feat0,
|
88 |
+
# "level_1": feat1,
|
89 |
+
# "level_2": feat2}
|
90 |
+
|
91 |
+
return [feat2, feat1, feat0] # coarser to finer features
|
92 |
+
|
93 |
+
|
94 |
+
class BasicSparseConvolutionBlock(nn.Module):
|
95 |
+
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
|
96 |
+
super().__init__()
|
97 |
+
self.net = nn.Sequential(
|
98 |
+
spnn.Conv3d(inc,
|
99 |
+
outc,
|
100 |
+
kernel_size=ks,
|
101 |
+
dilation=dilation,
|
102 |
+
stride=stride),
|
103 |
+
spnn.BatchNorm(outc),
|
104 |
+
spnn.ReLU(True))
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
out = self.net(x)
|
108 |
+
return out
|
109 |
+
|
110 |
+
|
111 |
+
class BasicSparseDeconvolutionBlock(nn.Module):
|
112 |
+
def __init__(self, inc, outc, ks=3, stride=1):
|
113 |
+
super().__init__()
|
114 |
+
self.net = nn.Sequential(
|
115 |
+
spnn.Conv3d(inc,
|
116 |
+
outc,
|
117 |
+
kernel_size=ks,
|
118 |
+
stride=stride,
|
119 |
+
transposed=True),
|
120 |
+
spnn.BatchNorm(outc),
|
121 |
+
spnn.ReLU(True))
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
return self.net(x)
|
125 |
+
|
126 |
+
|
127 |
+
class SparseResidualBlock(nn.Module):
|
128 |
+
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
|
129 |
+
super().__init__()
|
130 |
+
self.net = nn.Sequential(
|
131 |
+
spnn.Conv3d(inc,
|
132 |
+
outc,
|
133 |
+
kernel_size=ks,
|
134 |
+
dilation=dilation,
|
135 |
+
stride=stride), spnn.BatchNorm(outc),
|
136 |
+
spnn.ReLU(True),
|
137 |
+
spnn.Conv3d(outc,
|
138 |
+
outc,
|
139 |
+
kernel_size=ks,
|
140 |
+
dilation=dilation,
|
141 |
+
stride=1), spnn.BatchNorm(outc))
|
142 |
+
|
143 |
+
self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
|
144 |
+
nn.Sequential(
|
145 |
+
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride),
|
146 |
+
spnn.BatchNorm(outc)
|
147 |
+
)
|
148 |
+
|
149 |
+
self.relu = spnn.ReLU(True)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
out = self.relu(self.net(x) + self.downsample(x))
|
153 |
+
return out
|
154 |
+
|
155 |
+
|
156 |
+
class SPVCNN(nn.Module):
|
157 |
+
def __init__(self, **kwargs):
|
158 |
+
super().__init__()
|
159 |
+
|
160 |
+
self.dropout = kwargs['dropout']
|
161 |
+
|
162 |
+
cr = kwargs.get('cr', 1.0)
|
163 |
+
cs = [32, 64, 128, 96, 96]
|
164 |
+
cs = [int(cr * x) for x in cs]
|
165 |
+
|
166 |
+
if 'pres' in kwargs and 'vres' in kwargs:
|
167 |
+
self.pres = kwargs['pres']
|
168 |
+
self.vres = kwargs['vres']
|
169 |
+
|
170 |
+
self.stem = nn.Sequential(
|
171 |
+
spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1),
|
172 |
+
spnn.BatchNorm(cs[0]), spnn.ReLU(True)
|
173 |
+
)
|
174 |
+
|
175 |
+
self.stage1 = nn.Sequential(
|
176 |
+
BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
|
177 |
+
SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
|
178 |
+
SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
|
179 |
+
)
|
180 |
+
|
181 |
+
self.stage2 = nn.Sequential(
|
182 |
+
BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
|
183 |
+
SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
|
184 |
+
SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
|
185 |
+
)
|
186 |
+
|
187 |
+
self.up1 = nn.ModuleList([
|
188 |
+
BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2),
|
189 |
+
nn.Sequential(
|
190 |
+
SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1,
|
191 |
+
dilation=1),
|
192 |
+
SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
|
193 |
+
)
|
194 |
+
])
|
195 |
+
|
196 |
+
self.up2 = nn.ModuleList([
|
197 |
+
BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2),
|
198 |
+
nn.Sequential(
|
199 |
+
SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1,
|
200 |
+
dilation=1),
|
201 |
+
SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
|
202 |
+
)
|
203 |
+
])
|
204 |
+
|
205 |
+
self.point_transforms = nn.ModuleList([
|
206 |
+
nn.Sequential(
|
207 |
+
nn.Linear(cs[0], cs[2]),
|
208 |
+
nn.BatchNorm1d(cs[2]),
|
209 |
+
nn.ReLU(True),
|
210 |
+
),
|
211 |
+
nn.Sequential(
|
212 |
+
nn.Linear(cs[2], cs[4]),
|
213 |
+
nn.BatchNorm1d(cs[4]),
|
214 |
+
nn.ReLU(True),
|
215 |
+
)
|
216 |
+
])
|
217 |
+
|
218 |
+
self.weight_initialization()
|
219 |
+
|
220 |
+
if self.dropout:
|
221 |
+
self.dropout = nn.Dropout(0.3, True)
|
222 |
+
|
223 |
+
def weight_initialization(self):
|
224 |
+
for m in self.modules():
|
225 |
+
if isinstance(m, nn.BatchNorm1d):
|
226 |
+
nn.init.constant_(m.weight, 1)
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
|
229 |
+
def forward(self, z):
|
230 |
+
# x: SparseTensor z: PointTensor
|
231 |
+
x0 = initial_voxelize(z, self.pres, self.vres)
|
232 |
+
|
233 |
+
x0 = self.stem(x0)
|
234 |
+
z0 = voxel_to_point(x0, z, nearest=False)
|
235 |
+
z0.F = z0.F
|
236 |
+
|
237 |
+
x1 = point_to_voxel(x0, z0)
|
238 |
+
x1 = self.stage1(x1)
|
239 |
+
x2 = self.stage2(x1)
|
240 |
+
z1 = voxel_to_point(x2, z0)
|
241 |
+
z1.F = z1.F + self.point_transforms[0](z0.F)
|
242 |
+
|
243 |
+
y3 = point_to_voxel(x2, z1)
|
244 |
+
if self.dropout:
|
245 |
+
y3.F = self.dropout(y3.F)
|
246 |
+
y3 = self.up1[0](y3)
|
247 |
+
y3 = torchsparse.cat([y3, x1])
|
248 |
+
y3 = self.up1[1](y3)
|
249 |
+
|
250 |
+
y4 = self.up2[0](y3)
|
251 |
+
y4 = torchsparse.cat([y4, x0])
|
252 |
+
y4 = self.up2[1](y4)
|
253 |
+
z3 = voxel_to_point(y4, z1)
|
254 |
+
z3.F = z3.F + self.point_transforms[1](z1.F)
|
255 |
+
|
256 |
+
return z3.F
|
257 |
+
|
258 |
+
|
259 |
+
class SparseCostRegNet(nn.Module):
|
260 |
+
"""
|
261 |
+
Sparse cost regularization network;
|
262 |
+
require sparse tensors as input
|
263 |
+
"""
|
264 |
+
|
265 |
+
def __init__(self, d_in, d_out=8):
|
266 |
+
super(SparseCostRegNet, self).__init__()
|
267 |
+
self.d_in = d_in
|
268 |
+
self.d_out = d_out
|
269 |
+
|
270 |
+
self.conv0 = BasicSparseConvolutionBlock(d_in, d_out)
|
271 |
+
|
272 |
+
self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2)
|
273 |
+
self.conv2 = BasicSparseConvolutionBlock(16, 16)
|
274 |
+
|
275 |
+
self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2)
|
276 |
+
self.conv4 = BasicSparseConvolutionBlock(32, 32)
|
277 |
+
|
278 |
+
self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2)
|
279 |
+
self.conv6 = BasicSparseConvolutionBlock(64, 64)
|
280 |
+
|
281 |
+
self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2)
|
282 |
+
|
283 |
+
self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2)
|
284 |
+
|
285 |
+
self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2)
|
286 |
+
|
287 |
+
def forward(self, x):
|
288 |
+
"""
|
289 |
+
|
290 |
+
:param x: sparse tensor
|
291 |
+
:return: sparse tensor
|
292 |
+
"""
|
293 |
+
conv0 = self.conv0(x)
|
294 |
+
conv2 = self.conv2(self.conv1(conv0))
|
295 |
+
conv4 = self.conv4(self.conv3(conv2))
|
296 |
+
|
297 |
+
x = self.conv6(self.conv5(conv4))
|
298 |
+
x = conv4 + self.conv7(x)
|
299 |
+
del conv4
|
300 |
+
x = conv2 + self.conv9(x)
|
301 |
+
del conv2
|
302 |
+
x = conv0 + self.conv11(x)
|
303 |
+
del conv0
|
304 |
+
return x.F
|
305 |
+
|
306 |
+
|
307 |
+
class SConv3d(nn.Module):
|
308 |
+
def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1):
|
309 |
+
super().__init__()
|
310 |
+
self.net = spnn.Conv3d(inc,
|
311 |
+
outc,
|
312 |
+
kernel_size=ks,
|
313 |
+
dilation=dilation,
|
314 |
+
stride=stride)
|
315 |
+
self.point_transforms = nn.Sequential(
|
316 |
+
nn.Linear(inc, outc),
|
317 |
+
)
|
318 |
+
self.pres = pres
|
319 |
+
self.vres = vres
|
320 |
+
|
321 |
+
def forward(self, z):
|
322 |
+
x = initial_voxelize(z, self.pres, self.vres)
|
323 |
+
x = self.net(x)
|
324 |
+
out = voxel_to_point(x, z, nearest=False)
|
325 |
+
out.F = out.F + self.point_transforms(z.F)
|
326 |
+
return out
|
SparseNeuS_demo_v1/tsparse/torchsparse_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copied from:
|
3 |
+
https://github.com/mit-han-lab/spvnas/blob/b24f50379ed888d3a0e784508a809d4e92e820c0/core/models/utils.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torchsparse.nn.functional as F
|
7 |
+
from torchsparse import PointTensor, SparseTensor
|
8 |
+
from torchsparse.nn.utils import get_kernel_offsets
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point']
|
13 |
+
|
14 |
+
|
15 |
+
# z: PointTensor
|
16 |
+
# return: SparseTensor
|
17 |
+
def initial_voxelize(z, init_res, after_res):
|
18 |
+
new_float_coord = torch.cat(
|
19 |
+
[(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1)
|
20 |
+
|
21 |
+
pc_hash = F.sphash(torch.floor(new_float_coord).int())
|
22 |
+
sparse_hash = torch.unique(pc_hash)
|
23 |
+
idx_query = F.sphashquery(pc_hash, sparse_hash)
|
24 |
+
counts = F.spcount(idx_query.int(), len(sparse_hash))
|
25 |
+
|
26 |
+
inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query,
|
27 |
+
counts)
|
28 |
+
inserted_coords = torch.round(inserted_coords).int()
|
29 |
+
inserted_feat = F.spvoxelize(z.F, idx_query, counts)
|
30 |
+
|
31 |
+
new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
|
32 |
+
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
|
33 |
+
z.additional_features['idx_query'][1] = idx_query
|
34 |
+
z.additional_features['counts'][1] = counts
|
35 |
+
z.C = new_float_coord
|
36 |
+
|
37 |
+
return new_tensor
|
38 |
+
|
39 |
+
|
40 |
+
# x: SparseTensor, z: PointTensor
|
41 |
+
# return: SparseTensor
|
42 |
+
def point_to_voxel(x, z):
|
43 |
+
if z.additional_features is None or z.additional_features.get('idx_query') is None \
|
44 |
+
or z.additional_features['idx_query'].get(x.s) is None:
|
45 |
+
# pc_hash = hash_gpu(torch.floor(z.C).int())
|
46 |
+
pc_hash = F.sphash(
|
47 |
+
torch.cat([
|
48 |
+
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
|
49 |
+
z.C[:, -1].int().view(-1, 1)
|
50 |
+
], 1))
|
51 |
+
sparse_hash = F.sphash(x.C)
|
52 |
+
idx_query = F.sphashquery(pc_hash, sparse_hash)
|
53 |
+
counts = F.spcount(idx_query.int(), x.C.shape[0])
|
54 |
+
z.additional_features['idx_query'][x.s] = idx_query
|
55 |
+
z.additional_features['counts'][x.s] = counts
|
56 |
+
else:
|
57 |
+
idx_query = z.additional_features['idx_query'][x.s]
|
58 |
+
counts = z.additional_features['counts'][x.s]
|
59 |
+
|
60 |
+
inserted_feat = F.spvoxelize(z.F, idx_query, counts)
|
61 |
+
new_tensor = SparseTensor(inserted_feat, x.C, x.s)
|
62 |
+
new_tensor.cmaps = x.cmaps
|
63 |
+
new_tensor.kmaps = x.kmaps
|
64 |
+
|
65 |
+
return new_tensor
|
66 |
+
|
67 |
+
|
68 |
+
# x: SparseTensor, z: PointTensor
|
69 |
+
# return: PointTensor
|
70 |
+
def voxel_to_point(x, z, nearest=False):
|
71 |
+
if z.idx_query is None or z.weights is None or z.idx_query.get(
|
72 |
+
x.s) is None or z.weights.get(x.s) is None:
|
73 |
+
off = get_kernel_offsets(2, x.s, 1, device=z.F.device)
|
74 |
+
# old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off)
|
75 |
+
old_hash = F.sphash(
|
76 |
+
torch.cat([
|
77 |
+
torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0],
|
78 |
+
z.C[:, -1].int().view(-1, 1)
|
79 |
+
], 1), off)
|
80 |
+
mm = x.C.to(z.F.device)
|
81 |
+
pc_hash = F.sphash(x.C.to(z.F.device))
|
82 |
+
idx_query = F.sphashquery(old_hash, pc_hash)
|
83 |
+
weights = F.calc_ti_weights(z.C, idx_query,
|
84 |
+
scale=x.s[0]).transpose(0, 1).contiguous()
|
85 |
+
idx_query = idx_query.transpose(0, 1).contiguous()
|
86 |
+
if nearest:
|
87 |
+
weights[:, 1:] = 0.
|
88 |
+
idx_query[:, 1:] = -1
|
89 |
+
new_feat = F.spdevoxelize(x.F, idx_query, weights)
|
90 |
+
new_tensor = PointTensor(new_feat,
|
91 |
+
z.C,
|
92 |
+
idx_query=z.idx_query,
|
93 |
+
weights=z.weights)
|
94 |
+
new_tensor.additional_features = z.additional_features
|
95 |
+
new_tensor.idx_query[x.s] = idx_query
|
96 |
+
new_tensor.weights[x.s] = weights
|
97 |
+
z.idx_query[x.s] = idx_query
|
98 |
+
z.weights[x.s] = weights
|
99 |
+
|
100 |
+
else:
|
101 |
+
new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s),
|
102 |
+
z.weights.get(x.s)) # - sparse trilinear interpoltation operation
|
103 |
+
new_tensor = PointTensor(new_feat,
|
104 |
+
z.C,
|
105 |
+
idx_query=z.idx_query,
|
106 |
+
weights=z.weights)
|
107 |
+
new_tensor.additional_features = z.additional_features
|
108 |
+
|
109 |
+
return new_tensor
|
110 |
+
|
111 |
+
|
112 |
+
def sparse_to_dense_torch_batch(locs, values, dim, default_val):
|
113 |
+
dense = torch.full([dim[0], dim[1], dim[2], dim[3]], float(default_val), device=locs.device)
|
114 |
+
dense[locs[:, 0], locs[:, 1], locs[:, 2], locs[:, 3]] = values
|
115 |
+
return dense
|
116 |
+
|
117 |
+
|
118 |
+
def sparse_to_dense_torch(locs, values, dim, default_val, device):
|
119 |
+
dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device)
|
120 |
+
if locs.shape[0] > 0:
|
121 |
+
dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
|
122 |
+
return dense
|
123 |
+
|
124 |
+
|
125 |
+
def sparse_to_dense_channel(locs, values, dim, c, default_val, device):
|
126 |
+
locs = locs.to(torch.int64)
|
127 |
+
dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device)
|
128 |
+
if locs.shape[0] > 0:
|
129 |
+
dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
|
130 |
+
return dense
|
131 |
+
|
132 |
+
|
133 |
+
def sparse_to_dense_np(locs, values, dim, default_val):
|
134 |
+
dense = np.zeros([dim[0], dim[1], dim[2]], dtype=values.dtype)
|
135 |
+
dense.fill(default_val)
|
136 |
+
dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values
|
137 |
+
return dense
|
SparseNeuS_demo_v1/utils/__init__.py
ADDED
File without changes
|
SparseNeuS_demo_v1/utils/misc_utils.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch, cv2, re
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.transforms as T
|
7 |
+
|
8 |
+
# Misc
|
9 |
+
img2mse = lambda x, y: torch.mean((x - y) ** 2)
|
10 |
+
mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
|
11 |
+
to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)
|
12 |
+
mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.)
|
13 |
+
|
14 |
+
|
15 |
+
def get_psnr(imgs_pred, imgs_gt):
|
16 |
+
psnrs = []
|
17 |
+
for (img, tar) in zip(imgs_pred, imgs_gt):
|
18 |
+
psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2)))
|
19 |
+
return np.array(psnrs)
|
20 |
+
|
21 |
+
|
22 |
+
def init_log(log, keys):
|
23 |
+
for key in keys:
|
24 |
+
log[key] = torch.tensor([0.0], dtype=float)
|
25 |
+
return log
|
26 |
+
|
27 |
+
|
28 |
+
def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
|
29 |
+
"""
|
30 |
+
depth: (H, W)
|
31 |
+
"""
|
32 |
+
|
33 |
+
x = np.nan_to_num(depth) # change nan to 0
|
34 |
+
if minmax is None:
|
35 |
+
mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
|
36 |
+
ma = np.max(x)
|
37 |
+
else:
|
38 |
+
mi, ma = minmax
|
39 |
+
|
40 |
+
x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
|
41 |
+
x = (255 * x).astype(np.uint8)
|
42 |
+
x_ = cv2.applyColorMap(x, cmap)
|
43 |
+
return x_, [mi, ma]
|
44 |
+
|
45 |
+
|
46 |
+
def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
|
47 |
+
"""
|
48 |
+
depth: (H, W)
|
49 |
+
"""
|
50 |
+
if type(depth) is not np.ndarray:
|
51 |
+
depth = depth.cpu().numpy()
|
52 |
+
|
53 |
+
x = np.nan_to_num(depth) # change nan to 0
|
54 |
+
if minmax is None:
|
55 |
+
mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
|
56 |
+
ma = np.max(x)
|
57 |
+
else:
|
58 |
+
mi, ma = minmax
|
59 |
+
|
60 |
+
x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
|
61 |
+
x = (255 * x).astype(np.uint8)
|
62 |
+
x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
|
63 |
+
x_ = T.ToTensor()(x_) # (3, H, W)
|
64 |
+
return x_, [mi, ma]
|
65 |
+
|
66 |
+
|
67 |
+
def abs_error_numpy(depth_pred, depth_gt, mask):
|
68 |
+
depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
|
69 |
+
return np.abs(depth_pred - depth_gt)
|
70 |
+
|
71 |
+
|
72 |
+
def abs_error(depth_pred, depth_gt, mask):
|
73 |
+
depth_pred, depth_gt = depth_pred[mask], depth_gt[mask]
|
74 |
+
err = depth_pred - depth_gt
|
75 |
+
return np.abs(err) if type(depth_pred) is np.ndarray else err.abs()
|
76 |
+
|
77 |
+
|
78 |
+
def acc_threshold(depth_pred, depth_gt, mask, threshold):
|
79 |
+
"""
|
80 |
+
computes the percentage of pixels whose depth error is less than @threshold
|
81 |
+
"""
|
82 |
+
errors = abs_error(depth_pred, depth_gt, mask)
|
83 |
+
acc_mask = errors < threshold
|
84 |
+
return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float()
|
85 |
+
|
86 |
+
|
87 |
+
def to_tensor_cuda(data, device, filter):
|
88 |
+
for item in data.keys():
|
89 |
+
|
90 |
+
if item in filter:
|
91 |
+
continue
|
92 |
+
|
93 |
+
if type(data[item]) is np.ndarray:
|
94 |
+
data[item] = torch.tensor(data[item], dtype=torch.float32, device=device)
|
95 |
+
else:
|
96 |
+
data[item] = data[item].float().to(device)
|
97 |
+
return data
|
98 |
+
|
99 |
+
|
100 |
+
def to_cuda(data, device, filter):
|
101 |
+
for item in data.keys():
|
102 |
+
if item in filter:
|
103 |
+
continue
|
104 |
+
|
105 |
+
data[item] = data[item].float().to(device)
|
106 |
+
return data
|
107 |
+
|
108 |
+
|
109 |
+
def tensor_unsqueeze(data, filter):
|
110 |
+
for item in data.keys():
|
111 |
+
if item in filter:
|
112 |
+
continue
|
113 |
+
|
114 |
+
data[item] = data[item][None]
|
115 |
+
return data
|
116 |
+
|
117 |
+
|
118 |
+
def filter_keys(dict):
|
119 |
+
dict.pop('N_samples')
|
120 |
+
if 'ndc' in dict.keys():
|
121 |
+
dict.pop('ndc')
|
122 |
+
if 'lindisp' in dict.keys():
|
123 |
+
dict.pop('lindisp')
|
124 |
+
return dict
|
125 |
+
|
126 |
+
|
127 |
+
def sub_selete_data(data_batch, device, idx, filtKey=[],
|
128 |
+
filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt',
|
129 |
+
'idx']):
|
130 |
+
data_sub_selete = {}
|
131 |
+
for item in data_batch.keys():
|
132 |
+
data_sub_selete[item] = data_batch[item][:, idx].float() if (
|
133 |
+
item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float()
|
134 |
+
if not data_sub_selete[item].is_cuda:
|
135 |
+
data_sub_selete[item] = data_sub_selete[item].to(device)
|
136 |
+
return data_sub_selete
|
137 |
+
|
138 |
+
|
139 |
+
def detach_data(dictionary):
|
140 |
+
dictionary_new = {}
|
141 |
+
for key in dictionary.keys():
|
142 |
+
dictionary_new[key] = dictionary[key].detach().clone()
|
143 |
+
return dictionary_new
|
144 |
+
|
145 |
+
|
146 |
+
def read_pfm(filename):
|
147 |
+
file = open(filename, 'rb')
|
148 |
+
color = None
|
149 |
+
width = None
|
150 |
+
height = None
|
151 |
+
scale = None
|
152 |
+
endian = None
|
153 |
+
|
154 |
+
header = file.readline().decode('utf-8').rstrip()
|
155 |
+
if header == 'PF':
|
156 |
+
color = True
|
157 |
+
elif header == 'Pf':
|
158 |
+
color = False
|
159 |
+
else:
|
160 |
+
raise Exception('Not a PFM file.')
|
161 |
+
|
162 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
|
163 |
+
if dim_match:
|
164 |
+
width, height = map(int, dim_match.groups())
|
165 |
+
else:
|
166 |
+
raise Exception('Malformed PFM header.')
|
167 |
+
|
168 |
+
scale = float(file.readline().rstrip())
|
169 |
+
if scale < 0: # little-endian
|
170 |
+
endian = '<'
|
171 |
+
scale = -scale
|
172 |
+
else:
|
173 |
+
endian = '>' # big-endian
|
174 |
+
|
175 |
+
data = np.fromfile(file, endian + 'f')
|
176 |
+
shape = (height, width, 3) if color else (height, width)
|
177 |
+
|
178 |
+
data = np.reshape(data, shape)
|
179 |
+
data = np.flipud(data)
|
180 |
+
file.close()
|
181 |
+
return data, scale
|
182 |
+
|
183 |
+
|
184 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
|
185 |
+
|
186 |
+
|
187 |
+
# from warmup_scheduler import GradualWarmupScheduler
|
188 |
+
def get_scheduler(hparams, optimizer):
|
189 |
+
eps = 1e-8
|
190 |
+
if hparams.lr_scheduler == 'steplr':
|
191 |
+
scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
|
192 |
+
gamma=hparams.decay_gamma)
|
193 |
+
elif hparams.lr_scheduler == 'cosine':
|
194 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
|
195 |
+
|
196 |
+
else:
|
197 |
+
raise ValueError('scheduler not recognized!')
|
198 |
+
|
199 |
+
# if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']:
|
200 |
+
# scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier,
|
201 |
+
# total_epoch=hparams.warmup_epochs, after_scheduler=scheduler)
|
202 |
+
return scheduler
|
203 |
+
|
204 |
+
|
205 |
+
#### pairing ####
|
206 |
+
def get_nearest_pose_ids(tar_pose, ref_poses, num_select):
|
207 |
+
'''
|
208 |
+
Args:
|
209 |
+
tar_pose: target pose [N, 4, 4]
|
210 |
+
ref_poses: reference poses [M, 4, 4]
|
211 |
+
num_select: the number of nearest views to select
|
212 |
+
Returns: the selected indices
|
213 |
+
'''
|
214 |
+
|
215 |
+
dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1)
|
216 |
+
|
217 |
+
sorted_ids = np.argsort(dists, axis=-1)
|
218 |
+
selected_ids = sorted_ids[:, :num_select]
|
219 |
+
return selected_ids
|
configs/sd-objaverse-finetune-c_concat-256.yaml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "image_target"
|
11 |
+
cond_stage_key: "image_cond"
|
12 |
+
image_size: 32
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: hybrid
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
|
19 |
+
scheduler_config: # 10000 warmup steps
|
20 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
21 |
+
params:
|
22 |
+
warm_up_steps: [ 100 ]
|
23 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
24 |
+
f_start: [ 1.e-6 ]
|
25 |
+
f_max: [ 1. ]
|
26 |
+
f_min: [ 1. ]
|
27 |
+
|
28 |
+
unet_config:
|
29 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
30 |
+
params:
|
31 |
+
image_size: 32 # unused
|
32 |
+
in_channels: 8
|
33 |
+
out_channels: 4
|
34 |
+
model_channels: 320
|
35 |
+
attention_resolutions: [ 4, 2, 1 ]
|
36 |
+
num_res_blocks: 2
|
37 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
38 |
+
num_heads: 8
|
39 |
+
use_spatial_transformer: True
|
40 |
+
transformer_depth: 1
|
41 |
+
context_dim: 768
|
42 |
+
use_checkpoint: True
|
43 |
+
legacy: False
|
44 |
+
|
45 |
+
first_stage_config:
|
46 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
47 |
+
params:
|
48 |
+
embed_dim: 4
|
49 |
+
monitor: val/rec_loss
|
50 |
+
ddconfig:
|
51 |
+
double_z: true
|
52 |
+
z_channels: 4
|
53 |
+
resolution: 256
|
54 |
+
in_channels: 3
|
55 |
+
out_ch: 3
|
56 |
+
ch: 128
|
57 |
+
ch_mult:
|
58 |
+
- 1
|
59 |
+
- 2
|
60 |
+
- 4
|
61 |
+
- 4
|
62 |
+
num_res_blocks: 2
|
63 |
+
attn_resolutions: []
|
64 |
+
dropout: 0.0
|
65 |
+
lossconfig:
|
66 |
+
target: torch.nn.Identity
|
67 |
+
|
68 |
+
cond_stage_config:
|
69 |
+
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
|
70 |
+
|
71 |
+
|
72 |
+
data:
|
73 |
+
target: ldm.data.simple.ObjaverseDataModuleFromConfig
|
74 |
+
params:
|
75 |
+
root_dir: 'views_whole_sphere'
|
76 |
+
batch_size: 192
|
77 |
+
num_workers: 16
|
78 |
+
total_view: 4
|
79 |
+
train:
|
80 |
+
validation: False
|
81 |
+
image_transforms:
|
82 |
+
size: 256
|
83 |
+
|
84 |
+
validation:
|
85 |
+
validation: True
|
86 |
+
image_transforms:
|
87 |
+
size: 256
|
88 |
+
|
89 |
+
|
90 |
+
lightning:
|
91 |
+
find_unused_parameters: false
|
92 |
+
metrics_over_trainsteps_checkpoint: True
|
93 |
+
modelcheckpoint:
|
94 |
+
params:
|
95 |
+
every_n_train_steps: 5000
|
96 |
+
callbacks:
|
97 |
+
image_logger:
|
98 |
+
target: main.ImageLogger
|
99 |
+
params:
|
100 |
+
batch_frequency: 500
|
101 |
+
max_images: 32
|
102 |
+
increase_log_steps: False
|
103 |
+
log_first_step: True
|
104 |
+
log_images_kwargs:
|
105 |
+
use_ema_scope: False
|
106 |
+
inpaint: False
|
107 |
+
plot_progressive_rows: False
|
108 |
+
plot_diffusion_rows: False
|
109 |
+
N: 32
|
110 |
+
unconditional_guidance_scale: 3.0
|
111 |
+
unconditional_guidance_label: [""]
|
112 |
+
|
113 |
+
trainer:
|
114 |
+
benchmark: True
|
115 |
+
val_check_interval: 5000000 # really sorry
|
116 |
+
num_sanity_val_steps: 0
|
117 |
+
accumulate_grad_batches: 1
|
ldm/data/__init__.py
ADDED
File without changes
|
ldm/data/base.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from abc import abstractmethod
|
4 |
+
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
5 |
+
|
6 |
+
|
7 |
+
class Txt2ImgIterableBaseDataset(IterableDataset):
|
8 |
+
'''
|
9 |
+
Define an interface to make the IterableDatasets for text2img data chainable
|
10 |
+
'''
|
11 |
+
def __init__(self, num_records=0, valid_ids=None, size=256):
|
12 |
+
super().__init__()
|
13 |
+
self.num_records = num_records
|
14 |
+
self.valid_ids = valid_ids
|
15 |
+
self.sample_ids = valid_ids
|
16 |
+
self.size = size
|
17 |
+
|
18 |
+
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return self.num_records
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def __iter__(self):
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
class PRNGMixin(object):
|
29 |
+
"""
|
30 |
+
Adds a prng property which is a numpy RandomState which gets
|
31 |
+
reinitialized whenever the pid changes to avoid synchronized sampling
|
32 |
+
behavior when used in conjunction with multiprocessing.
|
33 |
+
"""
|
34 |
+
@property
|
35 |
+
def prng(self):
|
36 |
+
currentpid = os.getpid()
|
37 |
+
if getattr(self, "_initpid", None) != currentpid:
|
38 |
+
self._initpid = currentpid
|
39 |
+
self._prng = np.random.RandomState()
|
40 |
+
return self._prng
|
ldm/data/coco.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import albumentations
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from abc import abstractmethod
|
9 |
+
|
10 |
+
|
11 |
+
class CocoBase(Dataset):
|
12 |
+
"""needed for (image, caption, segmentation) pairs"""
|
13 |
+
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
|
14 |
+
crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
|
15 |
+
self.split = self.get_split()
|
16 |
+
self.size = size
|
17 |
+
if crop_size is None:
|
18 |
+
self.crop_size = size
|
19 |
+
else:
|
20 |
+
self.crop_size = crop_size
|
21 |
+
|
22 |
+
assert crop_type in [None, 'random', 'center']
|
23 |
+
self.crop_type = crop_type
|
24 |
+
self.use_segmenation = use_segmentation
|
25 |
+
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
|
26 |
+
self.stuffthing = use_stuffthing # include thing in segmentation
|
27 |
+
if self.onehot and not self.stuffthing:
|
28 |
+
raise NotImplemented("One hot mode is only supported for the "
|
29 |
+
"stuffthings version because labels are stored "
|
30 |
+
"a bit different.")
|
31 |
+
|
32 |
+
data_json = datajson
|
33 |
+
with open(data_json) as json_file:
|
34 |
+
self.json_data = json.load(json_file)
|
35 |
+
self.img_id_to_captions = dict()
|
36 |
+
self.img_id_to_filepath = dict()
|
37 |
+
self.img_id_to_segmentation_filepath = dict()
|
38 |
+
|
39 |
+
assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
|
40 |
+
f"captions_val{self.year()}.json"]
|
41 |
+
# TODO currently hardcoded paths, would be better to follow logic in
|
42 |
+
# cocstuff pixelmaps
|
43 |
+
if self.use_segmenation:
|
44 |
+
if self.stuffthing:
|
45 |
+
self.segmentation_prefix = (
|
46 |
+
f"data/cocostuffthings/val{self.year()}" if
|
47 |
+
data_json.endswith(f"captions_val{self.year()}.json") else
|
48 |
+
f"data/cocostuffthings/train{self.year()}")
|
49 |
+
else:
|
50 |
+
self.segmentation_prefix = (
|
51 |
+
f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
|
52 |
+
data_json.endswith(f"captions_val{self.year()}.json") else
|
53 |
+
f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
|
54 |
+
|
55 |
+
imagedirs = self.json_data["images"]
|
56 |
+
self.labels = {"image_ids": list()}
|
57 |
+
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
|
58 |
+
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
|
59 |
+
self.img_id_to_captions[imgdir["id"]] = list()
|
60 |
+
pngfilename = imgdir["file_name"].replace("jpg", "png")
|
61 |
+
if self.use_segmenation:
|
62 |
+
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
|
63 |
+
self.segmentation_prefix, pngfilename)
|
64 |
+
if given_files is not None:
|
65 |
+
if pngfilename in given_files:
|
66 |
+
self.labels["image_ids"].append(imgdir["id"])
|
67 |
+
else:
|
68 |
+
self.labels["image_ids"].append(imgdir["id"])
|
69 |
+
|
70 |
+
capdirs = self.json_data["annotations"]
|
71 |
+
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
|
72 |
+
# there are in average 5 captions per image
|
73 |
+
#self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
|
74 |
+
self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
|
75 |
+
|
76 |
+
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
|
77 |
+
if self.split=="validation":
|
78 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
79 |
+
else:
|
80 |
+
# default option for train is random crop
|
81 |
+
if self.crop_type in [None, 'random']:
|
82 |
+
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
|
83 |
+
else:
|
84 |
+
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
|
85 |
+
self.preprocessor = albumentations.Compose(
|
86 |
+
[self.rescaler, self.cropper],
|
87 |
+
additional_targets={"segmentation": "image"})
|
88 |
+
if force_no_crop:
|
89 |
+
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
|
90 |
+
self.preprocessor = albumentations.Compose(
|
91 |
+
[self.rescaler],
|
92 |
+
additional_targets={"segmentation": "image"})
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def year(self):
|
96 |
+
raise NotImplementedError()
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return len(self.labels["image_ids"])
|
100 |
+
|
101 |
+
def preprocess_image(self, image_path, segmentation_path=None):
|
102 |
+
image = Image.open(image_path)
|
103 |
+
if not image.mode == "RGB":
|
104 |
+
image = image.convert("RGB")
|
105 |
+
image = np.array(image).astype(np.uint8)
|
106 |
+
if segmentation_path:
|
107 |
+
segmentation = Image.open(segmentation_path)
|
108 |
+
if not self.onehot and not segmentation.mode == "RGB":
|
109 |
+
segmentation = segmentation.convert("RGB")
|
110 |
+
segmentation = np.array(segmentation).astype(np.uint8)
|
111 |
+
if self.onehot:
|
112 |
+
assert self.stuffthing
|
113 |
+
# stored in caffe format: unlabeled==255. stuff and thing from
|
114 |
+
# 0-181. to be compatible with the labels in
|
115 |
+
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
|
116 |
+
# we shift stuffthing one to the right and put unlabeled in zero
|
117 |
+
# as long as segmentation is uint8 shifting to right handles the
|
118 |
+
# latter too
|
119 |
+
assert segmentation.dtype == np.uint8
|
120 |
+
segmentation = segmentation + 1
|
121 |
+
|
122 |
+
processed = self.preprocessor(image=image, segmentation=segmentation)
|
123 |
+
|
124 |
+
image, segmentation = processed["image"], processed["segmentation"]
|
125 |
+
else:
|
126 |
+
image = self.preprocessor(image=image,)['image']
|
127 |
+
|
128 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
129 |
+
if segmentation_path:
|
130 |
+
if self.onehot:
|
131 |
+
assert segmentation.dtype == np.uint8
|
132 |
+
# make it one hot
|
133 |
+
n_labels = 183
|
134 |
+
flatseg = np.ravel(segmentation)
|
135 |
+
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
|
136 |
+
onehot[np.arange(flatseg.size), flatseg] = True
|
137 |
+
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
|
138 |
+
segmentation = onehot
|
139 |
+
else:
|
140 |
+
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
|
141 |
+
return image, segmentation
|
142 |
+
else:
|
143 |
+
return image
|
144 |
+
|
145 |
+
def __getitem__(self, i):
|
146 |
+
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
|
147 |
+
if self.use_segmenation:
|
148 |
+
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
|
149 |
+
image, segmentation = self.preprocess_image(img_path, seg_path)
|
150 |
+
else:
|
151 |
+
image = self.preprocess_image(img_path)
|
152 |
+
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
|
153 |
+
# randomly draw one of all available captions per image
|
154 |
+
caption = captions[np.random.randint(0, len(captions))]
|
155 |
+
example = {"image": image,
|
156 |
+
#"caption": [str(caption[0])],
|
157 |
+
"caption": caption,
|
158 |
+
"img_path": img_path,
|
159 |
+
"filename_": img_path.split(os.sep)[-1]
|
160 |
+
}
|
161 |
+
if self.use_segmenation:
|
162 |
+
example.update({"seg_path": seg_path, 'segmentation': segmentation})
|
163 |
+
return example
|
164 |
+
|
165 |
+
|
166 |
+
class CocoImagesAndCaptionsTrain2017(CocoBase):
|
167 |
+
"""returns a pair of (image, caption)"""
|
168 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
|
169 |
+
super().__init__(size=size,
|
170 |
+
dataroot="data/coco/train2017",
|
171 |
+
datajson="data/coco/annotations/captions_train2017.json",
|
172 |
+
onehot_segmentation=onehot_segmentation,
|
173 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
|
174 |
+
|
175 |
+
def get_split(self):
|
176 |
+
return "train"
|
177 |
+
|
178 |
+
def year(self):
|
179 |
+
return '2017'
|
180 |
+
|
181 |
+
|
182 |
+
class CocoImagesAndCaptionsValidation2017(CocoBase):
|
183 |
+
"""returns a pair of (image, caption)"""
|
184 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
185 |
+
given_files=None):
|
186 |
+
super().__init__(size=size,
|
187 |
+
dataroot="data/coco/val2017",
|
188 |
+
datajson="data/coco/annotations/captions_val2017.json",
|
189 |
+
onehot_segmentation=onehot_segmentation,
|
190 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
191 |
+
given_files=given_files)
|
192 |
+
|
193 |
+
def get_split(self):
|
194 |
+
return "validation"
|
195 |
+
|
196 |
+
def year(self):
|
197 |
+
return '2017'
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
class CocoImagesAndCaptionsTrain2014(CocoBase):
|
202 |
+
"""returns a pair of (image, caption)"""
|
203 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
|
204 |
+
super().__init__(size=size,
|
205 |
+
dataroot="data/coco/train2014",
|
206 |
+
datajson="data/coco/annotations2014/annotations/captions_train2014.json",
|
207 |
+
onehot_segmentation=onehot_segmentation,
|
208 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
209 |
+
use_segmentation=False,
|
210 |
+
crop_type=crop_type)
|
211 |
+
|
212 |
+
def get_split(self):
|
213 |
+
return "train"
|
214 |
+
|
215 |
+
def year(self):
|
216 |
+
return '2014'
|
217 |
+
|
218 |
+
class CocoImagesAndCaptionsValidation2014(CocoBase):
|
219 |
+
"""returns a pair of (image, caption)"""
|
220 |
+
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
|
221 |
+
given_files=None,crop_type='center',**kwargs):
|
222 |
+
super().__init__(size=size,
|
223 |
+
dataroot="data/coco/val2014",
|
224 |
+
datajson="data/coco/annotations2014/annotations/captions_val2014.json",
|
225 |
+
onehot_segmentation=onehot_segmentation,
|
226 |
+
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
|
227 |
+
given_files=given_files,
|
228 |
+
use_segmentation=False,
|
229 |
+
crop_type=crop_type)
|
230 |
+
|
231 |
+
def get_split(self):
|
232 |
+
return "validation"
|
233 |
+
|
234 |
+
def year(self):
|
235 |
+
return '2014'
|
236 |
+
|
237 |
+
if __name__ == '__main__':
|
238 |
+
with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
|
239 |
+
json_data = json.load(json_file)
|
240 |
+
capdirs = json_data["annotations"]
|
241 |
+
import pudb; pudb.set_trace()
|
242 |
+
#d2 = CocoImagesAndCaptionsTrain2014(size=256)
|
243 |
+
d2 = CocoImagesAndCaptionsValidation2014(size=256)
|
244 |
+
print("constructed dataset.")
|
245 |
+
print(f"length of {d2.__class__.__name__}: {len(d2)}")
|
246 |
+
|
247 |
+
ex2 = d2[0]
|
248 |
+
# ex3 = d3[0]
|
249 |
+
# print(ex1["image"].shape)
|
250 |
+
print(ex2["image"].shape)
|
251 |
+
# print(ex3["image"].shape)
|
252 |
+
# print(ex1["segmentation"].shape)
|
253 |
+
print(ex2["caption"].__class__.__name__)
|
ldm/data/dummy.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import string
|
4 |
+
from torch.utils.data import Dataset, Subset
|
5 |
+
|
6 |
+
class DummyData(Dataset):
|
7 |
+
def __init__(self, length, size):
|
8 |
+
self.length = length
|
9 |
+
self.size = size
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.length
|
13 |
+
|
14 |
+
def __getitem__(self, i):
|
15 |
+
x = np.random.randn(*self.size)
|
16 |
+
letters = string.ascii_lowercase
|
17 |
+
y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
|
18 |
+
return {"jpg": x, "txt": y}
|
19 |
+
|
20 |
+
|
21 |
+
class DummyDataWithEmbeddings(Dataset):
|
22 |
+
def __init__(self, length, size, emb_size):
|
23 |
+
self.length = length
|
24 |
+
self.size = size
|
25 |
+
self.emb_size = emb_size
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return self.length
|
29 |
+
|
30 |
+
def __getitem__(self, i):
|
31 |
+
x = np.random.randn(*self.size)
|
32 |
+
y = np.random.randn(*self.emb_size).astype(np.float32)
|
33 |
+
return {"jpg": x, "txt": y}
|
34 |
+
|
ldm/data/imagenet.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, yaml, pickle, shutil, tarfile, glob
|
2 |
+
import cv2
|
3 |
+
import albumentations
|
4 |
+
import PIL
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms.functional as TF
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from functools import partial
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
from torch.utils.data import Dataset, Subset
|
12 |
+
|
13 |
+
import taming.data.utils as tdu
|
14 |
+
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
15 |
+
from taming.data.imagenet import ImagePaths
|
16 |
+
|
17 |
+
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
18 |
+
|
19 |
+
|
20 |
+
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
21 |
+
with open(path_to_yaml) as f:
|
22 |
+
di2s = yaml.load(f)
|
23 |
+
return dict((v,k) for k,v in di2s.items())
|
24 |
+
|
25 |
+
|
26 |
+
class ImageNetBase(Dataset):
|
27 |
+
def __init__(self, config=None):
|
28 |
+
self.config = config or OmegaConf.create()
|
29 |
+
if not type(self.config)==dict:
|
30 |
+
self.config = OmegaConf.to_container(self.config)
|
31 |
+
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
32 |
+
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
33 |
+
self._prepare()
|
34 |
+
self._prepare_synset_to_human()
|
35 |
+
self._prepare_idx_to_synset()
|
36 |
+
self._prepare_human_to_integer_label()
|
37 |
+
self._load()
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return len(self.data)
|
41 |
+
|
42 |
+
def __getitem__(self, i):
|
43 |
+
return self.data[i]
|
44 |
+
|
45 |
+
def _prepare(self):
|
46 |
+
raise NotImplementedError()
|
47 |
+
|
48 |
+
def _filter_relpaths(self, relpaths):
|
49 |
+
ignore = set([
|
50 |
+
"n06596364_9591.JPEG",
|
51 |
+
])
|
52 |
+
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
53 |
+
if "sub_indices" in self.config:
|
54 |
+
indices = str_to_indices(self.config["sub_indices"])
|
55 |
+
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
56 |
+
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
57 |
+
files = []
|
58 |
+
for rpath in relpaths:
|
59 |
+
syn = rpath.split("/")[0]
|
60 |
+
if syn in synsets:
|
61 |
+
files.append(rpath)
|
62 |
+
return files
|
63 |
+
else:
|
64 |
+
return relpaths
|
65 |
+
|
66 |
+
def _prepare_synset_to_human(self):
|
67 |
+
SIZE = 2655750
|
68 |
+
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
69 |
+
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
70 |
+
if (not os.path.exists(self.human_dict) or
|
71 |
+
not os.path.getsize(self.human_dict)==SIZE):
|
72 |
+
download(URL, self.human_dict)
|
73 |
+
|
74 |
+
def _prepare_idx_to_synset(self):
|
75 |
+
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
76 |
+
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
77 |
+
if (not os.path.exists(self.idx2syn)):
|
78 |
+
download(URL, self.idx2syn)
|
79 |
+
|
80 |
+
def _prepare_human_to_integer_label(self):
|
81 |
+
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
82 |
+
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
83 |
+
if (not os.path.exists(self.human2integer)):
|
84 |
+
download(URL, self.human2integer)
|
85 |
+
with open(self.human2integer, "r") as f:
|
86 |
+
lines = f.read().splitlines()
|
87 |
+
assert len(lines) == 1000
|
88 |
+
self.human2integer_dict = dict()
|
89 |
+
for line in lines:
|
90 |
+
value, key = line.split(":")
|
91 |
+
self.human2integer_dict[key] = int(value)
|
92 |
+
|
93 |
+
def _load(self):
|
94 |
+
with open(self.txt_filelist, "r") as f:
|
95 |
+
self.relpaths = f.read().splitlines()
|
96 |
+
l1 = len(self.relpaths)
|
97 |
+
self.relpaths = self._filter_relpaths(self.relpaths)
|
98 |
+
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
99 |
+
|
100 |
+
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
101 |
+
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
102 |
+
|
103 |
+
unique_synsets = np.unique(self.synsets)
|
104 |
+
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
105 |
+
if not self.keep_orig_class_label:
|
106 |
+
self.class_labels = [class_dict[s] for s in self.synsets]
|
107 |
+
else:
|
108 |
+
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
109 |
+
|
110 |
+
with open(self.human_dict, "r") as f:
|
111 |
+
human_dict = f.read().splitlines()
|
112 |
+
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
113 |
+
|
114 |
+
self.human_labels = [human_dict[s] for s in self.synsets]
|
115 |
+
|
116 |
+
labels = {
|
117 |
+
"relpath": np.array(self.relpaths),
|
118 |
+
"synsets": np.array(self.synsets),
|
119 |
+
"class_label": np.array(self.class_labels),
|
120 |
+
"human_label": np.array(self.human_labels),
|
121 |
+
}
|
122 |
+
|
123 |
+
if self.process_images:
|
124 |
+
self.size = retrieve(self.config, "size", default=256)
|
125 |
+
self.data = ImagePaths(self.abspaths,
|
126 |
+
labels=labels,
|
127 |
+
size=self.size,
|
128 |
+
random_crop=self.random_crop,
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
self.data = self.abspaths
|
132 |
+
|
133 |
+
|
134 |
+
class ImageNetTrain(ImageNetBase):
|
135 |
+
NAME = "ILSVRC2012_train"
|
136 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
137 |
+
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
138 |
+
FILES = [
|
139 |
+
"ILSVRC2012_img_train.tar",
|
140 |
+
]
|
141 |
+
SIZES = [
|
142 |
+
147897477120,
|
143 |
+
]
|
144 |
+
|
145 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
146 |
+
self.process_images = process_images
|
147 |
+
self.data_root = data_root
|
148 |
+
super().__init__(**kwargs)
|
149 |
+
|
150 |
+
def _prepare(self):
|
151 |
+
if self.data_root:
|
152 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
153 |
+
else:
|
154 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
155 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
156 |
+
|
157 |
+
self.datadir = os.path.join(self.root, "data")
|
158 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
159 |
+
self.expected_length = 1281167
|
160 |
+
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
161 |
+
default=True)
|
162 |
+
if not tdu.is_prepared(self.root):
|
163 |
+
# prep
|
164 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
165 |
+
|
166 |
+
datadir = self.datadir
|
167 |
+
if not os.path.exists(datadir):
|
168 |
+
path = os.path.join(self.root, self.FILES[0])
|
169 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
170 |
+
import academictorrents as at
|
171 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
172 |
+
assert atpath == path
|
173 |
+
|
174 |
+
print("Extracting {} to {}".format(path, datadir))
|
175 |
+
os.makedirs(datadir, exist_ok=True)
|
176 |
+
with tarfile.open(path, "r:") as tar:
|
177 |
+
tar.extractall(path=datadir)
|
178 |
+
|
179 |
+
print("Extracting sub-tars.")
|
180 |
+
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
181 |
+
for subpath in tqdm(subpaths):
|
182 |
+
subdir = subpath[:-len(".tar")]
|
183 |
+
os.makedirs(subdir, exist_ok=True)
|
184 |
+
with tarfile.open(subpath, "r:") as tar:
|
185 |
+
tar.extractall(path=subdir)
|
186 |
+
|
187 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
188 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
189 |
+
filelist = sorted(filelist)
|
190 |
+
filelist = "\n".join(filelist)+"\n"
|
191 |
+
with open(self.txt_filelist, "w") as f:
|
192 |
+
f.write(filelist)
|
193 |
+
|
194 |
+
tdu.mark_prepared(self.root)
|
195 |
+
|
196 |
+
|
197 |
+
class ImageNetValidation(ImageNetBase):
|
198 |
+
NAME = "ILSVRC2012_validation"
|
199 |
+
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
200 |
+
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
201 |
+
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
202 |
+
FILES = [
|
203 |
+
"ILSVRC2012_img_val.tar",
|
204 |
+
"validation_synset.txt",
|
205 |
+
]
|
206 |
+
SIZES = [
|
207 |
+
6744924160,
|
208 |
+
1950000,
|
209 |
+
]
|
210 |
+
|
211 |
+
def __init__(self, process_images=True, data_root=None, **kwargs):
|
212 |
+
self.data_root = data_root
|
213 |
+
self.process_images = process_images
|
214 |
+
super().__init__(**kwargs)
|
215 |
+
|
216 |
+
def _prepare(self):
|
217 |
+
if self.data_root:
|
218 |
+
self.root = os.path.join(self.data_root, self.NAME)
|
219 |
+
else:
|
220 |
+
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
221 |
+
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
222 |
+
self.datadir = os.path.join(self.root, "data")
|
223 |
+
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
224 |
+
self.expected_length = 50000
|
225 |
+
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
226 |
+
default=False)
|
227 |
+
if not tdu.is_prepared(self.root):
|
228 |
+
# prep
|
229 |
+
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
230 |
+
|
231 |
+
datadir = self.datadir
|
232 |
+
if not os.path.exists(datadir):
|
233 |
+
path = os.path.join(self.root, self.FILES[0])
|
234 |
+
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
235 |
+
import academictorrents as at
|
236 |
+
atpath = at.get(self.AT_HASH, datastore=self.root)
|
237 |
+
assert atpath == path
|
238 |
+
|
239 |
+
print("Extracting {} to {}".format(path, datadir))
|
240 |
+
os.makedirs(datadir, exist_ok=True)
|
241 |
+
with tarfile.open(path, "r:") as tar:
|
242 |
+
tar.extractall(path=datadir)
|
243 |
+
|
244 |
+
vspath = os.path.join(self.root, self.FILES[1])
|
245 |
+
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
246 |
+
download(self.VS_URL, vspath)
|
247 |
+
|
248 |
+
with open(vspath, "r") as f:
|
249 |
+
synset_dict = f.read().splitlines()
|
250 |
+
synset_dict = dict(line.split() for line in synset_dict)
|
251 |
+
|
252 |
+
print("Reorganizing into synset folders")
|
253 |
+
synsets = np.unique(list(synset_dict.values()))
|
254 |
+
for s in synsets:
|
255 |
+
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
256 |
+
for k, v in synset_dict.items():
|
257 |
+
src = os.path.join(datadir, k)
|
258 |
+
dst = os.path.join(datadir, v)
|
259 |
+
shutil.move(src, dst)
|
260 |
+
|
261 |
+
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
262 |
+
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
263 |
+
filelist = sorted(filelist)
|
264 |
+
filelist = "\n".join(filelist)+"\n"
|
265 |
+
with open(self.txt_filelist, "w") as f:
|
266 |
+
f.write(filelist)
|
267 |
+
|
268 |
+
tdu.mark_prepared(self.root)
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
class ImageNetSR(Dataset):
|
273 |
+
def __init__(self, size=None,
|
274 |
+
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
275 |
+
random_crop=True):
|
276 |
+
"""
|
277 |
+
Imagenet Superresolution Dataloader
|
278 |
+
Performs following ops in order:
|
279 |
+
1. crops a crop of size s from image either as random or center crop
|
280 |
+
2. resizes crop to size with cv2.area_interpolation
|
281 |
+
3. degrades resized crop with degradation_fn
|
282 |
+
|
283 |
+
:param size: resizing to size after cropping
|
284 |
+
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
285 |
+
:param downscale_f: Low Resolution Downsample factor
|
286 |
+
:param min_crop_f: determines crop size s,
|
287 |
+
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
288 |
+
:param max_crop_f: ""
|
289 |
+
:param data_root:
|
290 |
+
:param random_crop:
|
291 |
+
"""
|
292 |
+
self.base = self.get_base()
|
293 |
+
assert size
|
294 |
+
assert (size / downscale_f).is_integer()
|
295 |
+
self.size = size
|
296 |
+
self.LR_size = int(size / downscale_f)
|
297 |
+
self.min_crop_f = min_crop_f
|
298 |
+
self.max_crop_f = max_crop_f
|
299 |
+
assert(max_crop_f <= 1.)
|
300 |
+
self.center_crop = not random_crop
|
301 |
+
|
302 |
+
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
303 |
+
|
304 |
+
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
305 |
+
|
306 |
+
if degradation == "bsrgan":
|
307 |
+
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
308 |
+
|
309 |
+
elif degradation == "bsrgan_light":
|
310 |
+
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
311 |
+
|
312 |
+
else:
|
313 |
+
interpolation_fn = {
|
314 |
+
"cv_nearest": cv2.INTER_NEAREST,
|
315 |
+
"cv_bilinear": cv2.INTER_LINEAR,
|
316 |
+
"cv_bicubic": cv2.INTER_CUBIC,
|
317 |
+
"cv_area": cv2.INTER_AREA,
|
318 |
+
"cv_lanczos": cv2.INTER_LANCZOS4,
|
319 |
+
"pil_nearest": PIL.Image.NEAREST,
|
320 |
+
"pil_bilinear": PIL.Image.BILINEAR,
|
321 |
+
"pil_bicubic": PIL.Image.BICUBIC,
|
322 |
+
"pil_box": PIL.Image.BOX,
|
323 |
+
"pil_hamming": PIL.Image.HAMMING,
|
324 |
+
"pil_lanczos": PIL.Image.LANCZOS,
|
325 |
+
}[degradation]
|
326 |
+
|
327 |
+
self.pil_interpolation = degradation.startswith("pil_")
|
328 |
+
|
329 |
+
if self.pil_interpolation:
|
330 |
+
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
331 |
+
|
332 |
+
else:
|
333 |
+
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
334 |
+
interpolation=interpolation_fn)
|
335 |
+
|
336 |
+
def __len__(self):
|
337 |
+
return len(self.base)
|
338 |
+
|
339 |
+
def __getitem__(self, i):
|
340 |
+
example = self.base[i]
|
341 |
+
image = Image.open(example["file_path_"])
|
342 |
+
|
343 |
+
if not image.mode == "RGB":
|
344 |
+
image = image.convert("RGB")
|
345 |
+
|
346 |
+
image = np.array(image).astype(np.uint8)
|
347 |
+
|
348 |
+
min_side_len = min(image.shape[:2])
|
349 |
+
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
350 |
+
crop_side_len = int(crop_side_len)
|
351 |
+
|
352 |
+
if self.center_crop:
|
353 |
+
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
354 |
+
|
355 |
+
else:
|
356 |
+
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
357 |
+
|
358 |
+
image = self.cropper(image=image)["image"]
|
359 |
+
image = self.image_rescaler(image=image)["image"]
|
360 |
+
|
361 |
+
if self.pil_interpolation:
|
362 |
+
image_pil = PIL.Image.fromarray(image)
|
363 |
+
LR_image = self.degradation_process(image_pil)
|
364 |
+
LR_image = np.array(LR_image).astype(np.uint8)
|
365 |
+
|
366 |
+
else:
|
367 |
+
LR_image = self.degradation_process(image=image)["image"]
|
368 |
+
|
369 |
+
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
370 |
+
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
371 |
+
example["caption"] = example["human_label"] # dummy caption
|
372 |
+
return example
|
373 |
+
|
374 |
+
|
375 |
+
class ImageNetSRTrain(ImageNetSR):
|
376 |
+
def __init__(self, **kwargs):
|
377 |
+
super().__init__(**kwargs)
|
378 |
+
|
379 |
+
def get_base(self):
|
380 |
+
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
381 |
+
indices = pickle.load(f)
|
382 |
+
dset = ImageNetTrain(process_images=False,)
|
383 |
+
return Subset(dset, indices)
|
384 |
+
|
385 |
+
|
386 |
+
class ImageNetSRValidation(ImageNetSR):
|
387 |
+
def __init__(self, **kwargs):
|
388 |
+
super().__init__(**kwargs)
|
389 |
+
|
390 |
+
def get_base(self):
|
391 |
+
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
392 |
+
indices = pickle.load(f)
|
393 |
+
dset = ImageNetValidation(process_images=False,)
|
394 |
+
return Subset(dset, indices)
|
ldm/data/inpainting/__init__.py
ADDED
File without changes
|
ldm/data/inpainting/synthetic_mask.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
settings = {
|
5 |
+
"256narrow": {
|
6 |
+
"p_irr": 1,
|
7 |
+
"min_n_irr": 4,
|
8 |
+
"max_n_irr": 50,
|
9 |
+
"max_l_irr": 40,
|
10 |
+
"max_w_irr": 10,
|
11 |
+
"min_n_box": None,
|
12 |
+
"max_n_box": None,
|
13 |
+
"min_s_box": None,
|
14 |
+
"max_s_box": None,
|
15 |
+
"marg": None,
|
16 |
+
},
|
17 |
+
"256train": {
|
18 |
+
"p_irr": 0.5,
|
19 |
+
"min_n_irr": 1,
|
20 |
+
"max_n_irr": 5,
|
21 |
+
"max_l_irr": 200,
|
22 |
+
"max_w_irr": 100,
|
23 |
+
"min_n_box": 1,
|
24 |
+
"max_n_box": 4,
|
25 |
+
"min_s_box": 30,
|
26 |
+
"max_s_box": 150,
|
27 |
+
"marg": 10,
|
28 |
+
},
|
29 |
+
"512train": { # TODO: experimental
|
30 |
+
"p_irr": 0.5,
|
31 |
+
"min_n_irr": 1,
|
32 |
+
"max_n_irr": 5,
|
33 |
+
"max_l_irr": 450,
|
34 |
+
"max_w_irr": 250,
|
35 |
+
"min_n_box": 1,
|
36 |
+
"max_n_box": 4,
|
37 |
+
"min_s_box": 30,
|
38 |
+
"max_s_box": 300,
|
39 |
+
"marg": 10,
|
40 |
+
},
|
41 |
+
"512train-large": { # TODO: experimental
|
42 |
+
"p_irr": 0.5,
|
43 |
+
"min_n_irr": 1,
|
44 |
+
"max_n_irr": 5,
|
45 |
+
"max_l_irr": 450,
|
46 |
+
"max_w_irr": 400,
|
47 |
+
"min_n_box": 1,
|
48 |
+
"max_n_box": 4,
|
49 |
+
"min_s_box": 75,
|
50 |
+
"max_s_box": 450,
|
51 |
+
"marg": 10,
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def gen_segment_mask(mask, start, end, brush_width):
|
57 |
+
mask = mask > 0
|
58 |
+
mask = (255 * mask).astype(np.uint8)
|
59 |
+
mask = Image.fromarray(mask)
|
60 |
+
draw = ImageDraw.Draw(mask)
|
61 |
+
draw.line([start, end], fill=255, width=brush_width, joint="curve")
|
62 |
+
mask = np.array(mask) / 255
|
63 |
+
return mask
|
64 |
+
|
65 |
+
|
66 |
+
def gen_box_mask(mask, masked):
|
67 |
+
x_0, y_0, w, h = masked
|
68 |
+
mask[y_0:y_0 + h, x_0:x_0 + w] = 1
|
69 |
+
return mask
|
70 |
+
|
71 |
+
|
72 |
+
def gen_round_mask(mask, masked, radius):
|
73 |
+
x_0, y_0, w, h = masked
|
74 |
+
xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
|
75 |
+
|
76 |
+
mask = mask > 0
|
77 |
+
mask = (255 * mask).astype(np.uint8)
|
78 |
+
mask = Image.fromarray(mask)
|
79 |
+
draw = ImageDraw.Draw(mask)
|
80 |
+
draw.rounded_rectangle(xy, radius=radius, fill=255)
|
81 |
+
mask = np.array(mask) / 255
|
82 |
+
return mask
|
83 |
+
|
84 |
+
|
85 |
+
def gen_large_mask(prng, img_h, img_w,
|
86 |
+
marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
|
87 |
+
min_n_box, max_n_box, min_s_box, max_s_box):
|
88 |
+
"""
|
89 |
+
img_h: int, an image height
|
90 |
+
img_w: int, an image width
|
91 |
+
marg: int, a margin for a box starting coordinate
|
92 |
+
p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
|
93 |
+
|
94 |
+
min_n_irr: int, min number of segments
|
95 |
+
max_n_irr: int, max number of segments
|
96 |
+
max_l_irr: max length of a segment in polygonal chain
|
97 |
+
max_w_irr: max width of a segment in polygonal chain
|
98 |
+
|
99 |
+
min_n_box: int, min bound for the number of box primitives
|
100 |
+
max_n_box: int, max bound for the number of box primitives
|
101 |
+
min_s_box: int, min length of a box side
|
102 |
+
max_s_box: int, max length of a box side
|
103 |
+
"""
|
104 |
+
|
105 |
+
mask = np.zeros((img_h, img_w))
|
106 |
+
uniform = prng.randint
|
107 |
+
|
108 |
+
if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
|
109 |
+
n = uniform(min_n_irr, max_n_irr) # sample number of segments
|
110 |
+
|
111 |
+
for _ in range(n):
|
112 |
+
y = uniform(0, img_h) # sample a starting point
|
113 |
+
x = uniform(0, img_w)
|
114 |
+
|
115 |
+
a = uniform(0, 360) # sample angle
|
116 |
+
l = uniform(10, max_l_irr) # sample segment length
|
117 |
+
w = uniform(5, max_w_irr) # sample a segment width
|
118 |
+
|
119 |
+
# draw segment starting from (x,y) to (x_,y_) using brush of width w
|
120 |
+
x_ = x + l * np.sin(a)
|
121 |
+
y_ = y + l * np.cos(a)
|
122 |
+
|
123 |
+
mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
|
124 |
+
x, y = x_, y_
|
125 |
+
else: # generate Box masks
|
126 |
+
n = uniform(min_n_box, max_n_box) # sample number of rectangles
|
127 |
+
|
128 |
+
for _ in range(n):
|
129 |
+
h = uniform(min_s_box, max_s_box) # sample box shape
|
130 |
+
w = uniform(min_s_box, max_s_box)
|
131 |
+
|
132 |
+
x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
|
133 |
+
y_0 = uniform(marg, img_h - marg - h)
|
134 |
+
|
135 |
+
if np.random.uniform(0, 1) < 0.5:
|
136 |
+
mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
|
137 |
+
else:
|
138 |
+
r = uniform(0, 60) # sample radius
|
139 |
+
mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
|
140 |
+
return mask
|
141 |
+
|
142 |
+
|
143 |
+
make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
|
144 |
+
make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
|
145 |
+
make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
|
146 |
+
make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
|
147 |
+
|
148 |
+
|
149 |
+
MASK_MODES = {
|
150 |
+
"256train": make_lama_mask,
|
151 |
+
"256narrow": make_narrow_lama_mask,
|
152 |
+
"512train": make_512_lama_mask,
|
153 |
+
"512train-large": make_512_lama_mask_large
|
154 |
+
}
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
import sys
|
158 |
+
|
159 |
+
out = sys.argv[1]
|
160 |
+
|
161 |
+
prng = np.random.RandomState(1)
|
162 |
+
kwargs = settings["256train"]
|
163 |
+
mask = gen_large_mask(prng, 256, 256, **kwargs)
|
164 |
+
mask = (255 * mask).astype(np.uint8)
|
165 |
+
mask = Image.fromarray(mask)
|
166 |
+
mask.save(out)
|
ldm/data/laion.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import kornia
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image
|
8 |
+
import glob
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
import pytorch_lightning as pl
|
12 |
+
from tqdm import tqdm
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
from einops import rearrange
|
15 |
+
import torch
|
16 |
+
from webdataset.handlers import warn_and_continue
|
17 |
+
|
18 |
+
|
19 |
+
from ldm.util import instantiate_from_config
|
20 |
+
from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
|
21 |
+
from ldm.data.base import PRNGMixin
|
22 |
+
|
23 |
+
|
24 |
+
class DataWithWings(torch.utils.data.IterableDataset):
|
25 |
+
def __init__(self, min_size, transform=None, target_transform=None):
|
26 |
+
self.min_size = min_size
|
27 |
+
self.transform = transform if transform is not None else nn.Identity()
|
28 |
+
self.target_transform = target_transform if target_transform is not None else nn.Identity()
|
29 |
+
self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
|
30 |
+
self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
|
31 |
+
self.pwatermark_threshold = 0.8
|
32 |
+
self.punsafe_threshold = 0.5
|
33 |
+
self.aesthetic_threshold = 5.
|
34 |
+
self.total_samples = 0
|
35 |
+
self.samples = 0
|
36 |
+
location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
|
37 |
+
|
38 |
+
self.inner_dataset = wds.DataPipeline(
|
39 |
+
wds.ResampledShards(location),
|
40 |
+
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
41 |
+
wds.shuffle(1000, handler=wds.warn_and_continue),
|
42 |
+
wds.decode('pilrgb', handler=wds.warn_and_continue),
|
43 |
+
wds.map(self._add_tags, handler=wds.ignore_and_continue),
|
44 |
+
wds.select(self._filter_predicate),
|
45 |
+
wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
|
46 |
+
wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
|
47 |
+
)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _compute_hash(url, text):
|
51 |
+
if url is None:
|
52 |
+
url = ''
|
53 |
+
if text is None:
|
54 |
+
text = ''
|
55 |
+
total = (url + text).encode('utf-8')
|
56 |
+
return mmh3.hash64(total)[0]
|
57 |
+
|
58 |
+
def _add_tags(self, x):
|
59 |
+
hsh = self._compute_hash(x['json']['url'], x['txt'])
|
60 |
+
pwatermark, punsafe = self.kv[hsh]
|
61 |
+
aesthetic = self.kv_aesthetic[hsh][0]
|
62 |
+
return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
|
63 |
+
|
64 |
+
def _punsafe_to_class(self, punsafe):
|
65 |
+
return torch.tensor(punsafe >= self.punsafe_threshold).long()
|
66 |
+
|
67 |
+
def _filter_predicate(self, x):
|
68 |
+
try:
|
69 |
+
return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
70 |
+
except:
|
71 |
+
return False
|
72 |
+
|
73 |
+
def __iter__(self):
|
74 |
+
return iter(self.inner_dataset)
|
75 |
+
|
76 |
+
|
77 |
+
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
78 |
+
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
|
79 |
+
If `tensors` is True, `ndarray` objects are combined into
|
80 |
+
tensor batches.
|
81 |
+
:param dict samples: list of samples
|
82 |
+
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
|
83 |
+
:returns: single sample consisting of a batch
|
84 |
+
:rtype: dict
|
85 |
+
"""
|
86 |
+
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
87 |
+
batched = {key: [] for key in keys}
|
88 |
+
|
89 |
+
for s in samples:
|
90 |
+
[batched[key].append(s[key]) for key in batched]
|
91 |
+
|
92 |
+
result = {}
|
93 |
+
for key in batched:
|
94 |
+
if isinstance(batched[key][0], (int, float)):
|
95 |
+
if combine_scalars:
|
96 |
+
result[key] = np.array(list(batched[key]))
|
97 |
+
elif isinstance(batched[key][0], torch.Tensor):
|
98 |
+
if combine_tensors:
|
99 |
+
result[key] = torch.stack(list(batched[key]))
|
100 |
+
elif isinstance(batched[key][0], np.ndarray):
|
101 |
+
if combine_tensors:
|
102 |
+
result[key] = np.array(list(batched[key]))
|
103 |
+
else:
|
104 |
+
result[key] = list(batched[key])
|
105 |
+
return result
|
106 |
+
|
107 |
+
|
108 |
+
class WebDataModuleFromConfig(pl.LightningDataModule):
|
109 |
+
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
110 |
+
test=None, num_workers=4, multinode=True, min_size=None,
|
111 |
+
max_pwatermark=1.0,
|
112 |
+
**kwargs):
|
113 |
+
super().__init__(self)
|
114 |
+
print(f'Setting tar base to {tar_base}')
|
115 |
+
self.tar_base = tar_base
|
116 |
+
self.batch_size = batch_size
|
117 |
+
self.num_workers = num_workers
|
118 |
+
self.train = train
|
119 |
+
self.validation = validation
|
120 |
+
self.test = test
|
121 |
+
self.multinode = multinode
|
122 |
+
self.min_size = min_size # filter out very small images
|
123 |
+
self.max_pwatermark = max_pwatermark # filter out watermarked images
|
124 |
+
|
125 |
+
def make_loader(self, dataset_config, train=True):
|
126 |
+
if 'image_transforms' in dataset_config:
|
127 |
+
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
|
128 |
+
else:
|
129 |
+
image_transforms = []
|
130 |
+
|
131 |
+
image_transforms.extend([torchvision.transforms.ToTensor(),
|
132 |
+
torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
133 |
+
image_transforms = torchvision.transforms.Compose(image_transforms)
|
134 |
+
|
135 |
+
if 'transforms' in dataset_config:
|
136 |
+
transforms_config = OmegaConf.to_container(dataset_config.transforms)
|
137 |
+
else:
|
138 |
+
transforms_config = dict()
|
139 |
+
|
140 |
+
transform_dict = {dkey: load_partial_from_config(transforms_config[dkey])
|
141 |
+
if transforms_config[dkey] != 'identity' else identity
|
142 |
+
for dkey in transforms_config}
|
143 |
+
img_key = dataset_config.get('image_key', 'jpeg')
|
144 |
+
transform_dict.update({img_key: image_transforms})
|
145 |
+
|
146 |
+
if 'postprocess' in dataset_config:
|
147 |
+
postprocess = instantiate_from_config(dataset_config['postprocess'])
|
148 |
+
else:
|
149 |
+
postprocess = None
|
150 |
+
|
151 |
+
shuffle = dataset_config.get('shuffle', 0)
|
152 |
+
shardshuffle = shuffle > 0
|
153 |
+
|
154 |
+
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
|
155 |
+
|
156 |
+
if self.tar_base == "__improvedaesthetic__":
|
157 |
+
print("## Warning, loading the same improved aesthetic dataset "
|
158 |
+
"for all splits and ignoring shards parameter.")
|
159 |
+
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
160 |
+
else:
|
161 |
+
tars = os.path.join(self.tar_base, dataset_config.shards)
|
162 |
+
|
163 |
+
dset = wds.WebDataset(
|
164 |
+
tars,
|
165 |
+
nodesplitter=nodesplitter,
|
166 |
+
shardshuffle=shardshuffle,
|
167 |
+
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
|
168 |
+
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
169 |
+
|
170 |
+
dset = (dset
|
171 |
+
.select(self.filter_keys)
|
172 |
+
.decode('pil', handler=wds.warn_and_continue)
|
173 |
+
.select(self.filter_size)
|
174 |
+
.map_dict(**transform_dict, handler=wds.warn_and_continue)
|
175 |
+
)
|
176 |
+
if postprocess is not None:
|
177 |
+
dset = dset.map(postprocess)
|
178 |
+
dset = (dset
|
179 |
+
.batched(self.batch_size, partial=False,
|
180 |
+
collation_fn=dict_collation_fn)
|
181 |
+
)
|
182 |
+
|
183 |
+
loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
|
184 |
+
num_workers=self.num_workers)
|
185 |
+
|
186 |
+
return loader
|
187 |
+
|
188 |
+
def filter_size(self, x):
|
189 |
+
try:
|
190 |
+
valid = True
|
191 |
+
if self.min_size is not None and self.min_size > 1:
|
192 |
+
try:
|
193 |
+
valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
194 |
+
except Exception:
|
195 |
+
valid = False
|
196 |
+
if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
|
197 |
+
try:
|
198 |
+
valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
|
199 |
+
except Exception:
|
200 |
+
valid = False
|
201 |
+
return valid
|
202 |
+
except Exception:
|
203 |
+
return False
|
204 |
+
|
205 |
+
def filter_keys(self, x):
|
206 |
+
try:
|
207 |
+
return ("jpg" in x) and ("txt" in x)
|
208 |
+
except Exception:
|
209 |
+
return False
|
210 |
+
|
211 |
+
def train_dataloader(self):
|
212 |
+
return self.make_loader(self.train)
|
213 |
+
|
214 |
+
def val_dataloader(self):
|
215 |
+
return self.make_loader(self.validation, train=False)
|
216 |
+
|
217 |
+
def test_dataloader(self):
|
218 |
+
return self.make_loader(self.test, train=False)
|
219 |
+
|
220 |
+
|
221 |
+
from ldm.modules.image_degradation import degradation_fn_bsr_light
|
222 |
+
import cv2
|
223 |
+
|
224 |
+
class AddLR(object):
|
225 |
+
def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
|
226 |
+
self.factor = factor
|
227 |
+
self.output_size = output_size
|
228 |
+
self.image_key = image_key
|
229 |
+
self.initial_size = initial_size
|
230 |
+
|
231 |
+
def pt2np(self, x):
|
232 |
+
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
|
233 |
+
return x
|
234 |
+
|
235 |
+
def np2pt(self, x):
|
236 |
+
x = torch.from_numpy(x)/127.5-1.0
|
237 |
+
return x
|
238 |
+
|
239 |
+
def __call__(self, sample):
|
240 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
241 |
+
x = self.pt2np(sample[self.image_key])
|
242 |
+
if self.initial_size is not None:
|
243 |
+
x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
|
244 |
+
x = degradation_fn_bsr_light(x, sf=self.factor)['image']
|
245 |
+
x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
|
246 |
+
x = self.np2pt(x)
|
247 |
+
sample['lr'] = x
|
248 |
+
return sample
|
249 |
+
|
250 |
+
class AddBW(object):
|
251 |
+
def __init__(self, image_key="jpg"):
|
252 |
+
self.image_key = image_key
|
253 |
+
|
254 |
+
def pt2np(self, x):
|
255 |
+
x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
|
256 |
+
return x
|
257 |
+
|
258 |
+
def np2pt(self, x):
|
259 |
+
x = torch.from_numpy(x)/127.5-1.0
|
260 |
+
return x
|
261 |
+
|
262 |
+
def __call__(self, sample):
|
263 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
264 |
+
x = sample[self.image_key]
|
265 |
+
w = torch.rand(3, device=x.device)
|
266 |
+
w /= w.sum()
|
267 |
+
out = torch.einsum('hwc,c->hw', x, w)
|
268 |
+
|
269 |
+
# Keep as 3ch so we can pass to encoder, also we might want to add hints
|
270 |
+
sample['lr'] = out.unsqueeze(-1).tile(1,1,3)
|
271 |
+
return sample
|
272 |
+
|
273 |
+
class AddMask(PRNGMixin):
|
274 |
+
def __init__(self, mode="512train", p_drop=0.):
|
275 |
+
super().__init__()
|
276 |
+
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
277 |
+
self.make_mask = MASK_MODES[mode]
|
278 |
+
self.p_drop = p_drop
|
279 |
+
|
280 |
+
def __call__(self, sample):
|
281 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
282 |
+
x = sample['jpg']
|
283 |
+
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
284 |
+
if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
|
285 |
+
mask = np.ones_like(mask)
|
286 |
+
mask[mask < 0.5] = 0
|
287 |
+
mask[mask > 0.5] = 1
|
288 |
+
mask = torch.from_numpy(mask[..., None])
|
289 |
+
sample['mask'] = mask
|
290 |
+
sample['masked_image'] = x * (mask < 0.5)
|
291 |
+
return sample
|
292 |
+
|
293 |
+
|
294 |
+
class AddEdge(PRNGMixin):
|
295 |
+
def __init__(self, mode="512train", mask_edges=True):
|
296 |
+
super().__init__()
|
297 |
+
assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
|
298 |
+
self.make_mask = MASK_MODES[mode]
|
299 |
+
self.n_down_choices = [0]
|
300 |
+
self.sigma_choices = [1, 2]
|
301 |
+
self.mask_edges = mask_edges
|
302 |
+
|
303 |
+
@torch.no_grad()
|
304 |
+
def __call__(self, sample):
|
305 |
+
# sample['jpg'] is tensor hwc in [-1, 1] at this point
|
306 |
+
x = sample['jpg']
|
307 |
+
|
308 |
+
mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
|
309 |
+
mask[mask < 0.5] = 0
|
310 |
+
mask[mask > 0.5] = 1
|
311 |
+
mask = torch.from_numpy(mask[..., None])
|
312 |
+
sample['mask'] = mask
|
313 |
+
|
314 |
+
n_down_idx = self.prng.choice(len(self.n_down_choices))
|
315 |
+
sigma_idx = self.prng.choice(len(self.sigma_choices))
|
316 |
+
|
317 |
+
n_choices = len(self.n_down_choices)*len(self.sigma_choices)
|
318 |
+
raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx),
|
319 |
+
(len(self.n_down_choices), len(self.sigma_choices)))
|
320 |
+
normalized_idx = raveled_idx/max(1, n_choices-1)
|
321 |
+
|
322 |
+
n_down = self.n_down_choices[n_down_idx]
|
323 |
+
sigma = self.sigma_choices[sigma_idx]
|
324 |
+
|
325 |
+
kernel_size = 4*sigma+1
|
326 |
+
kernel_size = (kernel_size, kernel_size)
|
327 |
+
sigma = (sigma, sigma)
|
328 |
+
canny = kornia.filters.Canny(
|
329 |
+
low_threshold=0.1,
|
330 |
+
high_threshold=0.2,
|
331 |
+
kernel_size=kernel_size,
|
332 |
+
sigma=sigma,
|
333 |
+
hysteresis=True,
|
334 |
+
)
|
335 |
+
y = (x+1.0)/2.0 # in 01
|
336 |
+
y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
|
337 |
+
|
338 |
+
# down
|
339 |
+
for i_down in range(n_down):
|
340 |
+
size = min(y.shape[-2], y.shape[-1])//2
|
341 |
+
y = kornia.geometry.transform.resize(y, size, antialias=True)
|
342 |
+
|
343 |
+
# edge
|
344 |
+
_, y = canny(y)
|
345 |
+
|
346 |
+
if n_down > 0:
|
347 |
+
size = x.shape[0], x.shape[1]
|
348 |
+
y = kornia.geometry.transform.resize(y, size, interpolation="nearest")
|
349 |
+
|
350 |
+
y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
|
351 |
+
y = y*2.0-1.0
|
352 |
+
|
353 |
+
if self.mask_edges:
|
354 |
+
sample['masked_image'] = y * (mask < 0.5)
|
355 |
+
else:
|
356 |
+
sample['masked_image'] = y
|
357 |
+
sample['mask'] = torch.zeros_like(sample['mask'])
|
358 |
+
|
359 |
+
# concat normalized idx
|
360 |
+
sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx
|
361 |
+
|
362 |
+
return sample
|
363 |
+
|
364 |
+
|
365 |
+
def example00():
|
366 |
+
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
|
367 |
+
dataset = wds.WebDataset(url)
|
368 |
+
example = next(iter(dataset))
|
369 |
+
for k in example:
|
370 |
+
print(k, type(example[k]))
|
371 |
+
|
372 |
+
print(example["__key__"])
|
373 |
+
for k in ["json", "txt"]:
|
374 |
+
print(example[k].decode())
|
375 |
+
|
376 |
+
image = Image.open(io.BytesIO(example["jpg"]))
|
377 |
+
outdir = "tmp"
|
378 |
+
os.makedirs(outdir, exist_ok=True)
|
379 |
+
image.save(os.path.join(outdir, example["__key__"] + ".png"))
|
380 |
+
|
381 |
+
|
382 |
+
def load_example(example):
|
383 |
+
return {
|
384 |
+
"key": example["__key__"],
|
385 |
+
"image": Image.open(io.BytesIO(example["jpg"])),
|
386 |
+
"text": example["txt"].decode(),
|
387 |
+
}
|
388 |
+
|
389 |
+
|
390 |
+
for i, example in tqdm(enumerate(dataset)):
|
391 |
+
ex = load_example(example)
|
392 |
+
print(ex["image"].size, ex["text"])
|
393 |
+
if i >= 100:
|
394 |
+
break
|
395 |
+
|
396 |
+
|
397 |
+
def example01():
|
398 |
+
# the first laion shards contain ~10k examples each
|
399 |
+
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"
|
400 |
+
|
401 |
+
batch_size = 3
|
402 |
+
shuffle_buffer = 10000
|
403 |
+
dset = wds.WebDataset(
|
404 |
+
url,
|
405 |
+
nodesplitter=wds.shardlists.split_by_node,
|
406 |
+
shardshuffle=True,
|
407 |
+
)
|
408 |
+
dset = (dset
|
409 |
+
.shuffle(shuffle_buffer, initial=shuffle_buffer)
|
410 |
+
.decode('pil', handler=warn_and_continue)
|
411 |
+
.batched(batch_size, partial=False,
|
412 |
+
collation_fn=dict_collation_fn)
|
413 |
+
)
|
414 |
+
|
415 |
+
num_workers = 2
|
416 |
+
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers)
|
417 |
+
|
418 |
+
batch_sizes = list()
|
419 |
+
keys_per_epoch = list()
|
420 |
+
for epoch in range(5):
|
421 |
+
keys = list()
|
422 |
+
for batch in tqdm(loader):
|
423 |
+
batch_sizes.append(len(batch["__key__"]))
|
424 |
+
keys.append(batch["__key__"])
|
425 |
+
|
426 |
+
for bs in batch_sizes:
|
427 |
+
assert bs==batch_size
|
428 |
+
print(f"{len(batch_sizes)} batches of size {batch_size}.")
|
429 |
+
batch_sizes = list()
|
430 |
+
|
431 |
+
keys_per_epoch.append(keys)
|
432 |
+
for i_batch in [0, 1, -1]:
|
433 |
+
print(f"Batch {i_batch} of epoch {epoch}:")
|
434 |
+
print(keys[i_batch])
|
435 |
+
print("next epoch.")
|
436 |
+
|
437 |
+
|
438 |
+
def example02():
|
439 |
+
from omegaconf import OmegaConf
|
440 |
+
from torch.utils.data.distributed import DistributedSampler
|
441 |
+
from torch.utils.data import IterableDataset
|
442 |
+
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
443 |
+
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
444 |
+
|
445 |
+
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
446 |
+
#config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
447 |
+
config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
|
448 |
+
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
449 |
+
dataloader = datamod.train_dataloader()
|
450 |
+
|
451 |
+
for batch in dataloader:
|
452 |
+
print(batch.keys())
|
453 |
+
print(batch["jpg"].shape)
|
454 |
+
break
|
455 |
+
|
456 |
+
|
457 |
+
def example03():
|
458 |
+
# improved aesthetics
|
459 |
+
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
460 |
+
dataset = wds.WebDataset(tars)
|
461 |
+
|
462 |
+
def filter_keys(x):
|
463 |
+
try:
|
464 |
+
return ("jpg" in x) and ("txt" in x)
|
465 |
+
except Exception:
|
466 |
+
return False
|
467 |
+
|
468 |
+
def filter_size(x):
|
469 |
+
try:
|
470 |
+
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
471 |
+
except Exception:
|
472 |
+
return False
|
473 |
+
|
474 |
+
def filter_watermark(x):
|
475 |
+
try:
|
476 |
+
return x['json']['pwatermark'] < 0.5
|
477 |
+
except Exception:
|
478 |
+
return False
|
479 |
+
|
480 |
+
dataset = (dataset
|
481 |
+
.select(filter_keys)
|
482 |
+
.decode('pil', handler=wds.warn_and_continue))
|
483 |
+
n_save = 20
|
484 |
+
n_total = 0
|
485 |
+
n_large = 0
|
486 |
+
n_large_nowm = 0
|
487 |
+
for i, example in enumerate(dataset):
|
488 |
+
n_total += 1
|
489 |
+
if filter_size(example):
|
490 |
+
n_large += 1
|
491 |
+
if filter_watermark(example):
|
492 |
+
n_large_nowm += 1
|
493 |
+
if n_large_nowm < n_save+1:
|
494 |
+
image = example["jpg"]
|
495 |
+
image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))
|
496 |
+
|
497 |
+
if i%500 == 0:
|
498 |
+
print(i)
|
499 |
+
print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
|
500 |
+
if n_large > 0:
|
501 |
+
print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%")
|
502 |
+
|
503 |
+
|
504 |
+
|
505 |
+
def example04():
|
506 |
+
# improved aesthetics
|
507 |
+
for i_shard in range(60208)[::-1]:
|
508 |
+
print(i_shard)
|
509 |
+
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
|
510 |
+
dataset = wds.WebDataset(tars)
|
511 |
+
|
512 |
+
def filter_keys(x):
|
513 |
+
try:
|
514 |
+
return ("jpg" in x) and ("txt" in x)
|
515 |
+
except Exception:
|
516 |
+
return False
|
517 |
+
|
518 |
+
def filter_size(x):
|
519 |
+
try:
|
520 |
+
return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
|
521 |
+
except Exception:
|
522 |
+
return False
|
523 |
+
|
524 |
+
dataset = (dataset
|
525 |
+
.select(filter_keys)
|
526 |
+
.decode('pil', handler=wds.warn_and_continue))
|
527 |
+
try:
|
528 |
+
example = next(iter(dataset))
|
529 |
+
except Exception:
|
530 |
+
print(f"Error @ {i_shard}")
|
531 |
+
|
532 |
+
|
533 |
+
if __name__ == "__main__":
|
534 |
+
#example01()
|
535 |
+
#example02()
|
536 |
+
example03()
|
537 |
+
#example04()
|
ldm/data/lsun.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
|
8 |
+
|
9 |
+
class LSUNBase(Dataset):
|
10 |
+
def __init__(self,
|
11 |
+
txt_file,
|
12 |
+
data_root,
|
13 |
+
size=None,
|
14 |
+
interpolation="bicubic",
|
15 |
+
flip_p=0.5
|
16 |
+
):
|
17 |
+
self.data_paths = txt_file
|
18 |
+
self.data_root = data_root
|
19 |
+
with open(self.data_paths, "r") as f:
|
20 |
+
self.image_paths = f.read().splitlines()
|
21 |
+
self._length = len(self.image_paths)
|
22 |
+
self.labels = {
|
23 |
+
"relative_file_path_": [l for l in self.image_paths],
|
24 |
+
"file_path_": [os.path.join(self.data_root, l)
|
25 |
+
for l in self.image_paths],
|
26 |
+
}
|
27 |
+
|
28 |
+
self.size = size
|
29 |
+
self.interpolation = {"linear": PIL.Image.LINEAR,
|
30 |
+
"bilinear": PIL.Image.BILINEAR,
|
31 |
+
"bicubic": PIL.Image.BICUBIC,
|
32 |
+
"lanczos": PIL.Image.LANCZOS,
|
33 |
+
}[interpolation]
|
34 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self._length
|
38 |
+
|
39 |
+
def __getitem__(self, i):
|
40 |
+
example = dict((k, self.labels[k][i]) for k in self.labels)
|
41 |
+
image = Image.open(example["file_path_"])
|
42 |
+
if not image.mode == "RGB":
|
43 |
+
image = image.convert("RGB")
|
44 |
+
|
45 |
+
# default to score-sde preprocessing
|
46 |
+
img = np.array(image).astype(np.uint8)
|
47 |
+
crop = min(img.shape[0], img.shape[1])
|
48 |
+
h, w, = img.shape[0], img.shape[1]
|
49 |
+
img = img[(h - crop) // 2:(h + crop) // 2,
|
50 |
+
(w - crop) // 2:(w + crop) // 2]
|
51 |
+
|
52 |
+
image = Image.fromarray(img)
|
53 |
+
if self.size is not None:
|
54 |
+
image = image.resize((self.size, self.size), resample=self.interpolation)
|
55 |
+
|
56 |
+
image = self.flip(image)
|
57 |
+
image = np.array(image).astype(np.uint8)
|
58 |
+
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
59 |
+
return example
|
60 |
+
|
61 |
+
|
62 |
+
class LSUNChurchesTrain(LSUNBase):
|
63 |
+
def __init__(self, **kwargs):
|
64 |
+
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
65 |
+
|
66 |
+
|
67 |
+
class LSUNChurchesValidation(LSUNBase):
|
68 |
+
def __init__(self, flip_p=0., **kwargs):
|
69 |
+
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
70 |
+
flip_p=flip_p, **kwargs)
|
71 |
+
|
72 |
+
|
73 |
+
class LSUNBedroomsTrain(LSUNBase):
|
74 |
+
def __init__(self, **kwargs):
|
75 |
+
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
class LSUNBedroomsValidation(LSUNBase):
|
79 |
+
def __init__(self, flip_p=0.0, **kwargs):
|
80 |
+
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
81 |
+
flip_p=flip_p, **kwargs)
|
82 |
+
|
83 |
+
|
84 |
+
class LSUNCatsTrain(LSUNBase):
|
85 |
+
def __init__(self, **kwargs):
|
86 |
+
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
class LSUNCatsValidation(LSUNBase):
|
90 |
+
def __init__(self, flip_p=0., **kwargs):
|
91 |
+
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
92 |
+
flip_p=flip_p, **kwargs)
|
ldm/data/nerf_like.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import imageio
|
7 |
+
import math
|
8 |
+
import cv2
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
def cartesian_to_spherical(xyz):
|
12 |
+
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
13 |
+
xy = xyz[:,0]**2 + xyz[:,1]**2
|
14 |
+
z = np.sqrt(xy + xyz[:,2]**2)
|
15 |
+
theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
|
16 |
+
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
|
17 |
+
azimuth = np.arctan2(xyz[:,1], xyz[:,0])
|
18 |
+
return np.array([theta, azimuth, z])
|
19 |
+
|
20 |
+
|
21 |
+
def get_T(T_target, T_cond):
|
22 |
+
theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
|
23 |
+
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
|
24 |
+
|
25 |
+
d_theta = theta_target - theta_cond
|
26 |
+
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
|
27 |
+
d_z = z_target - z_cond
|
28 |
+
|
29 |
+
d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
|
30 |
+
return d_T
|
31 |
+
|
32 |
+
def get_spherical(T_target, T_cond):
|
33 |
+
theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
|
34 |
+
theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
|
35 |
+
|
36 |
+
d_theta = theta_target - theta_cond
|
37 |
+
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
|
38 |
+
d_z = z_target - z_cond
|
39 |
+
|
40 |
+
d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()])
|
41 |
+
return d_T
|
42 |
+
|
43 |
+
class RTMV(Dataset):
|
44 |
+
def __init__(self, root_dir='datasets/RTMV/google_scanned',\
|
45 |
+
first_K=64, resolution=256, load_target=False):
|
46 |
+
self.root_dir = root_dir
|
47 |
+
self.scene_list = sorted(next(os.walk(root_dir))[1])
|
48 |
+
self.resolution = resolution
|
49 |
+
self.first_K = first_K
|
50 |
+
self.load_target = load_target
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.scene_list)
|
54 |
+
|
55 |
+
def __getitem__(self, idx):
|
56 |
+
scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
|
57 |
+
with open(os.path.join(scene_dir, 'transforms.json'), "r") as f:
|
58 |
+
meta = json.load(f)
|
59 |
+
imgs = []
|
60 |
+
poses = []
|
61 |
+
for i_img in range(self.first_K):
|
62 |
+
meta_img = meta['frames'][i_img]
|
63 |
+
|
64 |
+
if i_img == 0 or self.load_target:
|
65 |
+
img_path = os.path.join(scene_dir, meta_img['file_path'])
|
66 |
+
img = imageio.imread(img_path)
|
67 |
+
img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
|
68 |
+
imgs.append(img)
|
69 |
+
|
70 |
+
c2w = meta_img['transform_matrix']
|
71 |
+
poses.append(c2w)
|
72 |
+
|
73 |
+
imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
|
74 |
+
imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
|
75 |
+
imgs = imgs * 2 - 1. # convert to stable diffusion range
|
76 |
+
poses = torch.tensor(np.array(poses).astype(np.float32))
|
77 |
+
return imgs, poses
|
78 |
+
|
79 |
+
def blend_rgba(self, img):
|
80 |
+
img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
|
81 |
+
return img
|
82 |
+
|
83 |
+
|
84 |
+
class GSO(Dataset):
|
85 |
+
def __init__(self, root_dir='datasets/GoogleScannedObjects',\
|
86 |
+
split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'):
|
87 |
+
self.root_dir = root_dir
|
88 |
+
with open(os.path.join(root_dir, '%s.json' % split), "r") as f:
|
89 |
+
self.scene_list = json.load(f)
|
90 |
+
self.resolution = resolution
|
91 |
+
self.first_K = first_K
|
92 |
+
self.load_target = load_target
|
93 |
+
self.name = name
|
94 |
+
|
95 |
+
def __len__(self):
|
96 |
+
return len(self.scene_list)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
|
100 |
+
with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f:
|
101 |
+
meta = json.load(f)
|
102 |
+
imgs = []
|
103 |
+
poses = []
|
104 |
+
for i_img in range(self.first_K):
|
105 |
+
meta_img = meta['frames'][i_img]
|
106 |
+
|
107 |
+
if i_img == 0 or self.load_target:
|
108 |
+
img_path = os.path.join(scene_dir, meta_img['file_path'])
|
109 |
+
img = imageio.imread(img_path)
|
110 |
+
img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
|
111 |
+
imgs.append(img)
|
112 |
+
|
113 |
+
c2w = meta_img['transform_matrix']
|
114 |
+
poses.append(c2w)
|
115 |
+
|
116 |
+
imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
|
117 |
+
mask = imgs[:, :, :, -1]
|
118 |
+
imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
|
119 |
+
imgs = imgs * 2 - 1. # convert to stable diffusion range
|
120 |
+
poses = torch.tensor(np.array(poses).astype(np.float32))
|
121 |
+
return imgs, poses
|
122 |
+
|
123 |
+
def blend_rgba(self, img):
|
124 |
+
img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
|
125 |
+
return img
|
126 |
+
|
127 |
+
class WILD(Dataset):
|
128 |
+
def __init__(self, root_dir='data/nerf_wild',\
|
129 |
+
first_K=33, resolution=256, load_target=False):
|
130 |
+
self.root_dir = root_dir
|
131 |
+
self.scene_list = sorted(next(os.walk(root_dir))[1])
|
132 |
+
self.resolution = resolution
|
133 |
+
self.first_K = first_K
|
134 |
+
self.load_target = load_target
|
135 |
+
|
136 |
+
def __len__(self):
|
137 |
+
return len(self.scene_list)
|
138 |
+
|
139 |
+
def __getitem__(self, idx):
|
140 |
+
scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
|
141 |
+
with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f:
|
142 |
+
meta = json.load(f)
|
143 |
+
imgs = []
|
144 |
+
poses = []
|
145 |
+
for i_img in range(self.first_K):
|
146 |
+
meta_img = meta['frames'][i_img]
|
147 |
+
|
148 |
+
if i_img == 0 or self.load_target:
|
149 |
+
img_path = os.path.join(scene_dir, meta_img['file_path'])
|
150 |
+
img = imageio.imread(img_path + '.png')
|
151 |
+
img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
|
152 |
+
imgs.append(img)
|
153 |
+
|
154 |
+
c2w = meta_img['transform_matrix']
|
155 |
+
poses.append(c2w)
|
156 |
+
|
157 |
+
imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
|
158 |
+
imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
|
159 |
+
imgs = imgs * 2 - 1. # convert to stable diffusion range
|
160 |
+
poses = torch.tensor(np.array(poses).astype(np.float32))
|
161 |
+
return imgs, poses
|
162 |
+
|
163 |
+
def blend_rgba(self, img):
|
164 |
+
img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
|
165 |
+
return img
|
ldm/data/simple.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
import webdataset as wds
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import DictConfig, ListConfig
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from pathlib import Path
|
8 |
+
import json
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
import torchvision
|
12 |
+
from einops import rearrange
|
13 |
+
from ldm.util import instantiate_from_config
|
14 |
+
from datasets import load_dataset
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
import copy
|
17 |
+
import csv
|
18 |
+
import cv2
|
19 |
+
import random
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
import json
|
23 |
+
import os, sys
|
24 |
+
import webdataset as wds
|
25 |
+
import math
|
26 |
+
from torch.utils.data.distributed import DistributedSampler
|
27 |
+
|
28 |
+
# Some hacky things to make experimentation easier
|
29 |
+
def make_transform_multi_folder_data(paths, caption_files=None, **kwargs):
|
30 |
+
ds = make_multi_folder_data(paths, caption_files, **kwargs)
|
31 |
+
return TransformDataset(ds)
|
32 |
+
|
33 |
+
def make_nfp_data(base_path):
|
34 |
+
dirs = list(Path(base_path).glob("*/"))
|
35 |
+
print(f"Found {len(dirs)} folders")
|
36 |
+
print(dirs)
|
37 |
+
tforms = [transforms.Resize(512), transforms.CenterCrop(512)]
|
38 |
+
datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs]
|
39 |
+
return torch.utils.data.ConcatDataset(datasets)
|
40 |
+
|
41 |
+
|
42 |
+
class VideoDataset(Dataset):
|
43 |
+
def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2):
|
44 |
+
self.root_dir = Path(root_dir)
|
45 |
+
self.caption_file = caption_file
|
46 |
+
self.n = n
|
47 |
+
ext = "mp4"
|
48 |
+
self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
|
49 |
+
self.offset = offset
|
50 |
+
|
51 |
+
if isinstance(image_transforms, ListConfig):
|
52 |
+
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
53 |
+
image_transforms.extend([transforms.ToTensor(),
|
54 |
+
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
55 |
+
image_transforms = transforms.Compose(image_transforms)
|
56 |
+
self.tform = image_transforms
|
57 |
+
with open(self.caption_file) as f:
|
58 |
+
reader = csv.reader(f)
|
59 |
+
rows = [row for row in reader]
|
60 |
+
self.captions = dict(rows)
|
61 |
+
|
62 |
+
def __len__(self):
|
63 |
+
return len(self.paths)
|
64 |
+
|
65 |
+
def __getitem__(self, index):
|
66 |
+
for i in range(10):
|
67 |
+
try:
|
68 |
+
return self._load_sample(index)
|
69 |
+
except Exception:
|
70 |
+
# Not really good enough but...
|
71 |
+
print("uh oh")
|
72 |
+
|
73 |
+
def _load_sample(self, index):
|
74 |
+
n = self.n
|
75 |
+
filename = self.paths[index]
|
76 |
+
min_frame = 2*self.offset + 2
|
77 |
+
vid = cv2.VideoCapture(str(filename))
|
78 |
+
max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
|
79 |
+
curr_frame_n = random.randint(min_frame, max_frames)
|
80 |
+
vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n)
|
81 |
+
_, curr_frame = vid.read()
|
82 |
+
|
83 |
+
prev_frames = []
|
84 |
+
for i in range(n):
|
85 |
+
prev_frame_n = curr_frame_n - (i+1)*self.offset
|
86 |
+
vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n)
|
87 |
+
_, prev_frame = vid.read()
|
88 |
+
prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1]))
|
89 |
+
prev_frames.append(prev_frame)
|
90 |
+
|
91 |
+
vid.release()
|
92 |
+
caption = self.captions[filename.name]
|
93 |
+
data = {
|
94 |
+
"image": self.tform(Image.fromarray(curr_frame[...,::-1])),
|
95 |
+
"prev": torch.cat(prev_frames, dim=-1),
|
96 |
+
"txt": caption
|
97 |
+
}
|
98 |
+
return data
|
99 |
+
|
100 |
+
# end hacky things
|
101 |
+
|
102 |
+
|
103 |
+
def make_tranforms(image_transforms):
|
104 |
+
# if isinstance(image_transforms, ListConfig):
|
105 |
+
# image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
106 |
+
image_transforms = []
|
107 |
+
image_transforms.extend([transforms.ToTensor(),
|
108 |
+
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
109 |
+
image_transforms = transforms.Compose(image_transforms)
|
110 |
+
return image_transforms
|
111 |
+
|
112 |
+
|
113 |
+
def make_multi_folder_data(paths, caption_files=None, **kwargs):
|
114 |
+
"""Make a concat dataset from multiple folders
|
115 |
+
Don't suport captions yet
|
116 |
+
|
117 |
+
If paths is a list, that's ok, if it's a Dict interpret it as:
|
118 |
+
k=folder v=n_times to repeat that
|
119 |
+
"""
|
120 |
+
list_of_paths = []
|
121 |
+
if isinstance(paths, (Dict, DictConfig)):
|
122 |
+
assert caption_files is None, \
|
123 |
+
"Caption files not yet supported for repeats"
|
124 |
+
for folder_path, repeats in paths.items():
|
125 |
+
list_of_paths.extend([folder_path]*repeats)
|
126 |
+
paths = list_of_paths
|
127 |
+
|
128 |
+
if caption_files is not None:
|
129 |
+
datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
|
130 |
+
else:
|
131 |
+
datasets = [FolderData(p, **kwargs) for p in paths]
|
132 |
+
return torch.utils.data.ConcatDataset(datasets)
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
class NfpDataset(Dataset):
|
137 |
+
def __init__(self,
|
138 |
+
root_dir,
|
139 |
+
image_transforms=[],
|
140 |
+
ext="jpg",
|
141 |
+
default_caption="",
|
142 |
+
) -> None:
|
143 |
+
"""assume sequential frames and a deterministic transform"""
|
144 |
+
|
145 |
+
self.root_dir = Path(root_dir)
|
146 |
+
self.default_caption = default_caption
|
147 |
+
|
148 |
+
self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
|
149 |
+
self.tform = make_tranforms(image_transforms)
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.paths) - 1
|
153 |
+
|
154 |
+
|
155 |
+
def __getitem__(self, index):
|
156 |
+
prev = self.paths[index]
|
157 |
+
curr = self.paths[index+1]
|
158 |
+
data = {}
|
159 |
+
data["image"] = self._load_im(curr)
|
160 |
+
data["prev"] = self._load_im(prev)
|
161 |
+
data["txt"] = self.default_caption
|
162 |
+
return data
|
163 |
+
|
164 |
+
def _load_im(self, filename):
|
165 |
+
im = Image.open(filename).convert("RGB")
|
166 |
+
return self.tform(im)
|
167 |
+
|
168 |
+
class ObjaverseDataModuleFromConfig(pl.LightningDataModule):
|
169 |
+
def __init__(self, root_dir, batch_size, total_view, train=None, validation=None,
|
170 |
+
test=None, num_workers=4, **kwargs):
|
171 |
+
super().__init__(self)
|
172 |
+
self.root_dir = root_dir
|
173 |
+
self.batch_size = batch_size
|
174 |
+
self.num_workers = num_workers
|
175 |
+
self.total_view = total_view
|
176 |
+
|
177 |
+
if train is not None:
|
178 |
+
dataset_config = train
|
179 |
+
if validation is not None:
|
180 |
+
dataset_config = validation
|
181 |
+
|
182 |
+
if 'image_transforms' in dataset_config:
|
183 |
+
image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)]
|
184 |
+
else:
|
185 |
+
image_transforms = []
|
186 |
+
image_transforms.extend([transforms.ToTensor(),
|
187 |
+
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
188 |
+
self.image_transforms = torchvision.transforms.Compose(image_transforms)
|
189 |
+
|
190 |
+
|
191 |
+
def train_dataloader(self):
|
192 |
+
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \
|
193 |
+
image_transforms=self.image_transforms)
|
194 |
+
sampler = DistributedSampler(dataset)
|
195 |
+
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
196 |
+
|
197 |
+
def val_dataloader(self):
|
198 |
+
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \
|
199 |
+
image_transforms=self.image_transforms)
|
200 |
+
sampler = DistributedSampler(dataset)
|
201 |
+
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
202 |
+
|
203 |
+
def test_dataloader(self):
|
204 |
+
return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\
|
205 |
+
batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
206 |
+
|
207 |
+
|
208 |
+
class ObjaverseData(Dataset):
|
209 |
+
def __init__(self,
|
210 |
+
root_dir='.objaverse/hf-objaverse-v1/views',
|
211 |
+
image_transforms=[],
|
212 |
+
ext="png",
|
213 |
+
default_trans=torch.zeros(3),
|
214 |
+
postprocess=None,
|
215 |
+
return_paths=False,
|
216 |
+
total_view=4,
|
217 |
+
validation=False
|
218 |
+
) -> None:
|
219 |
+
"""Create a dataset from a folder of images.
|
220 |
+
If you pass in a root directory it will be searched for images
|
221 |
+
ending in ext (ext can be a list)
|
222 |
+
"""
|
223 |
+
self.root_dir = Path(root_dir)
|
224 |
+
self.default_trans = default_trans
|
225 |
+
self.return_paths = return_paths
|
226 |
+
if isinstance(postprocess, DictConfig):
|
227 |
+
postprocess = instantiate_from_config(postprocess)
|
228 |
+
self.postprocess = postprocess
|
229 |
+
self.total_view = total_view
|
230 |
+
|
231 |
+
if not isinstance(ext, (tuple, list, ListConfig)):
|
232 |
+
ext = [ext]
|
233 |
+
|
234 |
+
with open(os.path.join(root_dir, 'valid_paths.json')) as f:
|
235 |
+
self.paths = json.load(f)
|
236 |
+
|
237 |
+
total_objects = len(self.paths)
|
238 |
+
if validation:
|
239 |
+
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
|
240 |
+
else:
|
241 |
+
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
|
242 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
243 |
+
self.tform = image_transforms
|
244 |
+
|
245 |
+
def __len__(self):
|
246 |
+
return len(self.paths)
|
247 |
+
|
248 |
+
def cartesian_to_spherical(self, xyz):
|
249 |
+
ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
|
250 |
+
xy = xyz[:,0]**2 + xyz[:,1]**2
|
251 |
+
z = np.sqrt(xy + xyz[:,2]**2)
|
252 |
+
theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
|
253 |
+
#ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
|
254 |
+
azimuth = np.arctan2(xyz[:,1], xyz[:,0])
|
255 |
+
return np.array([theta, azimuth, z])
|
256 |
+
|
257 |
+
def get_T(self, target_RT, cond_RT):
|
258 |
+
R, T = target_RT[:3, :3], target_RT[:, -1]
|
259 |
+
T_target = -R.T @ T
|
260 |
+
|
261 |
+
R, T = cond_RT[:3, :3], cond_RT[:, -1]
|
262 |
+
T_cond = -R.T @ T
|
263 |
+
|
264 |
+
theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
|
265 |
+
theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
|
266 |
+
|
267 |
+
d_theta = theta_target - theta_cond
|
268 |
+
d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
|
269 |
+
d_z = z_target - z_cond
|
270 |
+
|
271 |
+
d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
|
272 |
+
return d_T
|
273 |
+
|
274 |
+
def load_im(self, path, color):
|
275 |
+
'''
|
276 |
+
replace background pixel with random color in rendering
|
277 |
+
'''
|
278 |
+
try:
|
279 |
+
img = plt.imread(path)
|
280 |
+
except:
|
281 |
+
print(path)
|
282 |
+
sys.exit()
|
283 |
+
img[img[:, :, -1] == 0.] = color
|
284 |
+
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
|
285 |
+
return img
|
286 |
+
|
287 |
+
def __getitem__(self, index):
|
288 |
+
|
289 |
+
data = {}
|
290 |
+
if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice
|
291 |
+
total_view = 8
|
292 |
+
else:
|
293 |
+
total_view = 4
|
294 |
+
index_target, index_cond = random.sample(range(total_view), 2) # without replacement
|
295 |
+
filename = os.path.join(self.root_dir, self.paths[index])
|
296 |
+
|
297 |
+
# print(self.paths[index])
|
298 |
+
|
299 |
+
if self.return_paths:
|
300 |
+
data["path"] = str(filename)
|
301 |
+
|
302 |
+
color = [1., 1., 1., 1.]
|
303 |
+
|
304 |
+
try:
|
305 |
+
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
306 |
+
cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
|
307 |
+
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
308 |
+
cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
|
309 |
+
except:
|
310 |
+
# very hacky solution, sorry about this
|
311 |
+
filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid
|
312 |
+
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
|
313 |
+
cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
|
314 |
+
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
|
315 |
+
cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
|
316 |
+
target_im = torch.zeros_like(target_im)
|
317 |
+
cond_im = torch.zeros_like(cond_im)
|
318 |
+
|
319 |
+
data["image_target"] = target_im
|
320 |
+
data["image_cond"] = cond_im
|
321 |
+
data["T"] = self.get_T(target_RT, cond_RT)
|
322 |
+
|
323 |
+
if self.postprocess is not None:
|
324 |
+
data = self.postprocess(data)
|
325 |
+
|
326 |
+
return data
|
327 |
+
|
328 |
+
def process_im(self, im):
|
329 |
+
im = im.convert("RGB")
|
330 |
+
return self.tform(im)
|
331 |
+
|
332 |
+
class FolderData(Dataset):
|
333 |
+
def __init__(self,
|
334 |
+
root_dir,
|
335 |
+
caption_file=None,
|
336 |
+
image_transforms=[],
|
337 |
+
ext="jpg",
|
338 |
+
default_caption="",
|
339 |
+
postprocess=None,
|
340 |
+
return_paths=False,
|
341 |
+
) -> None:
|
342 |
+
"""Create a dataset from a folder of images.
|
343 |
+
If you pass in a root directory it will be searched for images
|
344 |
+
ending in ext (ext can be a list)
|
345 |
+
"""
|
346 |
+
self.root_dir = Path(root_dir)
|
347 |
+
self.default_caption = default_caption
|
348 |
+
self.return_paths = return_paths
|
349 |
+
if isinstance(postprocess, DictConfig):
|
350 |
+
postprocess = instantiate_from_config(postprocess)
|
351 |
+
self.postprocess = postprocess
|
352 |
+
if caption_file is not None:
|
353 |
+
with open(caption_file, "rt") as f:
|
354 |
+
ext = Path(caption_file).suffix.lower()
|
355 |
+
if ext == ".json":
|
356 |
+
captions = json.load(f)
|
357 |
+
elif ext == ".jsonl":
|
358 |
+
lines = f.readlines()
|
359 |
+
lines = [json.loads(x) for x in lines]
|
360 |
+
captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
|
361 |
+
else:
|
362 |
+
raise ValueError(f"Unrecognised format: {ext}")
|
363 |
+
self.captions = captions
|
364 |
+
else:
|
365 |
+
self.captions = None
|
366 |
+
|
367 |
+
if not isinstance(ext, (tuple, list, ListConfig)):
|
368 |
+
ext = [ext]
|
369 |
+
|
370 |
+
# Only used if there is no caption file
|
371 |
+
self.paths = []
|
372 |
+
for e in ext:
|
373 |
+
self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}"))))
|
374 |
+
self.tform = make_tranforms(image_transforms)
|
375 |
+
|
376 |
+
def __len__(self):
|
377 |
+
if self.captions is not None:
|
378 |
+
return len(self.captions.keys())
|
379 |
+
else:
|
380 |
+
return len(self.paths)
|
381 |
+
|
382 |
+
def __getitem__(self, index):
|
383 |
+
data = {}
|
384 |
+
if self.captions is not None:
|
385 |
+
chosen = list(self.captions.keys())[index]
|
386 |
+
caption = self.captions.get(chosen, None)
|
387 |
+
if caption is None:
|
388 |
+
caption = self.default_caption
|
389 |
+
filename = self.root_dir/chosen
|
390 |
+
else:
|
391 |
+
filename = self.paths[index]
|
392 |
+
|
393 |
+
if self.return_paths:
|
394 |
+
data["path"] = str(filename)
|
395 |
+
|
396 |
+
im = Image.open(filename).convert("RGB")
|
397 |
+
im = self.process_im(im)
|
398 |
+
data["image"] = im
|
399 |
+
|
400 |
+
if self.captions is not None:
|
401 |
+
data["txt"] = caption
|
402 |
+
else:
|
403 |
+
data["txt"] = self.default_caption
|
404 |
+
|
405 |
+
if self.postprocess is not None:
|
406 |
+
data = self.postprocess(data)
|
407 |
+
|
408 |
+
return data
|
409 |
+
|
410 |
+
def process_im(self, im):
|
411 |
+
im = im.convert("RGB")
|
412 |
+
return self.tform(im)
|
413 |
+
import random
|
414 |
+
|
415 |
+
class TransformDataset():
|
416 |
+
def __init__(self, ds, extra_label="sksbspic"):
|
417 |
+
self.ds = ds
|
418 |
+
self.extra_label = extra_label
|
419 |
+
self.transforms = {
|
420 |
+
"align": transforms.Resize(768),
|
421 |
+
"centerzoom": transforms.CenterCrop(768),
|
422 |
+
"randzoom": transforms.RandomCrop(768),
|
423 |
+
}
|
424 |
+
|
425 |
+
|
426 |
+
def __getitem__(self, index):
|
427 |
+
data = self.ds[index]
|
428 |
+
|
429 |
+
im = data['image']
|
430 |
+
im = im.permute(2,0,1)
|
431 |
+
# In case data is smaller than expected
|
432 |
+
im = transforms.Resize(1024)(im)
|
433 |
+
|
434 |
+
tform_name = random.choice(list(self.transforms.keys()))
|
435 |
+
im = self.transforms[tform_name](im)
|
436 |
+
|
437 |
+
im = im.permute(1,2,0)
|
438 |
+
|
439 |
+
data['image'] = im
|
440 |
+
data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}"
|
441 |
+
|
442 |
+
return data
|
443 |
+
|
444 |
+
def __len__(self):
|
445 |
+
return len(self.ds)
|
446 |
+
|
447 |
+
def hf_dataset(
|
448 |
+
name,
|
449 |
+
image_transforms=[],
|
450 |
+
image_column="image",
|
451 |
+
text_column="text",
|
452 |
+
split='train',
|
453 |
+
image_key='image',
|
454 |
+
caption_key='txt',
|
455 |
+
):
|
456 |
+
"""Make huggingface dataset with appropriate list of transforms applied
|
457 |
+
"""
|
458 |
+
ds = load_dataset(name, split=split)
|
459 |
+
tform = make_tranforms(image_transforms)
|
460 |
+
|
461 |
+
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
462 |
+
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
463 |
+
|
464 |
+
def pre_process(examples):
|
465 |
+
processed = {}
|
466 |
+
processed[image_key] = [tform(im) for im in examples[image_column]]
|
467 |
+
processed[caption_key] = examples[text_column]
|
468 |
+
return processed
|
469 |
+
|
470 |
+
ds.set_transform(pre_process)
|
471 |
+
return ds
|
472 |
+
|
473 |
+
class TextOnly(Dataset):
|
474 |
+
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
475 |
+
"""Returns only captions with dummy images"""
|
476 |
+
self.output_size = output_size
|
477 |
+
self.image_key = image_key
|
478 |
+
self.caption_key = caption_key
|
479 |
+
if isinstance(captions, Path):
|
480 |
+
self.captions = self._load_caption_file(captions)
|
481 |
+
else:
|
482 |
+
self.captions = captions
|
483 |
+
|
484 |
+
if n_gpus > 1:
|
485 |
+
# hack to make sure that all the captions appear on each gpu
|
486 |
+
repeated = [n_gpus*[x] for x in self.captions]
|
487 |
+
self.captions = []
|
488 |
+
[self.captions.extend(x) for x in repeated]
|
489 |
+
|
490 |
+
def __len__(self):
|
491 |
+
return len(self.captions)
|
492 |
+
|
493 |
+
def __getitem__(self, index):
|
494 |
+
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
495 |
+
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
496 |
+
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
497 |
+
|
498 |
+
def _load_caption_file(self, filename):
|
499 |
+
with open(filename, 'rt') as f:
|
500 |
+
captions = f.readlines()
|
501 |
+
return [x.strip('\n') for x in captions]
|
502 |
+
|
503 |
+
|
504 |
+
|
505 |
+
import random
|
506 |
+
import json
|
507 |
+
class IdRetreivalDataset(FolderData):
|
508 |
+
def __init__(self, ret_file, *args, **kwargs):
|
509 |
+
super().__init__(*args, **kwargs)
|
510 |
+
with open(ret_file, "rt") as f:
|
511 |
+
self.ret = json.load(f)
|
512 |
+
|
513 |
+
def __getitem__(self, index):
|
514 |
+
data = super().__getitem__(index)
|
515 |
+
key = self.paths[index].name
|
516 |
+
matches = self.ret[key]
|
517 |
+
if len(matches) > 0:
|
518 |
+
retreived = random.choice(matches)
|
519 |
+
else:
|
520 |
+
retreived = key
|
521 |
+
filename = self.root_dir/retreived
|
522 |
+
im = Image.open(filename).convert("RGB")
|
523 |
+
im = self.process_im(im)
|
524 |
+
# data["match"] = im
|
525 |
+
data["match"] = torch.cat((data["image"], im), dim=-1)
|
526 |
+
return data
|
ldm/extras.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import torch
|
4 |
+
from ldm.util import instantiate_from_config
|
5 |
+
import logging
|
6 |
+
from contextlib import contextmanager
|
7 |
+
|
8 |
+
from contextlib import contextmanager
|
9 |
+
import logging
|
10 |
+
|
11 |
+
@contextmanager
|
12 |
+
def all_logging_disabled(highest_level=logging.CRITICAL):
|
13 |
+
"""
|
14 |
+
A context manager that will prevent any logging messages
|
15 |
+
triggered during the body from being processed.
|
16 |
+
|
17 |
+
:param highest_level: the maximum logging level in use.
|
18 |
+
This would only need to be changed if a custom level greater than CRITICAL
|
19 |
+
is defined.
|
20 |
+
|
21 |
+
https://gist.github.com/simon-weber/7853144
|
22 |
+
"""
|
23 |
+
# two kind-of hacks here:
|
24 |
+
# * can't get the highest logging level in effect => delegate to the user
|
25 |
+
# * can't get the current module-level override => use an undocumented
|
26 |
+
# (but non-private!) interface
|
27 |
+
|
28 |
+
previous_level = logging.root.manager.disable
|
29 |
+
|
30 |
+
logging.disable(highest_level)
|
31 |
+
|
32 |
+
try:
|
33 |
+
yield
|
34 |
+
finally:
|
35 |
+
logging.disable(previous_level)
|
36 |
+
|
37 |
+
def load_training_dir(train_dir, device, epoch="last"):
|
38 |
+
"""Load a checkpoint and config from training directory"""
|
39 |
+
train_dir = Path(train_dir)
|
40 |
+
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
|
41 |
+
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
|
42 |
+
config = list(train_dir.rglob(f"*-project.yaml"))
|
43 |
+
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
|
44 |
+
if len(config) > 1:
|
45 |
+
print(f"found {len(config)} matching config files")
|
46 |
+
config = sorted(config)[-1]
|
47 |
+
print(f"selecting {config}")
|
48 |
+
else:
|
49 |
+
config = config[0]
|
50 |
+
|
51 |
+
|
52 |
+
config = OmegaConf.load(config)
|
53 |
+
return load_model_from_config(config, ckpt[0], device)
|
54 |
+
|
55 |
+
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
|
56 |
+
"""Loads a model from config and a ckpt
|
57 |
+
if config is a path will use omegaconf to load
|
58 |
+
"""
|
59 |
+
if isinstance(config, (str, Path)):
|
60 |
+
config = OmegaConf.load(config)
|
61 |
+
|
62 |
+
with all_logging_disabled():
|
63 |
+
print(f"Loading model from {ckpt}")
|
64 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
65 |
+
global_step = pl_sd["global_step"]
|
66 |
+
sd = pl_sd["state_dict"]
|
67 |
+
model = instantiate_from_config(config.model)
|
68 |
+
m, u = model.load_state_dict(sd, strict=False)
|
69 |
+
if len(m) > 0 and verbose:
|
70 |
+
print("missing keys:")
|
71 |
+
print(m)
|
72 |
+
if len(u) > 0 and verbose:
|
73 |
+
print("unexpected keys:")
|
74 |
+
model.to(device)
|
75 |
+
model.eval()
|
76 |
+
model.cond_stage_model.device = device
|
77 |
+
return model
|
ldm/guidance.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
from scipy import interpolate
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from IPython.display import clear_output
|
7 |
+
import abc
|
8 |
+
|
9 |
+
|
10 |
+
class GuideModel(torch.nn.Module, abc.ABC):
|
11 |
+
def __init__(self) -> None:
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
@abc.abstractmethod
|
15 |
+
def preprocess(self, x_img):
|
16 |
+
pass
|
17 |
+
|
18 |
+
@abc.abstractmethod
|
19 |
+
def compute_loss(self, inp):
|
20 |
+
pass
|
21 |
+
|
22 |
+
|
23 |
+
class Guider(torch.nn.Module):
|
24 |
+
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
25 |
+
"""Apply classifier guidance
|
26 |
+
|
27 |
+
Specify a guidance scale as either a scalar
|
28 |
+
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
29 |
+
[(0, 10), (0.5, 20), (1, 50)]
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self.sampler = sampler
|
33 |
+
self.index = 0
|
34 |
+
self.show = verbose
|
35 |
+
self.guide_model = guide_model
|
36 |
+
self.history = []
|
37 |
+
|
38 |
+
if isinstance(scale, (Tuple, List)):
|
39 |
+
times = np.array([x[0] for x in scale])
|
40 |
+
values = np.array([x[1] for x in scale])
|
41 |
+
self.scale_schedule = {"times": times, "values": values}
|
42 |
+
else:
|
43 |
+
self.scale_schedule = float(scale)
|
44 |
+
|
45 |
+
self.ddim_timesteps = sampler.ddim_timesteps
|
46 |
+
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
47 |
+
|
48 |
+
|
49 |
+
def get_scales(self):
|
50 |
+
if isinstance(self.scale_schedule, float):
|
51 |
+
return len(self.ddim_timesteps)*[self.scale_schedule]
|
52 |
+
|
53 |
+
interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
|
54 |
+
fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
|
55 |
+
return interpolater(fractional_steps)
|
56 |
+
|
57 |
+
def modify_score(self, model, e_t, x, t, c):
|
58 |
+
|
59 |
+
# TODO look up index by t
|
60 |
+
scale = self.get_scales()[self.index]
|
61 |
+
|
62 |
+
if (scale == 0):
|
63 |
+
return e_t
|
64 |
+
|
65 |
+
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
66 |
+
with torch.enable_grad():
|
67 |
+
x_in = x.detach().requires_grad_(True)
|
68 |
+
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
69 |
+
x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
|
70 |
+
|
71 |
+
inp = self.guide_model.preprocess(x_img)
|
72 |
+
loss = self.guide_model.compute_loss(inp)
|
73 |
+
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
74 |
+
correction = grads * scale
|
75 |
+
|
76 |
+
if self.show:
|
77 |
+
clear_output(wait=True)
|
78 |
+
print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
|
79 |
+
self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
|
80 |
+
plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
|
81 |
+
plt.axis('off')
|
82 |
+
plt.show()
|
83 |
+
plt.imshow(correction[0][0].detach().cpu())
|
84 |
+
plt.axis('off')
|
85 |
+
plt.show()
|
86 |
+
|
87 |
+
|
88 |
+
e_t_mod = e_t - sqrt_1ma*correction
|
89 |
+
if self.show:
|
90 |
+
fig, axs = plt.subplots(1, 3)
|
91 |
+
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
92 |
+
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
93 |
+
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
94 |
+
plt.show()
|
95 |
+
self.index += 1
|
96 |
+
return e_t_mod
|
ldm/lr_scheduler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n, **kwargs):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n, **kwargs):
|
33 |
+
return self.schedule(n,**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaWarmUpCosineScheduler2:
|
37 |
+
"""
|
38 |
+
supports repeated iterations, configurable via lists
|
39 |
+
note: use with a base_lr of 1.0.
|
40 |
+
"""
|
41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
43 |
+
self.lr_warm_up_steps = warm_up_steps
|
44 |
+
self.f_start = f_start
|
45 |
+
self.f_min = f_min
|
46 |
+
self.f_max = f_max
|
47 |
+
self.cycle_lengths = cycle_lengths
|
48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
49 |
+
self.last_f = 0.
|
50 |
+
self.verbosity_interval = verbosity_interval
|
51 |
+
|
52 |
+
def find_in_interval(self, n):
|
53 |
+
interval = 0
|
54 |
+
for cl in self.cum_cycles[1:]:
|
55 |
+
if n <= cl:
|
56 |
+
return interval
|
57 |
+
interval += 1
|
58 |
+
|
59 |
+
def schedule(self, n, **kwargs):
|
60 |
+
cycle = self.find_in_interval(n)
|
61 |
+
n = n - self.cum_cycles[cycle]
|
62 |
+
if self.verbosity_interval > 0:
|
63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
64 |
+
f"current cycle {cycle}")
|
65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
67 |
+
self.last_f = f
|
68 |
+
return f
|
69 |
+
else:
|
70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
71 |
+
t = min(t, 1.0)
|
72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
73 |
+
1 + np.cos(t * np.pi))
|
74 |
+
self.last_f = f
|
75 |
+
return f
|
76 |
+
|
77 |
+
def __call__(self, n, **kwargs):
|
78 |
+
return self.schedule(n, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
88 |
+
f"current cycle {cycle}")
|
89 |
+
|
90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
92 |
+
self.last_f = f
|
93 |
+
return f
|
94 |
+
else:
|
95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
|